Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synthetic anomaly for testing and validation #634

Merged
merged 110 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from 109 commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
ee1cfce
move sample generation to datamodule instead of dataset
djdameln Sep 9, 2022
ec5199e
move sample generation from init to setup
djdameln Sep 12, 2022
9f0a35e
remove inference stage and add base classes
djdameln Sep 13, 2022
dea176f
replace dataset classes with AnomalibDataset
djdameln Sep 13, 2022
62a04f8
move setup to base class, create samples as class method
djdameln Sep 13, 2022
e91afad
update docstrings
djdameln Sep 13, 2022
df4a805
refactor btech to new format
djdameln Sep 14, 2022
c225a83
allow training with no anomalous data
djdameln Sep 14, 2022
ac0dc8a
remove MVTec name from comment
djdameln Sep 15, 2022
5d90209
raise NotImplementedError in base class
djdameln Sep 15, 2022
c1e6724
allow both png and bmp images for btech
djdameln Sep 15, 2022
2d70d89
use label_index to check if dataset contains anomalous images
djdameln Sep 16, 2022
f5f17db
refactor getitem in dataset class
djdameln Sep 16, 2022
f02065f
use iloc for indexing
djdameln Sep 16, 2022
9cba9da
move dataloader getters to base class
djdameln Sep 16, 2022
5b3e841
refactor to add validate stage in setup
djdameln Sep 16, 2022
f652227
implement alternative datamodules solution
djdameln Sep 21, 2022
0e565a4
small improvements
djdameln Sep 21, 2022
297195a
improve design
djdameln Oct 7, 2022
94cabb7
remove unused constructor arguments
djdameln Oct 7, 2022
1ee8a96
adapt btech to new design
djdameln Oct 7, 2022
7fc5483
add prepare_data method for mvtec
djdameln Oct 7, 2022
8a9a30c
solve merge conflicts
djdameln Oct 7, 2022
1ac7c65
implement more generic random splitting function
djdameln Oct 10, 2022
965ea94
update docstrings for folder module
djdameln Oct 10, 2022
2a9f6f8
ensure type consistency when performing operations on dataset
djdameln Oct 10, 2022
84997b9
change imports
djdameln Oct 10, 2022
f21c652
change variable names
djdameln Oct 10, 2022
ab7d0ff
replace pass with NotImplementedError
djdameln Oct 10, 2022
d7e47a9
allow training on folder without test images
djdameln Oct 11, 2022
da851c6
use relative path for normal_test_dir
djdameln Oct 11, 2022
f3e38ba
fix dataset tests
djdameln Oct 11, 2022
f4719f2
update validation set parameter in configs
djdameln Oct 11, 2022
e25a587
change default argument
djdameln Oct 11, 2022
1170ca3
Merge branch 'main' into da/datamodules-alternative
djdameln Oct 11, 2022
fb84cd1
use setter for samples
djdameln Oct 12, 2022
cfa4f52
hint options for val_split_mode
djdameln Oct 12, 2022
624e522
update assert message and docstring
djdameln Oct 12, 2022
0bd77f9
revert name change dataset vs datamodule
djdameln Oct 12, 2022
6bed98f
typing and docstrings
djdameln Oct 12, 2022
fc34f8e
remove samples argument from dataset constructor
djdameln Oct 12, 2022
1482c13
val/test -> eval
djdameln Oct 12, 2022
e168163
remove Split.Full from enum
djdameln Oct 13, 2022
5071dcf
sort samples when setting
djdameln Oct 13, 2022
e175d7d
update warn message
djdameln Oct 13, 2022
03773b0
formatting
djdameln Oct 13, 2022
3910c32
use setter when creating samples in dataset classes
djdameln Oct 13, 2022
894ef12
add tests for new dataset class
djdameln Oct 13, 2022
44009e2
add test case for label aware random split
djdameln Oct 13, 2022
012ed47
update parameter name in inferencers
djdameln Oct 14, 2022
62b176e
move _setup implementation to base class
djdameln Oct 14, 2022
7e957b6
address codacy issues
djdameln Oct 14, 2022
25f503d
fix pylint issues
djdameln Oct 14, 2022
1245928
codacy
djdameln Oct 14, 2022
d9bd6e0
Merge branch 'main' into da/datamodules-alternative
djdameln Oct 14, 2022
0459a0d
update example dataset config in docs
djdameln Oct 14, 2022
30dc45a
fix test
djdameln Oct 14, 2022
85c475a
move base classes to separate files (avoid circular import)
djdameln Oct 14, 2022
0552c1a
add synthetic dataset class
djdameln Oct 14, 2022
bf4f537
move augmenter to data directory
djdameln Oct 14, 2022
cc32896
add base classes
djdameln Oct 14, 2022
23d4766
update docstring
djdameln Oct 14, 2022
b06fc63
Merge branch 'da/datamodules-alternative' into da/synthetic-validatio…
djdameln Oct 14, 2022
05ba31d
use synthetic dataset in base datamodule
djdameln Oct 14, 2022
e8d7998
fix imports
djdameln Oct 14, 2022
26b6b83
clean up synthetic anomaly dataset implementation
djdameln Oct 17, 2022
c32fee9
fix mistake in augmenter
djdameln Oct 17, 2022
e120434
change default split ratio
djdameln Oct 17, 2022
14ee645
remove accidentally added file
djdameln Oct 17, 2022
9c4e7bf
validation_split_mode -> val_split_mode
djdameln Oct 18, 2022
e5d22aa
Merge branch 'da/datamodules-alternative' into da/synthetic-validatio…
djdameln Oct 18, 2022
067d601
update docs
djdameln Oct 19, 2022
c84c99c
Update anomalib/data/base/dataset.py
djdameln Oct 21, 2022
b680d44
get length from self.samples
djdameln Oct 21, 2022
95c37b0
assert unique indices
djdameln Oct 21, 2022
3e77014
check is_setup for individual datasets
djdameln Oct 21, 2022
ede213a
remove assert in __getitem_\
djdameln Oct 21, 2022
f5e2d24
Update anomalib/data/btech.py
djdameln Oct 21, 2022
d9e1369
clearer assert message
djdameln Oct 21, 2022
2e6bc60
clarify list inversion in comment
djdameln Oct 21, 2022
af0cd99
comments and typing
djdameln Oct 21, 2022
d508786
Merge branch 'da/datamodules-alternative' of https://github.com/openv…
djdameln Oct 21, 2022
c85713c
Merge branch 'main' into da/datamodules-alternative
djdameln Oct 21, 2022
5ee8480
validate contents of samples dataframe before setting
djdameln Oct 21, 2022
a5e876a
add file paths check
djdameln Oct 21, 2022
c490e30
add seed to random_split function
djdameln Oct 21, 2022
4808287
fix expected columns
djdameln Oct 24, 2022
10bbf9c
fix typo
djdameln Oct 24, 2022
81d3ca3
add seed parameter to datamodules
djdameln Oct 28, 2022
b372dd1
set global seed in test entrypoint
djdameln Oct 28, 2022
e07a12c
add NONE option to valsplitmode
djdameln Oct 28, 2022
ffdb47c
clarify setup behaviour in docstring
djdameln Oct 28, 2022
9523ad0
merge latest changes to datamodules
djdameln Nov 4, 2022
622f1b9
merge feature branch
djdameln Nov 4, 2022
79e09e9
Merge branch 'feature/datamodules' into da/synthetic-validation-set
djdameln Nov 28, 2022
9fe6da3
Merge branch 'feature/datamodules' into da/synthetic-validation-set
djdameln Dec 5, 2022
63801a2
add logging message
djdameln Dec 5, 2022
d482cec
Merge branch 'feature/datamodules' into da/synthetic-validation-set
djdameln Dec 5, 2022
74cbc0a
use val_split_ratio for synthetic validation set
djdameln Dec 6, 2022
090cec2
pathlib
djdameln Dec 6, 2022
5f16140
merge feature branch
djdameln Dec 6, 2022
2a8df7b
make synthetic anomaly available for test set
djdameln Dec 9, 2022
ea00442
update configs
djdameln Dec 9, 2022
dfd2d80
add tests
djdameln Dec 9, 2022
ce43e09
simplify test set splitting logic
djdameln Dec 9, 2022
8b2d356
update docstring
djdameln Dec 12, 2022
a126af1
add missing licence
djdameln Dec 12, 2022
b2879c8
split_normal_and_anomalous -> split_by_label
djdameln Dec 12, 2022
532ff8b
VideoAnomalib -> AnomalibVideo
djdameln Dec 12, 2022
b1d7eb1
merge feature branch
djdameln Dec 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule:
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_eval=config.dataset.transform_config.eval,
test_split_mode=config.dataset.test_split_mode,
test_split_ratio=config.dataset.test_split_ratio,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
Expand All @@ -58,6 +60,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule:
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_eval=config.dataset.transform_config.eval,
test_split_mode=config.dataset.test_split_mode,
test_split_ratio=config.dataset.test_split_ratio,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
Expand All @@ -70,13 +74,14 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> AnomalibDataModule:
normal_test_dir=config.dataset.normal_test_dir,
mask_dir=config.dataset.mask,
extensions=config.dataset.extensions,
normal_split_ratio=config.dataset.normal_split_ratio,
image_size=(config.dataset.image_size[0], config.dataset.image_size[1]),
train_batch_size=config.dataset.train_batch_size,
eval_batch_size=config.dataset.eval_batch_size,
num_workers=config.dataset.num_workers,
transform_config_train=config.dataset.transform_config.train,
transform_config_eval=config.dataset.transform_config.eval,
test_split_mode=config.dataset.test_split_mode,
test_split_ratio=config.dataset.test_split_ratio,
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
Expand Down
8 changes: 5 additions & 3 deletions anomalib/data/avenue.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pandas import DataFrame
from torch import Tensor

