Skip to content

Commit

Permalink
bring back initial sample sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszkolodziejczyk committed Feb 19, 2025
1 parent 47c897d commit 4fad14f
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions tests/_local/end_to_end/test_simple_flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,18 @@

@pytest.mark.parametrize(
"encoding_types",
[
{"a": "AUTO", "b": "AUTO"},
{"a": "LANGUAGE_CATEGORICAL", "b": "LANGUAGE_NUMERIC"},
],
[{"a": "AUTO", "b": "AUTO"}, {"a": "LANGUAGE_CATEGORICAL", "b": "LANGUAGE_NUMERIC"}],
)
def test_simple_flat(tmp_path, encoding_types):
mostly = MostlyAI(local=True, local_dir=tmp_path, quiet=True)

# create mock data
df = pd.DataFrame(
{
"id": range(100),
"a": ["a1", "a2"] * 50,
"b": [1, 2] * 50,
"text": ["c", "d"] * 50,
"id": range(200),
"a": ["a1", "a2"] * 100,
"b": [1, 2] * 100,
"text": ["c", "d"] * 100,
}
)

Expand Down Expand Up @@ -134,30 +131,30 @@ def test_simple_flat(tmp_path, encoding_types):
g.training.logs(tmp_path)

## SYNTHETIC PROBE
df = mostly.probe(g, size=5)
assert len(df) == 5
df = mostly.probe(g, size=10)
assert len(df) == 10

df = mostly.probe(g, seed=pd.DataFrame({"a": ["a1"]}))
assert len(df) == 1
df = mostly.probe(g, seed=pd.DataFrame({"a": ["a1"] * 10}))
assert len(df) == 10

## SYNTHETIC DATASET

# config via sugar
sd = mostly.generate(g, start=False)
assert sd.tables[0].configuration.sample_size == 100
assert sd.tables[0].configuration.sample_size == 200
sd.delete()

# config via dict
config = {"tables": [{"name": "data", "configuration": {"sample_size": 20}}]}
config = {"tables": [{"name": "data", "configuration": {"sample_size": 100}}]}
sd = mostly.generate(g, config=config, start=False)
assert sd.name == "Test 2"
sd_config = sd.config()
assert isinstance(sd_config, SyntheticDatasetConfig)
assert sd_config.tables[0].configuration.sample_size == 20
assert sd_config.tables[0].configuration.sample_size == 100
sd.delete()

# config via class
config = {"tables": [{"name": "data", "configuration": {"sample_size": 20}}]}
config = {"tables": [{"name": "data", "configuration": {"sample_size": 100}}]}
config = SyntheticDatasetConfig(**config)
sd = mostly.generate(g, config=config, start=False)

Expand All @@ -169,15 +166,15 @@ def test_simple_flat(tmp_path, encoding_types):
sd_config = sd.config()
assert isinstance(sd_config, SyntheticDatasetConfig)
assert sd_config.name == "Test 2"
assert sd_config.tables[0].configuration.sample_size == 20
assert sd_config.tables[0].configuration.sample_size == 100

# generate
sd.generation.start()
sd.generation.wait()
assert sd.generation_status == "DONE"
sd.download(tmp_path)
syn = sd.data()
assert len(syn) == 20
assert len(syn) == 100
assert list(syn.columns) == list(df.columns)

# reports
Expand Down

0 comments on commit 4fad14f

Please sign in to comment.