diff --git a/clrs/_src/dataset.py b/clrs/_src/dataset.py index aa22cbe..0399094 100644 --- a/clrs/_src/dataset.py +++ b/clrs/_src/dataset.py @@ -78,6 +78,7 @@ def _num_samples(self, algorithm_name): return num_samples def _create_data(self, single_sample): + assert self._builder_config is not None algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1]) num_samples = self._num_samples(algorithm_name) sampler, _ = samplers.build_sampler( @@ -118,6 +119,7 @@ def _info(self) -> tfds.core.DatasetInfo: def _split_generators(self, dl_manager: tfds.download.DownloadManager): """Download the data and define splits.""" + assert self._builder_config is not None if (self._instantiated_dataset_name != self._builder_config.name or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations self._create_data(single_sample=False) @@ -127,6 +129,8 @@ def _split_generators(self, dl_manager: tfds.download.DownloadManager): def _generate_examples(self): """Generator of examples for each split.""" + assert self._builder_config is not None + assert self._instantiated_dataset is not None algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1]) for i in range(self._num_samples(algorithm_name)): data = {k: _correct_axis_filtering(v, i, k)