-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathSymbolicRegressionEnzymeExt.jl
61 lines (51 loc) · 1.92 KB
/
SymbolicRegressionEnzymeExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
module SymbolicRegressionEnzymeExt
using SymbolicRegression.LossFunctionsModule: eval_loss
using DynamicExpressions:
AbstractExpression,
AbstractExpressionNode,
get_scalar_constants,
set_scalar_constants!,
extract_gradient,
with_contents,
get_contents
using ADTypes: AutoEnzyme
using Enzyme: autodiff, Reverse, Active, Const, Duplicated
import SymbolicRegression.ConstantOptimizationModule: GradEvaluator
# We prepare a copy of the tree and all arrays
function GradEvaluator(f::F, backend::AE) where {F,AE<:AutoEnzyme}
storage_tree = copy(f.tree)
_, storage_refs = get_scalar_constants(storage_tree)
storage_dataset = deepcopy(f.dataset)
# TODO: It is super inefficient to deepcopy; how can we skip this
return GradEvaluator(f, backend, (; storage_tree, storage_refs, storage_dataset))
end
function evaluator(tree, dataset, options, idx, output)
output[] = eval_loss(tree, dataset, options; regularization=false, idx=idx)
return nothing
end
with_stacksize(f::F, n) where {F} = fetch(schedule(Task(f, n)))
function (g::GradEvaluator{<:Any,<:AutoEnzyme})(_, G, x::AbstractVector{T}) where {T}
set_scalar_constants!(g.f.tree, x, g.f.refs)
set_scalar_constants!(g.extra.storage_tree, zero(x), g.extra.storage_refs)
fill!(g.extra.storage_dataset, 0)
output = [zero(T)]
doutput = [one(T)]
with_stacksize(32 * 1024 * 1024) do
autodiff(
Reverse,
evaluator,
Duplicated(g.f.tree, g.extra.storage_tree),
Duplicated(g.f.dataset, g.extra.storage_dataset),
Const(g.f.options),
Const(g.f.idx),
Duplicated(output, doutput),
)
end
if G !== nothing
# TODO: This is redundant since we already have the references.
# Should just be able to extract from the references directly.
G .= first(get_scalar_constants(g.extra.storage_tree))
end
return output[]
end
end