diff --git a/unit_scaling/_modules.py b/unit_scaling/_modules.py index 2918d55..69e882f 100644 --- a/unit_scaling/_modules.py +++ b/unit_scaling/_modules.py @@ -25,10 +25,9 @@ class GELU(nn.GELU): def __init__( self, - approximate: str = "none", constraint: Optional[BinaryConstraint] = gmean, ) -> None: - super().__init__(approximate) + super().__init__() self.constraint = constraint def forward(self, input: Tensor) -> Tensor: diff --git a/unit_scaling/utils.py b/unit_scaling/utils.py index 4a852f5..a05ce9f 100644 --- a/unit_scaling/utils.py +++ b/unit_scaling/utils.py @@ -222,7 +222,7 @@ def trace( """Adds the `target_to_function` dict to the graph so the interpreter can use it downstream.""" graph = super().trace(root, concrete_args) - if graph._tracer_extras is None: + if not hasattr(graph, "_tracer_extras") or graph._tracer_extras is None: graph._tracer_extras = {} graph._tracer_extras["target_to_function"] = self.target_to_function return graph # type: ignore