from anomalib.data.base import AnomalibDataModule, VideoAnomalibDataset
from anomalib.data.base import AnomalibVideoDataModule, AnomalibVideoDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import DownloadProgressBar, Split, ValSplitMode, hash_check
from anomalib.data.utils.video import ClipsIndexer
Expand Down Expand Up @@ -121,7 +121,7 @@ def get_mask(self, idx) -> Optional[Tensor]:
return masks


class AvenueDataset(VideoAnomalibDataset):
class AvenueDataset(AnomalibVideoDataset):
"""Avenue Dataset class.

Args:
Expand Down Expand Up @@ -156,7 +156,7 @@ def _setup(self):
self.samples = make_avenue_dataset(self.root, self.gt_dir, self.split)


class Avenue(AnomalibDataModule):
class Avenue(AnomalibVideoDataModule):
"""Avenue DataModule class.

Args:
Expand All @@ -177,6 +177,8 @@ class Avenue(AnomalibDataModule):
during validation.
Defaults to None.
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
seed (Optional[int], optional): Seed which may be set to a fixed value for reproducibility.
"""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions anomalib/data/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

from .datamodule import AnomalibDataModule
from .dataset import AnomalibDataset
from .video import VideoAnomalibDataset
from .video import AnomalibVideoDataModule, AnomalibVideoDataset

__all__ = ["AnomalibDataset", "AnomalibDataModule", "VideoAnomalibDataset"]
__all__ = ["AnomalibDataset", "AnomalibDataModule", "AnomalibVideoDataset", "AnomalibVideoDataModule"]
44 changes: 43 additions & 1 deletion anomalib/data/base/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from torch.utils.data import DataLoader, default_collate

from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.utils import ValSplitMode, random_split
from anomalib.data.synthetic import SyntheticAnomalyDataset
from anomalib.data.utils import (
TestSplitMode,
ValSplitMode,
random_split,
split_by_label,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,12 +66,16 @@ def __init__(
num_workers: int,
val_split_mode: ValSplitMode,
val_split_ratio: float,
test_split_mode: Optional[TestSplitMode] = None,
test_split_ratio: Optional[float] = None,
seed: Optional[int] = None,
):
super().__init__()
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.num_workers = num_workers
self.test_split_mode = test_split_mode
self.test_split_ratio = test_split_ratio
self.val_split_mode = val_split_mode
self.val_split_ratio = val_split_ratio
self.seed = seed
Expand Down Expand Up @@ -101,12 +111,44 @@ def _setup(self, _stage: Optional[str] = None) -> None:

self.train_data.setup()
self.test_data.setup()

self._create_test_split()
self._create_val_split()

def _create_test_split(self):
"""Obtain the test set based on the settings in the config."""
if self.test_data.has_normal:
# split the test data into normal and anomalous so these can be processed separately
normal_test_data, self.test_data = split_by_label(self.test_data)
else:
# when the user did not provide any normal images for testing, we sample some from the training set
logger.info(
"No normal test images found. Sampling from training set using a split ratio of %d",
self.test_split_ratio,
)
self.train_data, normal_test_data = random_split(self.train_data, self.test_split_ratio)

if self.test_split_mode == TestSplitMode.FROM_DIR:
self.test_data += normal_test_data
elif self.test_split_mode == TestSplitMode.SYNTHETIC:
self.test_data = SyntheticAnomalyDataset.from_dataset(normal_test_data)
else:
raise ValueError(f"Unsupported Test Split Mode: {self.test_split_mode}")

def _create_val_split(self):
"""Obtain the validation set based on the settings in the config."""
if self.val_split_mode == ValSplitMode.FROM_TEST:
# randomly sampled from test set
self.test_data, self.val_data = random_split(
self.test_data, self.val_split_ratio, label_aware=True, seed=self.seed
)
elif self.val_split_mode == ValSplitMode.SAME_AS_TEST:
# equal to test set
self.val_data = self.test_data
elif self.val_split_mode == ValSplitMode.SYNTHETIC:
# converted from random training sample
self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio)
self.val_data = SyntheticAnomalyDataset.from_dataset(normal_val_data)
elif self.val_split_mode != ValSplitMode.NONE:
raise ValueError(f"Unknown validation split mode: {self.val_split_mode}")

Expand Down
29 changes: 26 additions & 3 deletions anomalib/data/base/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import torch
from torch import Tensor

from anomalib.data.base.datamodule import AnomalibDataModule
from anomalib.data.base.dataset import AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import masks_to_boxes
from anomalib.data.utils import ValSplitMode, masks_to_boxes
from anomalib.data.utils.video import ClipsIndexer
from anomalib.pre_processing import PreProcessor


class VideoAnomalibDataset(AnomalibDataset, ABC):
class AnomalibVideoDataset(AnomalibDataset, ABC):
"""Base video anomalib dataset class.

