diff --git a/.actions/assistant.py b/.actions/assistant.py index c5043addffeec..d7a7a850f9d58 100644 --- a/.actions/assistant.py +++ b/.actions/assistant.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import glob +import logging import os import pathlib import re @@ -43,6 +44,11 @@ "requirements/fabric/base.txt", "requirements/fabric/strategies.txt", ), + "data": ( + "requirements/data/data.txt", + "requirements/data/cloud.txt", + "requirements/data/examples.txt", + ), } REQUIREMENT_FILES_ALL = list(chain(*REQUIREMENT_FILES.values())) @@ -146,6 +152,9 @@ def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str """ assert unfreeze in {"none", "major", "all"} path = Path(path_dir) / file_name + if not path.exists(): + logging.warning(f"Folder {path_dir} does not have any base requirements.") + return [] assert path.exists(), (path_dir, file_name, path) text = path.read_text() return [req.adjust(unfreeze) for req in _parse_requirements(text)] @@ -240,7 +249,7 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme requires = [ load_requirements(d, unfreeze="none" if freeze_requirements else "major") for d in glob.glob(os.path.join(req_dir, "*")) - # skip empty folder as git artefacts, and resolving Will's special issue + # skip empty folder (git artifacts), and resolving Will's special issue if os.path.isdir(d) and len(glob.glob(os.path.join(d, "*"))) > 0 and not os.path.basename(d).startswith("_") ] if not requires: @@ -404,6 +413,7 @@ def _replace_min(fname: str) -> None: def replace_oldest_ver(requirement_fnames: Sequence[str] = REQUIREMENT_FILES_ALL) -> None: """Replace the min package version by fixed one.""" for fname in requirement_fnames: + print(fname) AssistantCLI._replace_min(fname) @staticmethod diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 86480027deb14..eb45ac3f1bebe 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -41,6 +41,11 @@ /src/lightning/pytorch/core/hooks.py @williamfalcon @tchaton @awaelchli @carmocca /src/lightning/pytorch/core/module.py @williamfalcon @tchaton @awaelchli @carmocca +# Data Utilities +/examples/data/ @nohalon @justusschock +/src/lightning/data/ @nohalon @justusschock +/tests/tests_data @nohalon @justusschock + # Lightning Fabric /src/lightning/fabric @awaelchli @carmocca @justusschock /src/lightning_fabric @awaelchli @carmocca @justusschock diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index dbaa992619009..7ab4560931916 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -150,6 +150,26 @@ subprojects: - "build-pl (3.9, 1.13, 11.7.1)" - "build-pl (3.10, 2.0, 11.7.1)" + # SECTIONS: lightning_data + + - id: "lightning_data: CPU workflow" + paths: + - ".actions/**" + - "requirements/data/**" + - "src/lightning/data/**" + - "src/lightning_data/*" + - "tests/tests_data/**" + - "examples/data/**" + - "pyproject.toml" # includes pytest config + - ".github/workflows/ci-tests-data.yml" + - "!requirements/*/docs.txt" + - "!*.md" + - "!**/*.md" + checks: + - "data-cpu (macOS-11, lightning, 3.10, 2.0)" + - "data-cpu (ubuntu-20.04, lightning, 3.10, 2.0)" + - "data-cpu (windows-2022, lightning, 3.10, 2.0)" + # SECTION: lightning_fabric - id: "lightning_fabric: CPU workflow" diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml new file mode 100644 index 0000000000000..665ff7c55d4f7 --- /dev/null +++ b/.github/workflows/ci-tests-data.yml @@ -0,0 +1,118 @@ +name: Test Data + +# see: https://help.github.com/en/actions/reference/events-that-trigger-workflows +on: + push: + branches: [master, "release/*"] + pull_request: + branches: [master, "release/*"] + types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped + paths: + - ".actions/**" + - "requirements/data/**" + - "src/lightning/data/**" + - "tests/tests_data/**" + - "pyproject.toml" # includes pytest config + - ".github/workflows/ci-tests-data.yml" + - "!requirements/*/docs.txt" + - "!*.md" + - "!**/*.md" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} + cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} + +defaults: + run: + shell: bash + +jobs: + data-cpu: + runs-on: ${{ matrix.os }} + if: github.event.pull_request.draft == false + strategy: + fail-fast: false + matrix: + include: + - {os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0"} + - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0"} + - {os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0"} + # "oldest" versions tests, only on minimum Python + # - {os: "macOS-11", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"} + # - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"} + # - {os: "windows-2022", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"} + timeout-minutes: 25 # because of building grpcio on Mac + env: + PACKAGE_NAME: ${{ matrix.pkg-name }} + FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} + # PYPI_CACHE_DIR: "_pip-wheels" + TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html" + TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch_test.html" + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: basic setup + run: pip install -q -r .actions/requirements.txt + + - name: Set min. dependencies + if: ${{ matrix.requires == 'oldest' }} + run: | + python .actions/assistant.py replace_oldest_ver + + - name: Adjust PyTorch versions in requirements files + if: ${{ matrix.requires != 'oldest' && matrix.release != 'pre' }} + run: | + pip install -q wget packaging + python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py + for fpath in `ls requirements/data/*.txt`; do \ + python ./adjust-torch-versions.py $fpath ${{ matrix.pytorch-version }}; \ + done + cat requirements/data/data.txt + cat requirements/data/cloud.txt + + # - name: pip wheels cache + # uses: actions/cache/restore@v3 + # with: + # path: ${{ env.PYPI_CACHE_DIR }} + # key: pypi_wheels + # - run: | + # mkdir -p $PYPI_CACHE_DIR + # ls -lh $PYPI_CACHE_DIR + + # removing torch stable line: + # pip install -e ".[${extra}test]" "pytest-timeout" -U -f ${TORCH_URL} ${TORCH_PREINSTALL} -f ${PYPI_CACHE_DIR} --prefer-binary + - name: Install package & dependencies + run: | + python -m pip install -q pip -U + pip install -e ".[data-dev]" "pytest-timeout" -U -f ${TORCH_URL} --prefer-binary + pip list + + - name: Testing Data + working-directory: tests/tests_data + # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 + run: | + python -m coverage run --source lightning \ + -m pytest -v --timeout=30 --durations=50 + + - name: Statistics + if: success() + working-directory: tests/tests_data + run: | + coverage report + coverage xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + # see: https://github.com/actions/toolkit/issues/399 + continue-on-error: true + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: tests/tests_data/coverage.xml + flags: lightning,cpu,pytest,python${{ matrix.python-version }} + name: CPU-coverage + fail_ci_if_error: false diff --git a/.gitignore b/.gitignore index 17c18b06ee99a..8cbc5a81a325b 100644 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,10 @@ our_model.tar test.png saved_models data/ +!src/lightning/data/ +!examples/data/ +!tests/tests_pytorch/utilities/data/ +!requirements/data/ .shared .lightning node_modules/ diff --git a/examples/data/image/imagenet.py b/examples/data/image/imagenet.py new file mode 100644 index 0000000000000..c9cd50fa256aa --- /dev/null +++ b/examples/data/image/imagenet.py @@ -0,0 +1,190 @@ +import os +import traceback +from argparse import ArgumentParser +from typing import Callable, Literal, Optional + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler + +import lightning as L +from lightning.pytorch.utilities.model_helpers import get_torchvision_model + +parser = ArgumentParser() +parser.add_argument("--workers", default=4, type=int) +parser.add_argument("--batchsize", default=56, type=int) +parser.add_argument("-e", "--evaluate", dest="evaluate", action="store_true", help="evaluate model on validation set") +args = parser.parse_args() + +# -------------------------------- +# Step 1: Define a LightningModule +# -------------------------------- + + +class ImageNetLightningModel(L.LightningModule): + """ + >>> ImageNetLightningModel(data_path='missing') # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE + ImageNetLightningModel( + (model): ResNet(...) + ) + """ + + from torchvision.models.resnet import ResNet18_Weights + + def __init__( + self, + data_path: str, + index_file_path: str = None, + arch: str = "resnet18", + weights=ResNet18_Weights.IMAGENET1K_V1, + lr: float = 1e-4, + momentum: float = 0.9, + weight_decay: float = 1e-4, + batch_size: int = 256, + workers: int = 4, + ): + super().__init__() + self.arch = arch + self.weights = weights + self.lr = lr + self.momentum = momentum + self.weight_decay = weight_decay + self.batch_size = batch_size + self.workers = workers + self.data_path = data_path + self.index_file_path = index_file_path + self.model = get_torchvision_model(self.arch, weights=self.weights) + self.train_dataset: Optional[Dataset] = None + self.eval_dataset: Optional[Dataset] = None + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + images, target = batch + output = self.model(images) + loss_train = F.cross_entropy(output, target) + self.log("train_loss", loss_train) + return loss_train + + def eval_step(self, batch, batch_idx, prefix: str): + images, target = batch + output = self.model(images) + loss_val = F.cross_entropy(output, target) + self.log(f"{prefix}_loss", loss_val) + return loss_val + + def validation_step(self, batch, batch_idx): + return self.eval_step(batch, batch_idx, "val") + + def test_step(self, batch, batch_idx): + return self.eval_step(batch, batch_idx, "test") + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=self.momentum, weight_decay=self.weight_decay) + scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.1 ** (epoch // 30)) + return [optimizer], [scheduler] + + def train_dataloader(self): + import torchvision as tv + + transforms = tv.transforms.Compose([tv.transforms.RandomResizedCrop(224), tv.transforms.ToTensor()]) + + train_dataset = S3LightningImagenetDataset( + data_source=self.data_path, split="train", transforms=transforms, path_to_index_file=self.index_file_path + ) + + return torch.utils.data.DataLoader( + dataset=train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers + ) + + def val_dataloader(self): + import torchvision as tv + + transforms = tv.transforms.Compose([tv.transforms.RandomResizedCrop(224), tv.transforms.ToTensor()]) + + val_dataset = S3LightningImagenetDataset( + data_source=self.data_path, split="val", transforms=transforms, path_to_index_file=self.index_file_path + ) + + return torch.utils.data.DataLoader( + dataset=val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.workers + ) + + def test_dataloader(self): + return self.val_dataloader() + + +# ------------------- +# Step 2: Define data +# ------------------- + + +class S3LightningImagenetDataset(L.LightningDataset): + def __init__( + self, + data_source: str, + split: Literal["train", "val"], + transforms: Optional[Callable] = None, + path_to_index_file: Optional[str] = None, + ): + from torchvision.models._meta import _IMAGENET_CATEGORIES + + super().__init__(data_source=data_source, backend="s3", path_to_index_file=path_to_index_file) + + # only get files for the split + self.files = tuple([x for x in self.files if split in x]) + + # get unique classes + self.classes = _IMAGENET_CATEGORIES + + self.transforms = transforms + + def load_sample(self, file_path, stream): + from PIL import Image + + try: + img = Image.open(stream) + + if self.transforms is not None: + img = self.transforms(img) + + # Converting grey scale images to RGB + if img.shape[0] == 1: + img = img.repeat((3, 1, 1)) + + curr_cls = os.path.basename(os.path.dirname(file_path)).replace("_", " ") + cls_idx = self.classes.index(curr_cls) + return img, cls_idx + except Exception: + print(file_path, traceback.print_exc()) + pass + + +if __name__ == "__main__": + # os.environ["AWS_ACCESS_KEY"] = + # os.environ["AWS_SECRET_ACCESS_KEY"] = + + data_path = "s3://imagenet-tiny" + index_file_path = "imagenet/imagenet-index.txt" + + # ------------------- + # Step 3: Train + # ------------------- + + print("Instantiate Model") + model = ImageNetLightningModel( + weights=None, + data_path=data_path, + index_file_path=index_file_path, + batch_size=args.batchsize, + workers=args.workers, + ) + trainer = L.Trainer() + + print("Train Model") + if args.evaluate: + trainer.test(model) + else: + trainer.fit(model) diff --git a/requirements/data/cloud.txt b/requirements/data/cloud.txt new file mode 100644 index 0000000000000..f882fdcee1272 --- /dev/null +++ b/requirements/data/cloud.txt @@ -0,0 +1,5 @@ +# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package +# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment + +fsspec[http] >2021.06.0, <2023.5.0 +s3fs >=2022.5.0, <=2022.11.1 diff --git a/requirements/data/data.txt b/requirements/data/data.txt new file mode 100644 index 0000000000000..7d176a12478ca --- /dev/null +++ b/requirements/data/data.txt @@ -0,0 +1,8 @@ +# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package +# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment + +lightning-utilities >=0.8.0, <0.9.0 +# to be able to include also 0.6 and preserve `>` needed for CI min version bypass +torchdata >0.5.9, <0.7.0 +# to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass +torch >0.14.0, <2.1.0 diff --git a/requirements/data/examples.txt b/requirements/data/examples.txt new file mode 100644 index 0000000000000..4daff66969d97 --- /dev/null +++ b/requirements/data/examples.txt @@ -0,0 +1,3 @@ +Pillow >= 9.5.0 +# min version to match torch >= 2.0.1 +torchvision >=0.15.2, <=0.16 diff --git a/requirements/data/test.txt b/requirements/data/test.txt new file mode 100644 index 0000000000000..a02407eb28c7c --- /dev/null +++ b/requirements/data/test.txt @@ -0,0 +1,5 @@ +coverage ==7.2.5 +pytest ==7.3.1 +pytest-cov ==4.0.0 +pytest-rerunfailures ==10.3 +pytest-random-order ==1.1.0 diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 052d078392018..567c648759849 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment numpy >=1.17.2, <1.24.4 -torch >=1.11.0, <=2.0.1 +torch >=1.11.0, <2.1.0 fsspec[http]>2021.06.0, <2023.5.0 packaging >=17.1, <=23.0 typing-extensions >=4.0.0, <=4.4.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index cbc937420f70f..2d202af15f518 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -2,7 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment numpy >=1.17.2, <1.24.4 -torch >=1.11.0, <=2.0.1 +torch >=1.11.0, <2.1.0 tqdm >=4.57.0, <4.66.0 PyYAML >=5.4, <=6.0 fsspec[http] >2021.06.0, <2023.5.0 diff --git a/setup.py b/setup.py index e2dfc48d08088..043308ebb683c 100755 --- a/setup.py +++ b/setup.py @@ -41,6 +41,7 @@ """ import contextlib import glob +import logging import os import tempfile from importlib.util import module_from_spec, spec_from_file_location @@ -87,11 +88,7 @@ def _set_manifest_path(manifest_dir: str, aggregate: bool = False, mapping: Mapp if aggregate: # aggregate all MANIFEST.in contents into a single temporary file manifest_path = _named_temporary_file(manifest_dir) - lines = [ - "include src/lightning/version.info\n", - "include src/lightning/py.typed\n", - "include requirements/base.txt\n", - ] + lines = [] # load manifest and aggregated all manifests for pkg in mapping.values(): pkg_manifest = os.path.join(_PATH_SRC, pkg, "MANIFEST.in") @@ -104,6 +101,7 @@ def _set_manifest_path(manifest_dir: str, aggregate: bool = False, mapping: Mapp continue # avoid `lightning` -> `lightning/lightning` lines = [ln.replace(old, f"lightning/{new}") for ln in lines] lines = sorted(set(filter(lambda ln: not ln.strip().startswith("#"), lines))) + logging.debug(f"aggregated manifest consists of: {lines}") with open(manifest_path, mode="w") as fp: fp.writelines(lines) else: @@ -111,7 +109,7 @@ def _set_manifest_path(manifest_dir: str, aggregate: bool = False, mapping: Mapp assert os.path.exists(manifest_path) # avoid error: setup script specifies an absolute path manifest_path = os.path.relpath(manifest_path, _PATH_ROOT) - print("Set manifest path to", manifest_path) + logging.info("Set manifest path to", manifest_path) setuptools.command.egg_info.manifest_maker.template = manifest_path yield # cleanup diff --git a/src/lightning/MANIFEST.in b/src/lightning/MANIFEST.in new file mode 100644 index 0000000000000..30c7f8a75ac4b --- /dev/null +++ b/src/lightning/MANIFEST.in @@ -0,0 +1,6 @@ +# on top of all other aggregated manifests from sub packages add also: + +include src/lightning/version.info +include src/lightning/py.typed +include requirements/base.txt +recursive-include requirements/data *.txt diff --git a/src/lightning/__init__.py b/src/lightning/__init__.py index cd66a2fa80a42..0a209e9064131 100644 --- a/src/lightning/__init__.py +++ b/src/lightning/__init__.py @@ -22,6 +22,7 @@ from lightning.app.perf import pdb # noqa: E402 from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402 from lightning.app.utilities.packaging.cloud_compute import CloudCompute # noqa: E402 +from lightning.data import LightningDataset # noqa: E402 from lightning.fabric.fabric import Fabric # noqa: E402 from lightning.fabric.utilities.seed import seed_everything # noqa: E402 from lightning.pytorch.callbacks import Callback # noqa: E402 @@ -42,6 +43,7 @@ "BuildConfig", "CloudCompute", "Trainer", + "LightningDataset", "LightningDataModule", "LightningModule", "Callback", diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index 25ad816cbffaf..96ce7c27de944 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -47,6 +47,8 @@ def _prepare_extras() -> Dict[str, Any]: extras["app-extra"] = extras["app-cloud"] + extras["app-ui"] + extras["app-components"] extras["app-all"] = extras["app-extra"] extras["app-dev"] = extras["app-all"] + extras["app-test"] + extras["data-all"] = extras["data-data"] + extras["data-cloud"] + extras["data-examples"] + extras["data-dev"] = extras["data-all"] + extras["data-test"] # merge per-project extras of the same category, e.g. `app-test` + `fabric-test` for extra in list(extras): name = "-".join(extra.split("-")[1:]) diff --git a/src/lightning/data/CHANGELOG.md b/src/lightning/data/CHANGELOG.md new file mode 100644 index 0000000000000..552f04c4a8c89 --- /dev/null +++ b/src/lightning/data/CHANGELOG.md @@ -0,0 +1,11 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). + +## \[UnReleased\] - 2023-MM-DD + +### Added + +- Added `LightningDataset` for optimized data loading including fast loading for S3 buckets. ([#17743](https://github.com/Lightning-AI/lightning/pull/17743)) diff --git a/src/lightning/data/__init__.py b/src/lightning/data/__init__.py new file mode 100644 index 0000000000000..e976bab846553 --- /dev/null +++ b/src/lightning/data/__init__.py @@ -0,0 +1,3 @@ +from lightning.data.dataset import LightningDataset + +__all__ = ["LightningDataset"] diff --git a/src/lightning/data/backends.py b/src/lightning/data/backends.py new file mode 100644 index 0000000000000..9cdc764185d4e --- /dev/null +++ b/src/lightning/data/backends.py @@ -0,0 +1,71 @@ +import os +from typing import Dict, Optional, Protocol, runtime_checkable, TYPE_CHECKING + +if TYPE_CHECKING: + try: + from botocore.credentials import RefreshableCredentials + except ImportError: + RefreshableCredentials = object + + +@runtime_checkable +class _DatasetBackend(Protocol): + """This class is used to detect if an object implements a valid dataset backend using `isinstance(obj, + _DatasetBackend)`.""" + + def credentials(self) -> Dict[str, Optional[str]]: + ... + + def handle_error(self, exc: Exception) -> None: + ... + + +class S3DatasetBackend: + """A backend handler for datasets stored on S3.""" + + @staticmethod + def get_aws_credentials() -> "RefreshableCredentials": + """Gets AWS credentials from the current IAM role. + + Returns: + credentials object to be used for file reading + """ + from botocore.credentials import InstanceMetadataProvider + from botocore.utils import InstanceMetadataFetcher + + provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=1000, num_attempts=2)) + + credentials = provider.load() + + os.environ["AWS_ACCESS_KEY"] = credentials.access_key + os.environ["AWS_SECRET_ACCESS_KEY"] = credentials.secret_key + os.environ["AWS_SESSION_TOKEN"] = credentials.token + + return credentials + + def credentials(self) -> Dict[str, Optional[str]]: + if os.getenv("AWS_ACCESS_KEY") and os.getenv("AWS_SECRET_ACCESS_KEY"): + return {"access_key": os.getenv("AWS_ACCESS_KEY"), "secret_key": os.getenv("AWS_SECRET_ACCESS_KEY")} + + return self.get_aws_credentials() + + def handle_error(self, exc: Exception) -> None: + from botocore.exceptions import NoCredentialsError + + if isinstance(exc, NoCredentialsError): + raise ValueError( + "Unable to locate credentials. \ + Make sure you have set the following environment variables: \nAWS_ACCESS_KEY\\AWS_SECRET_ACCESS_KEY" + ) from exc + + raise exc + + +class LocalDatasetBackend: + """A backend handler for datasets stored locally.""" + + def credentials(self) -> Dict[str, Optional[str]]: + return {} + + def handle_error(self, exc: Exception) -> None: + raise exc diff --git a/src/lightning/data/dataset.py b/src/lightning/data/dataset.py new file mode 100644 index 0000000000000..831306dfa08ce --- /dev/null +++ b/src/lightning/data/dataset.py @@ -0,0 +1,98 @@ +import os +import tempfile +from abc import ABC, abstractmethod +from typing import Any, Optional + +from torch.utils.data import Dataset as TorchDataset + +from lightning.data.backends import _DatasetBackend, LocalDatasetBackend, S3DatasetBackend +from lightning.data.dataset_index import get_index +from lightning.data.fileio import OpenCloudFileObj + + +class LightningDataset(TorchDataset, ABC): + """Dataset wrapper for optimized dataloading. + + Arguments: + + data_source: path of data directory. ex. s3://mybucket/path + + backend: storage location of the data_source. current options are "s3" or "local" + + path_to_index_file: path to index file that lists all file contents of the data_source. + """ + + def __init__(self, data_source: str, backend: str = "local", path_to_index_file: Optional[str] = None): + super().__init__() + self.data_source = data_source + + if not path_to_index_file: + tmpdir = tempfile.mkdtemp() + path_to_index_file = os.path.join(tmpdir, "index.txt") + + self.index_file = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_index_file))) + + self.files = self.get_index() + + self.backend = self._init_backend(backend=backend) + + assert isinstance(self.backend, _DatasetBackend) + + def _init_backend(self, backend: str) -> _DatasetBackend: + """Picks the correct backend handler.""" + if backend == "s3": + return S3DatasetBackend() + if backend == "local": + return LocalDatasetBackend() + raise ValueError(f"Unsupported backend {backend}") + + def get_index(self) -> Any: + """Gets existing index or triggers an index generation if it doesn't exist for the provided data_source. + + Returns: + The contents of the index file (all the file paths in the data_source) + """ + if not os.path.isfile(self.index_file): + get_index(self.data_source, self.index_file) + + with open(self.index_file) as f: + index = f.readlines() + return (line.strip("\n") for line in index) + + def open(self, file: str, mode: str = "r", kwargs_for_open: Any = {}, **kwargs: Any) -> OpenCloudFileObj: + """Opens a stream for the given file. + + Returns: + A stream object of the file. + """ + return OpenCloudFileObj( + path=file, mode=mode, kwargs_for_open={**self.backend.credentials(), **kwargs_for_open}, **kwargs + ) + + def __getitem__(self, idx: int) -> Any: + """Get's item from the dataset at provided index. + + Returns: + The loaded item + """ + file_path = self.files[idx] + + try: + with self.open( + file_path, + "rb", + ) as stream: + return self.load_sample(file_path, stream) + except Exception as exc: + self.backend.handle_error(exc) + + @abstractmethod + def load_sample(self, file_path: str, stream: OpenCloudFileObj) -> Any: + """Loads each sample in the dataset. + + Any data prep/cleaning logic goes here. For ex. image transformations, text cleaning, etc. + """ + pass + + def __len__(self) -> int: + return len(self.files) diff --git a/src/lightning/data/dataset_index.py b/src/lightning/data/dataset_index.py new file mode 100644 index 0000000000000..b8742d8ff7086 --- /dev/null +++ b/src/lightning/data/dataset_index.py @@ -0,0 +1,136 @@ +import math +import os +from io import TextIOWrapper + +from lightning.app.utilities.network import LightningClient + + +def get_index(s3_connection_path: str, index_file_path: str) -> bool: + """Creates an index of file paths that are in the provided s3 path. + + Returns: + Returns True is the index got created and False if it wasn't + """ + + if s3_connection_path.startswith("/data/"): + s3_connection_path = s3_connection_path[len("/data/") :] + if s3_connection_path.startswith("s3://"): + s3_connection_path = s3_connection_path[len("s3://") :] + + try: + index_exists = _get_index(s3_connection_path, index_file_path) + except KeyError: + index_exists = False + except Exception as exc: + raise ValueError(f"Could not get index file with error: {exc}") + + # Fallback to creating an index from scratch + if not index_exists: + index_exists = _create_index(s3_connection_path, index_file_path) + + return index_exists + + +def _create_index_recursive(root: str, write_to: TextIOWrapper) -> None: + """Recursively pull files from s3 prefixes until full path is available.""" + from fsspec.core import url_to_fs + from torchdata.datapipes.iter import FSSpecFileLister + + files = FSSpecFileLister(root).list_files_by_fsspec() + + for file in files: + if file == root: + continue + + fs, path = url_to_fs(file) + + if not fs.isfile(file): + _create_index_recursive(root=file, write_to=write_to) + else: + write_to.write(file + "\n") + + +def _create_index(data_connection_path: str, index_file_path: str) -> bool: + """Fallback mechanism for index creation.""" + from botocore.exceptions import NoCredentialsError + + print(f"Creating Index for {data_connection_path} in {index_file_path}") + try: + list_from = f"s3://{data_connection_path}" if not os.path.isdir(data_connection_path) else data_connection_path + + if not os.path.exists(os.path.dirname(index_file_path)): + os.makedirs(os.path.dirname(index_file_path)) + + with open(index_file_path, "w") as f: + _create_index_recursive(root=list_from, write_to=f) + + return True + except NoCredentialsError as exc: + print( + "Unable to locate credentials. \ + Make sure you have set the following environment variables: \nAWS_ACCESS_KEY\\AWS_SECRET_ACCESS_KEY" + ) + os.remove(index_file_path) + raise ValueError(exc) + except Exception as exc: + os.remove(index_file_path) + raise ValueError(exc) + + +def _get_index(data_connection_path: str, index_file_path: str) -> bool: + """Expecting a string in the format s3:// or /data/... + + Returns: + True if the index retrieved + """ + + PROJECT_ID_ENV = "LCP_ID" + + client = LightningClient(retry=False) + + if PROJECT_ID_ENV in os.environ: + project_id = os.environ[PROJECT_ID_ENV] + else: + return False + + try: + cluster_bindings = client.projects_service_list_project_cluster_bindings(project_id).clusters + + # For now just use the first one + # For BYOC we will have to update this + cluster = cluster_bindings[0] + + # Find the data connection object first + data_connections = client.data_connection_service_list_data_connections(project_id).data_connections + data_connection = [con for con in data_connections if con.name == data_connection_path] + + if len(data_connection) == 1: + print(f"Placing existing index for {data_connection_path} in {index_file_path}") + + data_connection = data_connection[0] + # Then use the ID of the data connection for retrieving the index + folder_index = client.data_connection_service_get_data_connection_folder_index( + project_id=project_id, id=data_connection.id, cluster_id=cluster.cluster_id + ) + + # Compute number of pages we need to retrieve + num_pages = math.ceil(int(folder_index.nested_file_count) / folder_index.page_size) + + # Get all the pages and append to the index + with open(index_file_path, "a") as f: + f.truncate(0) + + for page_num in range(num_pages): + page = client.data_connection_service_get_data_connection_artifacts_page( + project_id=project_id, + id=data_connection.id, + cluster_id="litng-ai-03", + page_number=str(page_num), + ).artifacts + + f.writelines([f"s3://{data_connection_path}/{item.filename}" + "\n" for item in page]) + return True + return False + + except Exception: + return False diff --git a/src/lightning/data/fileio.py b/src/lightning/data/fileio.py new file mode 100644 index 0000000000000..a6e2ec887d301 --- /dev/null +++ b/src/lightning/data/fileio.py @@ -0,0 +1,126 @@ +import os +from typing import Any, Dict, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + try: + from torchdata.datapipes.utils import StreamWrapper + except ImportError: + StreamWrapper = object + + +def is_url(path: str) -> bool: + return path.startswith("s3://") + + +def is_path(path: str) -> bool: + return not is_url(path) and path.startswith("/") + + +def path_to_url(path: str, bucket_name: str, bucket_root_path: str = "/") -> str: + """Gets full S3 path given bucket info. + + Returns: + Full S3 url path + """ + if not path.startswith(bucket_root_path): + raise ValueError(f"Cannot create a path from {path} relative to {bucket_root_path}") + + rel_path = os.path.relpath(path, bucket_root_path).replace("\\", "/") + return f"s3://{bucket_name}/{rel_path}" + + +def open_single_file( + path_or_url: str, mode: str = "r", kwargs_for_open: Optional[Dict[str, Any]] = None, **kwargs: Any +) -> "StreamWrapper": + """Streams the given file. + + Returns: + The opened file stream. + """ + from torchdata.datapipes.iter import FSSpecFileOpener, IterableWrapper + + datapipe = IterableWrapper([path_or_url]) + + # iterable of length 1, still better than manually instantiating iterator and calling next + for _, stream in FSSpecFileOpener(datapipe, mode=mode, kwargs_for_open=kwargs_for_open, **kwargs): + return stream + return None + + +def open_single_file_with_retry( + path_or_url: str, mode: str = "r", kwargs_for_open: Optional[Dict[str, Any]] = None, **kwargs: Any +) -> "StreamWrapper": + """Streams the given file with a retry mechanism in case of high batch_size (>128) parallel opens. + + Returns: + The opened file stream. + """ + from torchdata.datapipes.iter import FSSpecFileOpener, IterableWrapper + + datapipe = IterableWrapper([path_or_url], **kwargs) + + num_attempts = 5 + + for _, stream in FSSpecFileOpener(datapipe, mode=mode, kwargs_for_open=kwargs_for_open, **kwargs): + curr_attempt = 0 + while curr_attempt < num_attempts: + try: + return stream + except Exception: + curr_attempt += 1 + + raise RuntimeError(f"Could not open {path_or_url}") + + +# Necessary to support both a context manager and a call +class OpenCloudFileObj: + """File object wrapper that streams files on open. + + Arguments: + + path: string containg the path of the file to be opened. + + mode: An optional string that specifies the mode in which the file is opened (``"r"`` by default). + + kwargs_for_open: Optional Dict to specify kwargs for opening files (``fs.open()``). + """ + + def __init__( + self, + path: str, + mode: str = "r", + kwargs_for_open: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ): + from torchdata.datapipes.utils import StreamWrapper + + self._path = path + self._stream: Optional[StreamWrapper] = None + self._mode = mode + self._kwargs_for_open = kwargs_for_open + self._kwargs = kwargs + + def __enter__(self) -> "StreamWrapper": + return self._conditionally_open() + + def __exit__(self, exc_type: str, exc_val: str, exc_tb: str) -> None: + if self._stream is not None: + self._stream.close() + + def _conditionally_open(self) -> "StreamWrapper": + if self._stream is None: + self._stream = open_single_file( + self._path, mode=self._mode, kwargs_for_open=self._kwargs_for_open, **self._kwargs + ) + + return self._stream + + def _conditionally_close(self) -> None: + if self._stream is not None: + self._stream.close() + + def __call__(self) -> "StreamWrapper": + return self._conditionally_open() + + def __getattr__(self, attr: str) -> Any: + return getattr(self._conditionally_open(), attr) diff --git a/tests/tests_data/__init__.py b/tests/tests_data/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_data/test_data/test_index.txt b/tests/tests_data/test_data/test_index.txt new file mode 100644 index 0000000000000..69d0437bc8ea5 --- /dev/null +++ b/tests/tests_data/test_data/test_index.txt @@ -0,0 +1,17 @@ +1/img-1000007.jpeg +1/img-1000002.jpeg +1/img-1.jpeg +1/img-100000.jpeg +1/img-1000006.jpeg +1/img-1000009.jpeg +1/img-1000008.jpeg +1/img-1000004.jpeg +1/img-1000000.jpeg +1/img-10000.jpeg +1/img-100.jpeg +1/img-0.jpeg +1/img-1000003.jpeg +1/img-10.jpeg +1/img-1000001.jpeg +1/img-1000.jpeg +1/img-1000005.jpeg diff --git a/tests/tests_data/test_data/test_index_s3.txt b/tests/tests_data/test_data/test_index_s3.txt new file mode 100644 index 0000000000000..234dad9506630 --- /dev/null +++ b/tests/tests_data/test_data/test_index_s3.txt @@ -0,0 +1,17 @@ +s3://nohaspublictestbucket/1/img-0.jpeg +s3://nohaspublictestbucket/1/img-1.jpeg +s3://nohaspublictestbucket/1/img-10.jpeg +s3://nohaspublictestbucket/1/img-100.jpeg +s3://nohaspublictestbucket/1/img-1000.jpeg +s3://nohaspublictestbucket/1/img-10000.jpeg +s3://nohaspublictestbucket/1/img-100000.jpeg +s3://nohaspublictestbucket/1/img-1000000.jpeg +s3://nohaspublictestbucket/1/img-1000001.jpeg +s3://nohaspublictestbucket/1/img-1000002.jpeg +s3://nohaspublictestbucket/1/img-1000003.jpeg +s3://nohaspublictestbucket/1/img-1000004.jpeg +s3://nohaspublictestbucket/1/img-1000005.jpeg +s3://nohaspublictestbucket/1/img-1000006.jpeg +s3://nohaspublictestbucket/1/img-1000007.jpeg +s3://nohaspublictestbucket/1/img-1000008.jpeg +s3://nohaspublictestbucket/1/img-1000009.jpeg diff --git a/tests/tests_data/test_dataset.py b/tests/tests_data/test_dataset.py new file mode 100644 index 0000000000000..0ac232608ed8e --- /dev/null +++ b/tests/tests_data/test_dataset.py @@ -0,0 +1,116 @@ +import os +import socket +from types import GeneratorType +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest +from lightning_utilities.core.imports import package_available + +from lightning.data import dataset_index +from lightning.data.dataset import LightningDataset +from lightning.data.fileio import OpenCloudFileObj + + +def isConnectedWithInternet(): + try: + socket.create_connection(("1.1.1.1", 53)) + return True + except OSError: + pass + return False + + +@pytest.fixture(scope="session") +def image_set(tmp_path_factory): + from PIL import Image + + file_nums = [ + 0, + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 1000001, + 1000002, + 1000003, + 1000004, + 1000005, + 1000006, + 1000007, + 1000008, + 1000009, + ] + + img = np.random.randint(255, size=(800, 800)) + img = img.astype(np.uint8) + im = Image.fromarray(img) + + for i in file_nums: + fn = tmp_path_factory.mktemp("test_data") / f"img-{i}.jpeg" + im.save(fn) + + return tmp_path_factory.getbasetemp()._str + + +class TestLightningDataset(LightningDataset): + def __init__(self, data_source, backend, path_to_index_file): + super().__init__(data_source=data_source, backend=backend, path_to_index_file=path_to_index_file) + + def load_sample(self, file_path, stream): + from PIL import Image + + img = Image.open(stream) + return img + + +@pytest.mark.skipif(not isConnectedWithInternet(), reason="Not connected to internet") +@pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package") +@mock.patch("lightning.data.dataset_index.LightningClient", MagicMock()) +def test_lightning_dataset(tmpdir, image_set, monkeypatch): + client = MagicMock() + client.projects_service_list_project_cluster_bindings.return_value = None + client.data_connection_service_list_data_connections.return_value = None + client.data_connection_service_get_data_connection_folder_index.return_value = None + client.data_connection_service_get_data_connection_artifacts_page.return_value = None + + monkeypatch.setattr(dataset_index, "LightningClient", MagicMock(return_value=client)) + + index_path = os.path.join(tmpdir, "index.txt") + + dset = TestLightningDataset(image_set, backend="local", path_to_index_file=index_path) + tuple_of_files = dset.get_index() + assert isinstance(tuple_of_files, GeneratorType) + files_list = list(tuple_of_files) + + assert os.path.isfile(index_path) + with open(index_path) as f: + file_content = f.readlines() + + assert len(file_content) == len(files_list) + + for file_entry, tuple_entry in zip(file_content, tuple_of_files): + assert file_content == tuple_entry + "\n" + + assert isinstance(dset.open(index_path), OpenCloudFileObj) + + foo_path = os.path.join(tmpdir, "foo.txt") + with open(foo_path, "w") as f: + f.write("bar!") + + with dset.open(foo_path, "r") as f: + assert f.read() == "bar!" + + with dset.open(foo_path, "w") as f: + f.write("not bar anymore!") + + with open(foo_path) as f: + assert f.read() == "not bar anymore!" + + file_obj = dset.open(foo_path, "w") + file_obj.close() + assert file_obj._stream.closed diff --git a/tests/tests_data/test_fileio.py b/tests/tests_data/test_fileio.py new file mode 100644 index 0000000000000..3b40b992ba1b9 --- /dev/null +++ b/tests/tests_data/test_fileio.py @@ -0,0 +1,121 @@ +import os +from unittest import mock + +import pytest + +from lightning.data.fileio import is_path, is_url, open_single_file, OpenCloudFileObj, path_to_url + + +@pytest.mark.parametrize( + ("input_str", "expected"), + [ + ("s3://my_bucket/a", True), + ("s3:/my_bucket", False), + ("my_bucket", False), + ("my_bucket_s3://", False), + ], +) +def test_is_url(input_str, expected): + assert is_url(input_str) == expected + + +@pytest.mark.parametrize( + ("input_str", "expected"), + [ + ("s3://my_bucket/a", False), + ("s3:/my_bucket", False), + ("my_bucket", False), + ("my_bucket_s3://", False), + ("/my_bucket", True), + ], +) +def test_is_path(input_str, expected): + assert is_path(input_str) == expected + + +@pytest.mark.parametrize( + ("path", "bucket_name", "bucket_root_path", "expected"), + [ + ("/data/abc/def", "my_bucket", "/data/abc", "s3://my_bucket/def"), + ("/data/abc/def", "my_bucket", "/data", "s3://my_bucket/abc/def"), + ], +) +def test_path_to_url(path, bucket_name, bucket_root_path, expected): + assert path_to_url(path, bucket_name, bucket_root_path) == expected + + +def test_path_to_url_error(): + with pytest.raises(ValueError, match="Cannot create a path from /path1/abc relative to /path2"): + path_to_url("/path1/abc", "foo", "/path2") + + +@pytest.mark.parametrize("path", ["s3://my_bucket/da.txt", "abc.txt"]) +@mock.patch("s3fs.S3FileSystem", autospec=True) +def test_read_single_file_read(patch: mock.Mock, path, tmpdir): + from torchdata.datapipes.utils import StreamWrapper + + is_s3 = is_url(path) + + if not is_s3: + path = os.path.join(tmpdir, path) + with open(path, "w") as f: + f.write("mytestfile") + + file_stream = open_single_file(path) + assert isinstance(file_stream, StreamWrapper) + + content = file_stream.read() + + if is_s3: + assert isinstance(file_stream.file_obj, mock.Mock) + assert patch.open.assert_called_once + + else: + assert content == "mytestfile" + + +@pytest.mark.parametrize("path", ["s3://my_bucket/da.txt", "abc.txt"]) +@mock.patch("s3fs.S3FileSystem", autospec=True) +def test_read_single_file_write(patch: mock.Mock, path, tmpdir): + from torchdata.datapipes.utils import StreamWrapper + + is_s3 = is_url(path) + + if not is_s3: + path = os.path.join(tmpdir, path) + + file_stream = open_single_file(path, mode="w") + assert isinstance(file_stream, StreamWrapper) + file_stream.write("mytestfile") + file_stream.close() + + if is_s3: + assert isinstance(file_stream.file_obj, mock.Mock) + assert patch.open.assert_called_once + + else: + with open(path) as f: + assert f.read() == "mytestfile" + + +def test_open_cloud_file_obj(tmpdir): + path = os.path.join(tmpdir, "foo.txt") + with open(path, "w") as f: + f.write("bar!") + + f = OpenCloudFileObj(path) + + with f: + assert f.read() == "bar!" + assert f._stream.closed + + f = OpenCloudFileObj(path) + assert f.read() == "bar!" + f.close() + assert f._stream.closed + + with OpenCloudFileObj(path, "w") as f: + f.write("not bar anymore!") + + with open(path) as f: + assert f.read() == "not bar anymore!" diff --git a/tests/tests_data/test_get_index.py b/tests/tests_data/test_get_index.py new file mode 100644 index 0000000000000..81703b4d45de5 --- /dev/null +++ b/tests/tests_data/test_get_index.py @@ -0,0 +1,131 @@ +import os +from unittest import mock +from unittest.mock import MagicMock + +import numpy as np +import pytest +from lightning_utilities.core.imports import package_available + +from lightning.data import dataset_index +from lightning.data.dataset_index import get_index + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) + + +def get_test_index_data(index_path): + with open(index_path) as f: + data = f.readlines() + return list(dict.fromkeys([item.split("/")[-1] for item in data if "jpeg" in item])) + + +@pytest.fixture(scope="session") +def image_set(tmp_path_factory): + from PIL import Image + + file_nums = [ + 0, + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 1000001, + 1000002, + 1000003, + 1000004, + 1000005, + 1000006, + 1000007, + 1000008, + 1000009, + ] + + img = np.random.randint(255, size=(800, 800)) + img = img.astype(np.uint8) + im = Image.fromarray(img) + + for i in file_nums: + fn = tmp_path_factory.mktemp("test_data") / f"img-{i}.jpeg" + im.save(fn) + + return tmp_path_factory.getbasetemp()._str + + +@pytest.mark.skip(reason="Need a valid AWS key and AWS secret key in CI for this to work") +@mock.patch("lightning.data.dataset_index.LightningClient", MagicMock()) +def test_get_index_generate_for_s3_bucket(monkeypatch): + """Can generate an index as s3 bucket mounted localled on the Lightning AI platform.""" + + client = MagicMock() + client.projects_service_list_project_cluster_bindings.return_value = None + client.data_connection_service_list_data_connections.return_value = None + client.data_connection_service_get_data_connection_folder_index.return_value = None + client.data_connection_service_get_data_connection_artifacts_page.return_value = None + + monkeypatch.setattr(dataset_index, "LightningClient", MagicMock(return_value=client)) + + test_index_path = f"{THIS_DIR}/test_data/test_index_s3.txt" + test_index_data = get_test_index_data(test_index_path) + + test_bucket = "s3://nohaspublictestbucket" + index_path = os.path.join(os.getcwd(), "index_1.txt") + print(index_path) + got_index = get_index(s3_connection_path=test_bucket, index_file_path=index_path) + + assert got_index + + generated_index = get_test_index_data(index_path) + print("generted index", generated_index) + + assert len(test_index_data) == len(generated_index) + assert test_index_data == generated_index + + +@pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package") +@mock.patch("lightning.data.dataset_index.LightningClient", MagicMock()) +def test_get_index_generate_for_local_folder(image_set, monkeypatch): + """Can generate an index for an s3 bucket.""" + + client = MagicMock() + client.projects_service_list_project_cluster_bindings.return_value = None + client.data_connection_service_list_data_connections.return_value = None + client.data_connection_service_get_data_connection_folder_index.return_value = None + client.data_connection_service_get_data_connection_artifacts_page.return_value = None + + monkeypatch.setattr(dataset_index, "LightningClient", MagicMock(return_value=client)) + + test_index_path = f"{THIS_DIR}/test_data/test_index.txt" + test_index_data = get_test_index_data(test_index_path) + + # test_local_bucket = "data/test_dataset" + index_path = os.path.join(THIS_DIR, "index_2.txt") + got_index = get_index(s3_connection_path=image_set, index_file_path=index_path) + + assert got_index + + generated_index = get_test_index_data(index_path) + + assert len(test_index_data) == len(generated_index) + + item_from_gen_list = list(dict.fromkeys([item.split("/")[-1] for item in generated_index if "jpeg" in item])) + assert sorted(test_index_data) == sorted(item_from_gen_list) + + +@pytest.mark.skip(reason="Not required at the moment") +def test_get_index_generate_for_mounted_s3_bucket(): + """Can generate an index for an s3 bucket.""" + test_index_path = f"{THIS_DIR}/test_data/test_index_s3.txt" + test_index_data = get_test_index_data(test_index_path) + + test_local_bucket = "/data/nohaspublictestbucket" + index_path = os.path.join(THIS_DIR, "index_3.txt") + got_index = get_index(s3_connection_path=test_local_bucket, index_file_path=index_path) + + assert got_index + + generated_index = get_test_index_data(index_path) + + assert len(test_index_data) == len(generated_index) + assert test_index_data == generated_index