forked from ArthurConmy/Automatic-Circuit-Discovery
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpyproject.toml
109 lines (101 loc) · 3.22 KB
/
pyproject.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
[tool.poetry]
name = "acdc"
version = "0.0.0" # This should automatically be set by the CD pipeline on release
description = "ACDC: Automatic Circuit DisCovery implementation on top of TransformerLens"
authors = ["Arthur Conmy, Adrià Garriga-Alonso"]
license = "MIT"
readme = "README.md"
packages = [{include = "acdc"}, {include = "subnetwork_probing"}]
[tool.poetry.dependencies]
python = "^3.10"
einops = "^0.7.0"
numpy = [{ version = "^1.21", python = "<3.10" },
{ version = "^1.26", python = ">=3.10" }]
torch = ">=2.2.0"
datasets = "^2.7.1"
transformers = "^4.37"
tokenizers = "^0.15.0"
tqdm = "^4.66"
pandas = "^2.1.4"
wandb = "^0.16"
torchtyping = "^0.1.4"
huggingface-hub = "^0.24.0"
cmapy = "^0.6.6"
networkx = "^3.1"
plotly = "^5.12.0"
kaleido = "0.2.1"
pygraphviz = "^1.11"
tracr = { git = "https://github.com/FlyingPumba/tracr.git" }
transformer-lens = "1.19.0"
typer = "^0.9.0"
jaxtyping = "^0.2.25"
joblib = "^1.3.2"
hydra-core = "^1.3.2"
icecream = "^2.1.3"
scikit-learn = "^1.4.2"
jsonpickle = "^3.0.4"
[tool.poetry.group.dev.dependencies]
pytest = "^7.2.0"
pytest-cov = "^4.0.0"
jupyterlab = "^3.5.0"
jupyter = "^1.0.0"
ruff = "^0.2.1"
pyright = "^1.1.348"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
filterwarnings = [
# Ignore numpy.distutils deprecation warning caused by pandas
# More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
"ignore:distutils Version classes are deprecated:DeprecationWarning"
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
]
testpaths = "acdc subnetwork_probing experiments notebooks tests"
# Use jaxtyping runtime checks when using pytest; see https://docs.kidger.site/jaxtyping/api/runtime-type-checking/#pytest-hook
# nice feature, but it slows down some tests a lot, hence commented out
# addopts = "--jaxtyping-packages=acdc,beartype.beartype"
[tool.ruff]
line-length = 120
extend-exclude = ['submodules/']
[tool.ruff.lint]
# Enable the isort rules.
extend-select = ["I"]
ignore = ["F722"] # the jaxtyping annotations cause F722
[tool.pyright]
exclude = [
# other people's code
"subnetwork_probing/transformer_lens",
"submodules",
# code that is part of the ACDC repo, but has not been cleaned up yet
"notebooks",
"ims",
"experiments",
"acdc/tracr_task",
"tests/subnetwork_probing/run_edge_sp_tests.py",
# these ones below should also be checked in the future
"acdc/acdc_graphics.py",
"acdc/acdc_utils.py",
"acdc/docstring/prompts.py",
"acdc/docstring/utils.py",
"acdc/global_cache.py",
"acdc/greaterthan/utils.py",
"acdc/induction/utils.py",
"acdc/ioi/ioi_dataset.py",
"acdc/ioi/utils.py",
"acdc/logic_gates/utils.py",
"acdc/main.py",
# "acdc/TLACDCEdge.py",
# "acdc/TLACDCExperiment.py",
"subnetwork_probing/create_reset_networks.py",
"subnetwork_probing/launch_all_edge_sp.py",
"subnetwork_probing/sp_utils.py",
"subnetwork_probing/train_edge_sp.py",
"subnetwork_probing/train.py",
"tests/acdc/test_acdc.py",
"acdc/TLACDCExperiment.py",
"acdc/nudb/adv_opt/analysis/2024_02_28_scratch.py",
"acdc/nudb/adv_opt/analysis/2024_w16_4_plot_heatmaps.py",
]