diff --git a/sae_bench/evals/absorption/common.py b/sae_bench/evals/absorption/common.py index 0837069..1c14b77 100644 --- a/sae_bench/evals/absorption/common.py +++ b/sae_bench/evals/absorption/common.py @@ -73,6 +73,7 @@ def load_probe( probe = torch.load( Path(probes_dir) / f"{model_name}" / f"layer_{layer}" / "probe.pth", map_location=device, + weights_only=False, ).to(dtype=dtype) return probe