Skip to content

Commit

Permalink
Merge pull request #40 from LuxDL/ap/patch_cuda_tracker
Browse files Browse the repository at this point in the history
Workaround CuPtr issue in Tracker
  • Loading branch information
avik-pal authored Oct 10, 2023
2 parents 79928f9 + 400ed83 commit f3609c9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.7"
version = "0.3.8"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
5 changes: 4 additions & 1 deletion lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector),
end
end

__make_nothing(x) = x
__make_nothing(::CuPtr{Nothing}) = 0

@grad function LuxLib.batchnorm_cudnn(running_mean, running_var, scale, bias, x, momentum,
eps, training)
y, xmean, xivar = batchnorm_cudnn(data(running_mean), data(running_var), data(scale),
Expand All @@ -47,7 +50,7 @@ end
data(running_mean), data(running_var), xmean, xivar; ϵ=eps)
return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing)
end
return (y, xmean, xivar), ∇batchnorm_cudnn_internal
return (y, __make_nothing(xmean), __make_nothing(xivar)), ∇batchnorm_cudnn_internal
end

end

0 comments on commit f3609c9

Please sign in to comment.