Skip to content

Commit 12c4c60

Browse files
authored
Merge pull request #139 from GianlucaFuwa/master
Fix poor performance when one of `batch`'s arguments is a `Type{...}`
2 parents c0ef9e1 + 3038f82 commit 12c4c60

File tree

3 files changed

+48
-29
lines changed

3 files changed

+48
-29
lines changed

src/batch.jl

+15-22
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# S is a Val{Bool} indicating whether we will need to load the thread index
22
# C is a Tuple{...} containing the types of the reduction variables
33
struct BatchClosure{F,A,S,C}
4-
f::F
4+
f::F
55
end
66
function (b::BatchClosure{F,A,S,C})(p::Ptr{UInt}) where {F,A,S,C}
77
(offset, args) = ThreadingUtilities.load(p, A, 2 * sizeof(UInt))
@@ -130,7 +130,7 @@ end
130130
start,
131131
stop,
132132
i,
133-
reductup
133+
reductup,
134134
) do p, fptr, argtup, start, stop, i, reductup
135135
setup_batch!(p, fptr, argtup, start, stop, i, reductup)
136136
end
@@ -185,12 +185,9 @@ end
185185
# nthread_total = sum(nthread_tuple)
186186
Ndp = Nd + one(Nd)
187187
end
188-
C !== 0 && push!(
189-
q.args,
190-
quote
191-
@nexprs $C j -> RVAR_j = reducinits[j]
192-
end
193-
)
188+
C !== 0 && push!(q.args, quote
189+
@nexprs $C j -> RVAR_j = reducinits[j]
190+
end)
194191
launch_quote = if S
195192
if C === 0
196193
:(launch_batched_thread!(cfunc, tid, argtup, start, stop, tid % UInt))
@@ -211,27 +208,23 @@ end
211208
if C !== 0
212209
push!(
213210
rem_quote.args,
214-
:(@nexprs $C j -> RVAR_j = reducops[j](RVAR_j, thread_results[j]))
211+
:(@nexprs $C j -> RVAR_j = reducops[j](RVAR_j, thread_results[j])),
215212
)
216213
end
217214
update_retv = if C === 0
218-
Expr(:block)
215+
Expr(:block)
219216
else
220217
quote
221218
thread_results = load_threadlocals(tid, argtup, needtid, reducinits)
222219
@nexprs $C j -> RVAR_j = reducops[j](RVAR_j, thread_results[j])
223220
end
224221
end
225222
ret_quote = Expr(:return)
226-
if C === 0
227-
push!(ret_quote.args, nothing)
228-
else
229-
redtup = Expr(:tuple)
230-
for j in 1:C
231-
push!(redtup.args, Symbol("RVAR_", j))
232-
end
233-
push!(ret_quote.args, redtup)
223+
redtup = Expr(:tuple)
224+
for j 1:C
225+
push!(redtup.args, Symbol("RVAR_", j))
234226
end
227+
push!(ret_quote.args, redtup)
235228

236229
block = quote
237230
start = zero(UInt)
@@ -321,16 +314,16 @@ end
321314
@label SERIAL
322315
if S
323316
if C === 0
324-
reducres = f!(args, one(Int), ulen % Int, 1)
325-
return reducres
317+
f!(args, one(Int), ulen % Int, 1)
318+
return ()
326319
else
327320
reducres = f!(args, one(Int), ulen % Int, 1, reducinits)
328321
return reducres
329322
end
330323
else
331324
if C === 0
332-
reducres = f!(args, one(Int), ulen % Int)
333-
return reducres
325+
f!(args, one(Int), ulen % Int)
326+
return ()
334327
else
335328
reducres = f!(args, one(Int), ulen % Int, reducinits)
336329
return reducres

src/closure.jl

+13-1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,13 @@ Base.@propagate_inbounds combine(x::AbstractArray, I, j) =
207207
x[combine(CombineIndices(), I, j)]
208208
Base.@propagate_inbounds combine(x::AbstractArray, ::NoLoop, j) = x[j]
209209

