Skip to content

Commit

Permalink
update defaults for builds
Browse files Browse the repository at this point in the history
  • Loading branch information
nix-apollo committed Apr 19, 2024
1 parent 2ef90e4 commit e15aa71
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 55 deletions.
4 changes: 2 additions & 2 deletions tests/test_build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_naive_gradient_flow_interface(run_type, use_out_dir, tmpdir):
}
)
results = graph_build_test(config=config, atol=atol)
get_rib_acts_test(results, atol=0) # Need atol=1e-3 if float32
get_rib_acts_test(results, atol=1e-15) # Need atol=1e-3 if float32
if use_out_dir:
# Full run saves both Cs and graph (this is only true for NGF)
assert Cs_path.exists()
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_modular_arithmetic_build_graph(
}
)
results = graph_build_test(config=config, atol=atol)
get_rib_acts_test(results, atol=0) # Need atol=1e-3 if float32
get_rib_acts_test(results, atol=1e-15) # Need atol=1e-3 if float32


@pytest.mark.slow
Expand Down
61 changes: 8 additions & 53 deletions tests/test_float_precision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test that float32 and float64 results match.
Note that they do not match everywhere yet."""
Note that they do not match everywhere yet. We normally run on float64 for this reason."""

import json
import tempfile
Expand All @@ -11,6 +11,7 @@
from rib.ablations import AblationConfig, load_bases_and_ablate
from rib.log import logger
from rib.rib_builder import RibBuildConfig, RibBuildResults, rib_build
from tests.utils import get_pythia_config


@pytest.mark.slow
Expand All @@ -23,54 +24,12 @@ def temp_object(self) -> tempfile.TemporaryDirectory:
@pytest.fixture(scope="class")
def rib_results(self, temp_object) -> dict[str, RibBuildResults]:
"""Run RIB build with float32 and float64 and return the results keyed by dtype."""
rib_config_str = """
tlens_pretrained: pythia-14m
tlens_model_path: null
dataset:
dataset_type: huggingface
name: NeelNanda/pile-10k
tokenizer_name: EleutherAI/pythia-14m
return_set: train # pile-10k only has train, so we take the first 90% for building and last 10% for ablations
return_set_frac: null
n_documents: 30
n_samples: 3
return_set_portion: first
node_layers:
- mlp_out.0
- ln2.3
- mlp_out.3
- mlp_out.5
- output
batch_size: 4 # A100 can handle 24
gram_batch_size: 20 # A100 can handle 80
truncation_threshold: 1e-6
rotate_final_node_layer: false
n_intervals: 0
calculate_edges: false
eval_type: null
seed: 42
"""
rib_config = yaml.safe_load(rib_config_str)
temp_dir = temp_object.name
rib_config["out_dir"] = temp_dir

rib_results = {}
for dtype in ["float32", "float64"]:
exp_name = f"float-precision-test-pythia-14m-{dtype}"
rib_config["dtype"] = dtype
rib_config["exp_name"] = exp_name
if not torch.cuda.is_available():
# Try to reduce memory usage for CI
rib_config["batch_size"] = 1
rib_config["gram_batch_size"] = 1
logger.info(
("Running RIB build with batch size", rib_config["batch_size"], "for", dtype)
)
rib_build(RibBuildConfig(**rib_config))
interaction_rotations = RibBuildResults(
**torch.load(f"{temp_dir}/float-precision-test-pythia-14m-{dtype}_rib_Cs.pt")
)
rib_results[dtype] = interaction_rotations
config = get_pythia_config({"dtype": dtype, "out_dir": temp_dir, "exp_name": exp_name})
rib_results[dtype] = rib_build(config)

return rib_results

Expand Down Expand Up @@ -147,10 +106,7 @@ def ablation_results(self, temp_object, rib_results) -> dict:
n_samples: 3
return_set_portion: first
ablation_node_layers:
- mlp_out.0
- ln2.3
- mlp_out.3
- mlp_out.5
- ln2.1
batch_size: 30 # A100 can handle 60
eval_type: ce_loss
seed: 42
Expand Down Expand Up @@ -185,12 +141,11 @@ def ablation_results(self, temp_object, rib_results) -> dict:

return ablation_results

@pytest.mark.xfail(
reason="Insufficient precision. See https://github.com/ApolloResearch/rib/issues/212)"
)
def test_ablation_result_float_precision(self, ablation_results: dict) -> None:
# ln2.3 (and others) are broken (https://github.com/ApolloResearch/rib/issues/212)
# ln1.- are broken. ln1.0 seemed fine on GPU (a6000) but broken on CPU
for node_layer in ablation_results["float32"].keys():
if node_layer in ["ln2.3", "ln1.5", "ln1.0"]:
continue
for n_vecs_ablated in ablation_results["float32"][node_layer].keys():
float32_ablation_result = ablation_results["float32"][node_layer][n_vecs_ablated]
float64_ablation_result = ablation_results["float64"][node_layer][n_vecs_ablated]
Expand Down
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def get_modular_arithmetic_config(*updates: dict) -> RibBuildConfig:
basis_formula: (1-0)*alpha
edge_formula: squared
n_stochastic_sources_edges: null
integration_method: gauss-legendre
center: false
"""
config_dict = deep_update(yaml.safe_load(config_str), *updates)
return RibBuildConfig(**config_dict)
Expand Down Expand Up @@ -71,6 +73,8 @@ def get_pythia_config(*updates: dict) -> RibBuildConfig:
eval_type: ce_loss
out_dir: null
basis_formula: (1-0)*alpha
integration_method: gauss-legendre
center: false
"""
config_dict = deep_update(yaml.safe_load(config_str), *updates)
return RibBuildConfig(**config_dict)
Expand Down Expand Up @@ -109,6 +113,8 @@ def get_tinystories_config(*updates: dict) -> RibBuildConfig:
eval_type: ce_loss
basis_formula: jacobian
edge_formula: squared
integration_method: gauss-legendre
center: false
"""
config_dict = deep_update(yaml.safe_load(config_str), *updates)
return RibBuildConfig(**config_dict)
Expand Down Expand Up @@ -136,6 +142,8 @@ def get_mnist_config(*updates: dict) -> RibBuildConfig:
out_dir: null
basis_formula: (1-0)*alpha
edge_formula: squared
integration_method: gauss-legendre
center: false
"""
config_dict = deep_update(yaml.safe_load(config_str), *updates)
return RibBuildConfig(**config_dict)
Expand Down Expand Up @@ -171,6 +179,8 @@ def get_modular_mlp_config(*updates: dict) -> RibBuildConfig:
rotate_final_node_layer: false
basis_formula: (1-0)*alpha
edge_formula: squared
integration_method: gauss-legendre
center: false
"""
config_dict = deep_update(yaml.safe_load(config_str), *updates)
config = RibBuildConfig(**config_dict)
Expand Down

0 comments on commit e15aa71

Please sign in to comment.