Skip to content

Commit

Permalink
update train test
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 18, 2025
1 parent 4da6f3d commit 92ea510
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions tests/operations/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dacapo.store.create_store import create_stats_store
from ..fixtures import *

from dacapo.experiments import Run
from dacapo.experiments import RunConfig
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo.train import train_run

Expand All @@ -26,7 +26,7 @@
)
def test_large(
options,
run_config,
run_config: RunConfig,
):
# create a store

Expand All @@ -37,21 +37,32 @@ def test_large(
# store the configs

store.store_run_config(run_config)
run = Run(run_config)
run = run_config

# -------------------------------------

# train

weights_store.store_weights(run, 0)
train_run(run)
train_run(run, validate=False)

init_weights = weights_store.retrieve_weights(run.name, 0)
final_weights = weights_store.retrieve_weights(run.name, run.train_until)

weight_diffs = []
for name, weight in init_weights.model.items():
weight_diff = (weight - final_weights.model[name]).sum()
assert abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps, weight_diff
weight_diffs.append((name, weight_diff))

# test that some weights have changed instead of all weights have changed
# depending on the design of the task and the model, some weights might
# not get updated (dummy run for example)
assert any(
[
abs(weight_diff) > np.finfo(weight_diff.numpy().dtype).eps
for _, weight_diff in weight_diffs
]
), weight_diffs

# assert train_stats and validation_scores are available

Expand Down

0 comments on commit 92ea510

Please sign in to comment.