Args:
Expand Down Expand Up @@ -48,7 +49,7 @@ def samples(self):
@samples.setter
def samples(self, samples):
"""Overwrite samples and re-index subvideos."""
super(VideoAnomalibDataset, self.__class__).samples.fset(self, samples)
super(AnomalibVideoDataset, self.__class__).samples.fset(self, samples)
self._setup_clips()

def _setup_clips(self) -> None:
Expand Down Expand Up @@ -93,3 +94,25 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
item.pop("mask")

return item


class AnomalibVideoDataModule(AnomalibDataModule):
"""Base class for video data modules."""

def _setup(self, _stage: Optional[str] = None) -> None:
"""Set up the datasets and perform dynamic subset splitting.

This method may be overridden in subclass for custom splitting behaviour.

Video datamodules are not compatible with synthetic anomaly generation.
"""
assert self.train_data is not None
assert self.test_data is not None

self.train_data.setup()
self.test_data.setup()

if self.val_split_mode == ValSplitMode.SYNTHETIC:
raise ValueError(f"Val split mode {self.test_split_mode} not supported for video datasets.")

self._create_val_split()
17 changes: 16 additions & 1 deletion anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@

from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import DownloadProgressBar, Split, ValSplitMode, hash_check
from anomalib.data.utils import (
DownloadProgressBar,
Split,
TestSplitMode,
ValSplitMode,
hash_check,
)
from anomalib.pre_processing import PreProcessor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -181,6 +187,8 @@ def __init__(
task: TaskType = TaskType.SEGMENTATION,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_eval: Optional[Union[str, A.Compose]] = None,
test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR,
test_split_ratio: float = 0.2,
val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST,
val_split_ratio: float = 0.5,
seed: Optional[int] = None,
Expand All @@ -199,6 +207,11 @@ def __init__(
transform_config_val: Config for pre-processing during validation.
create_validation_set: Create a validation subset in addition to the train and test subsets
seed (Optional[int], optional): Seed used during random subset splitting.
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
seed (Optional[int], optional): Seed which may be set to a fixed value for reproducibility.

Examples:
>>> from anomalib.data import BTech
Expand Down Expand Up @@ -230,6 +243,8 @@ def __init__(
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
num_workers=num_workers,
test_split_mode=test_split_mode,
test_split_ratio=test_split_ratio,
val_split_mode=val_split_mode,
val_split_ratio=val_split_ratio,
seed=seed,
Expand Down
24 changes: 8 additions & 16 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import Split, ValSplitMode, random_split
from anomalib.data.utils import Split, TestSplitMode, ValSplitMode
from anomalib.pre_processing.pre_process import PreProcessor


Expand Down Expand Up @@ -237,7 +237,10 @@ class Folder(AnomalibDataModule):
transform_config_val (Optional[Union[str, A.Compose]], optional): Config for pre-processing
during validation.
Defaults to None.
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
seed (Optional[int], optional): Seed used during random subset splitting.
"""

Expand All @@ -258,6 +261,8 @@ def __init__(
task: TaskType = TaskType.SEGMENTATION,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_eval: Optional[Union[str, A.Compose]] = None,
test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR,
test_split_ratio: float = 0.2,
val_split_mode: ValSplitMode = ValSplitMode.FROM_TEST,
val_split_ratio: float = 0.5,
seed: Optional[int] = None,
Expand All @@ -266,6 +271,8 @@ def __init__(
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
num_workers=num_workers,
test_split_mode=test_split_mode,
test_split_ratio=test_split_ratio,
val_split_mode=val_split_mode,
val_split_ratio=val_split_ratio,
seed=seed,
Expand Down Expand Up @@ -299,18 +306,3 @@ def __init__(
mask_dir=mask_dir,
extensions=extensions,
)

def _setup(self, _stage: Optional[str] = None):
"""Set up the datasets for the Folder Data Module."""
assert self.train_data is not None
assert self.test_data is not None

self.train_data.setup()
self.test_data.setup()

# add some normal images to the test set
if not self.test_data.has_normal:
self.train_data, normal_test_data = random_split(self.train_data, self.normal_split_ratio, seed=self.seed)
self.test_data += normal_test_data

super()._setup()
36 changes: 34 additions & 2 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@

from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import DownloadProgressBar, Split, ValSplitMode, hash_check
from anomalib.data.utils import (
DownloadProgressBar,
Split,
TestSplitMode,
ValSplitMode,
hash_check,
)
from anomalib.pre_processing import PreProcessor

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,7 +155,29 @@ def _setup(self):


class MVTec(AnomalibDataModule):
"""MVTec Datamodule."""
"""MVTec Datamodule.

Args:
root (str): Path to the root of the dataset
category (str): Category of the MVTec dataset (e.g. "bottle" or "cable").
image_size (Optional[Union[int, Tuple[int, int]]], optional): Size of the input image.
Defaults to None.
train_batch_size (int, optional): Training batch size. Defaults to 32.
eval_batch_size (int, optional): Test batch size. Defaults to 32.
num_workers (int, optional): Number of workers. Defaults to 8.
task TaskType): Task type, 'classification', 'detection' or 'segmentation'
transform_config_train (Optional[Union[str, A.Compose]], optional): Config for pre-processing
during training.
Defaults to None.
transform_config_val (Optional[Union[str, A.Compose]], optional): Config for pre-processing
during validation.
Defaults to None.
test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained.
test_split_ratio (float): Fraction of images from the train set that will be reserved for testing.
val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained.
val_split_ratio (float): Fraction of train or test images that will be reserved for validation.
seed (Optional[int], optional): Seed which may be set to a fixed value for reproducibility.
"""

def __init__(
self,
Expand All @@ -162,6 +190,8 @@ def __init__(
task: TaskType = TaskType.SEGMENTATION,
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_eval: Optional[Union[str, A.Compose]] = None,
test_split_mode: TestSplitMode = TestSplitMode.FROM_DIR,
test_split_ratio: float = 0.2,
val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST,
val_split_ratio: float = 0.5,
seed: Optional[int] = None,
Expand All @@ -170,6 +200,8 @@ def __init__(
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
num_workers=num_workers,
test_split_mode=test_split_mode,
test_split_ratio=test_split_ratio,
val_split_mode=val_split_mode,
val_split_ratio=val_split_ratio,
seed=seed,
Expand Down
Loading