210+
struct WrapType{T} end
211+
wrap_type(@nospecialize(x)) = x
212+
wrap_type(::Type{T}) where {T} = WrapType{T}()
213+
214+
unwrap_type(@nospecialize(x)) = x
215+
unwrap_type(::WrapType{T}) where {T} = T
216+
210217
function makestatic!(expr)
211218
expr isa Expr || return expr
212219
for i in eachindex(expr.args)
@@ -421,9 +428,11 @@ function enclose(exorig::Expr, minbatchsize, per, threadlocal, reduction, stride
421428
loop_stop_expr =
422429
:(var"##SUBSTOP##" * var"##LOOP_STEP##" + var"##LOOPOFFSET##" - var"##LOOP_STEP##")
423430
end
431+
unwrap_args = Expr(:block)
424432
closureq = quote
425433
$closure = let
426434
@inline $closure_args -> begin
435+
$unwrap_args
427436
local var"##STEP##" = $(
428437
stride ?
429438
:($loop_step * min(Threads.nthreads()::Int, Sys.CPU_THREADS::Int)) :
@@ -470,12 +479,15 @@ function enclose(exorig::Expr, minbatchsize, per, threadlocal, reduction, stride
470479
end
471480
for a arguments
472481
push!(args.args, get(defined, a, a))
473-
push!(batchcall.args, esc(a))
482+
push!(batchcall.args, :($wrap_type($(esc(a)))))
474483
end
475484
if threadlocal_val !== Symbol("")
476485
push!(args.args, threadlocal_accum)
477486
push!(batchcall.args, esc(threadlocal_accum))
478487
end
488+
for a in args.args
489+
push!(unwrap_args.args, :($a = $unwrap_type($a)))
490+
end
479491
push!(q.args, batchcall)
480492
quote
481493
var"##NUM#THREADS##" = $(Threads.nthreads)()

test/runtests.jl

+20-6
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ function issue25_but_with_strides!(dest, x, y)
112112
dest
113113
end
114114

115+
function issue116!(y::Vector{T}, x::Vector{T}) where {T}
116+
@batch for i in 1:length(x)
117+
y[i] = exp(x[i] + one(T))
118+
end
119+
end
120+
115121

116122
@testset "Range Map" begin
117123

@@ -297,6 +303,15 @@ Base.eachindex(e::Iterators.Enumerate{LazyTree{T}}) where {T} = eachindex(e.itr)
297303
evt = 5
298304
end
299305

306+
@testset "not-specializing-on-type heuristics" begin
307+
allocated(f::F, args...) where {F} = @allocated f(args...)
308+
x = rand(10000)
309+
y = similar(x)
310+
allocated(issue116!, y, x)
311+
@test y exp.(x .+ 1.0)
312+
@test allocated(issue116!, y, x) == 0
313+
end
314+
300315
@testset "threadlocal storage" begin
301316
local1 = let
302317
@batch threadlocal = 0 for i = 0:9
@@ -338,15 +353,13 @@ end
338353
end
339354
@test local1 == local2 == local3 == local4 == local5 == local6
340355
# check that each thread has a separate init
341-
myvar = 0
342-
myinitC() = myvar += 1
343356
inits = let
344-
@batch threadlocal = myinitC() for i = 0:9
357+
@batch threadlocal = rand() for i = 0:9
345358
threadlocal += 1
346359
end
347360
threadlocal
348361
end
349-
@test length(inits) == 1 || inits[1] != inits[end] # this test has a race condition and can (rarely) fail
362+
@test length(inits) == 1 || inits[1] != inits[end]
350363
# check that types are respected
351364
myinitD() = Float16(1.0)
352365
settingtype = let
@@ -461,14 +474,15 @@ end
461474
# check for name interference with threadlocal (used to error on single threaded runs)
462475
function f()
463476
n = 1000
464-
threadlocal = false
477+
threadlocal = 0
465478
@batch minbatch = 10 reduction = (+,threadlocal) for i = 1:n
466-
threadlocal += true
479+
threadlocal += 1
467480
end
468481
return threadlocal
469482
end
470483
allocated(f::F) where {F} = @allocated f()
471484
inferred(f::F) where {F} = try @inferred f(); true catch; false end
485+
allocated(f)
472486
@test allocated(f) == 0
473487
@test inferred(f) == true
474488
# remaining supported operations

0 commit comments

Comments
 (0)