Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fastmath flag to PTXCompilerTarget #492

Merged
merged 2 commits into from
Aug 14, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
blocks_per_sm::Union{Nothing,Int} = nothing
maxregs::Union{Nothing,Int} = nothing

fastmath::Bool = false

# deprecated; remove with next major version
exitable::Union{Nothing,Bool} = nothing
unreachable::Union{Nothing,Bool} = nothing
Expand All @@ -33,6 +35,7 @@ function Base.hash(target::PTXCompilerTarget, h::UInt)
h = hash(target.maxthreads, h)
h = hash(target.blocks_per_sm, h)
h = hash(target.maxregs, h)
h = hash(target.fastmath, h)

h
end
Expand Down Expand Up @@ -81,6 +84,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget}))
job.config.target.maxthreads !== nothing && print(io, ", maxthreads=$(job.config.target.maxthreads)")
job.config.target.blocks_per_sm !== nothing && print(io, ", blocks_per_sm=$(job.config.target.blocks_per_sm)")
job.config.target.maxregs !== nothing && print(io, ", maxregs=$(job.config.target.maxregs)")
job.config.target.fastmath && print(io, ", fast math enabled")
end

const ptx_intrinsics = ("vprintf", "__assertfail", "malloc", "free")
Expand Down Expand Up @@ -423,7 +427,7 @@ function nvvm_reflect!(fun::LLVM.Function)
# handle possible cases
# XXX: put some of these property in the compiler job?
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
fast_math = Base.JLOptions().fast_math == 1
fast_math = current_job.config.target.fastmath
# NOTE: we follow nvcc's --use_fast_math
reflect_val = if reflect_arg == "__CUDA_FTZ"
# single-precision denormals support
Expand All @@ -432,7 +436,7 @@ function nvvm_reflect!(fun::LLVM.Function)
# single-precision floating-point division and reciprocals.
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_PREC_SQRT"
# single-precision denormals support
# single-precision floating point square roots.
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_FMAD"
# contraction of floating-point multiplies and adds/subtracts into
Expand Down