Skip to content

Commit

Permalink
Add BaseDataset and FileSystemDataset classes (#76)
Browse files Browse the repository at this point in the history
* Add BaseDataset and FileSystemDataset classes

* Fix tests

* Fix tests

* Type fix

* Add column_bytes_file_reader parameter to BaseDataset init

* Add DatasetFactory class to hide implementation details

* Fix test

* Replace factory class with builder function

* Fix builder function to accept a dataset configuration type and update tests

* Fix type-check

* Support config for one S3Dataset and multiple FileSystemDatasets

* Version bump

---------

Co-authored-by: Alex Bain (Woven by Toyota <[email protected]>
  • Loading branch information
convexquad and Alex Bain (Woven by Toyota authored Dec 11, 2024
1 parent 383b83b commit babcf9a
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 94 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.0.17
0.0.18
12 changes: 12 additions & 0 deletions tests/.wickerconfig.test.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
"connect_timeout_s": 140
}
},
"filesystem_configs": [
{
"config_name": "filesystem_1",
"prefix_replace_path": "s3://fake_data_1/",
"root_datasets_path": "/mnt/fake_data_1/"
},
{
"config_name": "filesystem_2",
"prefix_replace_path": "s3://fake_data_2/",
"root_datasets_path": "/mnt/fake_data_2/"
}
],
"dynamodb_config": {
"table_name": "fake_db",
"region": "us-west-2"
Expand Down
99 changes: 95 additions & 4 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
import pyarrow.parquet as papq # type: ignore

from wicker.core.column_files import ColumnBytesFileWriter
from wicker.core.datasets import S3Dataset
from wicker.core.config import (
FILESYSTEM_CONFIG,
StorageDownloadConfig,
WickerAwsS3Config,
WickerConfig,
WickerFileSystemConfig,
WickerWandBConfig,
)
from wicker.core.datasets import FileSystemDataset, S3Dataset, build_dataset
from wicker.core.definitions import DatasetID, DatasetPartition
from wicker.core.storage import S3PathFactory
from wicker.core.storage import FileSystemDataStorage, S3PathFactory, WickerPathFactory
from wicker.schema import schema, serialization
from wicker.testing.storage import FakeS3DataStorage

Expand All @@ -36,6 +44,25 @@
FAKE_DATA = [{"foo": f"bar{i}", "np_arr": np.eye(4)} for i in range(1000)]


def build_mock_wicker_config(tmpdir: str) -> WickerConfig:
"""Helper function to build WickerConfig objects to use as unit test mocks."""
return WickerConfig(
raw={},
aws_s3_config=WickerAwsS3Config.from_json({}),
filesystem_configs=[
WickerFileSystemConfig.from_json(
{
"config_name": "filesystem_1",
"prefix_replace_path": "",
"root_datasets_path": os.path.join(tmpdir, "fake_data"),
}
),
],
storage_download_config=StorageDownloadConfig.from_json({}),
wandb_config=WickerWandBConfig.from_json({}),
)


@contextmanager
def cwd(path):
"""Changes the current working directory, and returns to the previous directory afterwards"""
Expand All @@ -47,6 +74,72 @@ def cwd(path):
os.chdir(oldpwd)


class TestFileSystemDataset(unittest.TestCase):
@contextmanager
def _setup_storage(self) -> Iterator[Tuple[FileSystemDataStorage, WickerPathFactory, str]]:
with tempfile.TemporaryDirectory() as tmpdir, cwd(tmpdir):
fake_local_fs_storage = FileSystemDataStorage()
fake_local_path_factory = WickerPathFactory(root_path=os.path.join(tmpdir, "fake_data"))
fake_s3_path_factory = S3PathFactory()
fake_s3_storage = FakeS3DataStorage(tmpdir=tmpdir)
with ColumnBytesFileWriter(
storage=fake_s3_storage,
s3_path_factory=fake_s3_path_factory,
target_file_rowgroup_size=10,
) as writer:
locs = [
writer.add("np_arr", FAKE_NUMPY_CODEC.validate_and_encode_object(data["np_arr"])) # type: ignore
for data in FAKE_DATA
]

arrow_metadata_table = pa.Table.from_pydict(
{"foo": [data["foo"] for data in FAKE_DATA], "np_arr": [loc.to_bytes() for loc in locs]}
)
metadata_table_path = os.path.join(
tmpdir, fake_local_path_factory._get_dataset_partition_path(FAKE_DATASET_PARTITION)
)
os.makedirs(os.path.dirname(metadata_table_path), exist_ok=True)
papq.write_table(arrow_metadata_table, metadata_table_path)

# The mock storage class here actually writes to local storage, so we can use it in the test.
fake_s3_storage.put_object_s3(
serialization.dumps(FAKE_SCHEMA).encode("utf-8"),
fake_local_path_factory._get_dataset_schema_path(FAKE_DATASET_ID),
)
yield fake_local_fs_storage, fake_local_path_factory, tmpdir

def test_filesystem_dataset(self):
with self._setup_storage() as (fake_local_storage, fake_local_path_factory, tmpdir):
ds = FileSystemDataset(
FAKE_NAME,
FAKE_VERSION,
FAKE_PARTITION,
fake_local_path_factory,
fake_local_storage,
)
for i in range(len(FAKE_DATA)):
retrieved = ds[i]
reference = FAKE_DATA[i]
self.assertEqual(retrieved["foo"], reference["foo"])
np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"])

# Also double-check that the builder function is working correctly.
with patch("wicker.core.datasets.get_config") as mock_get_config:
mock_get_config.return_value = build_mock_wicker_config(tmpdir)
ds2 = build_dataset(
FILESYSTEM_CONFIG,
FAKE_NAME,
FAKE_VERSION,
FAKE_PARTITION,
config_name="filesystem_1",
)
for i in range(len(FAKE_DATA)):
retrieved = ds2[i]
reference = FAKE_DATA[i]
self.assertEqual(retrieved["foo"], reference["foo"])
np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"])


class TestS3Dataset(unittest.TestCase):
@contextmanager
def _setup_storage(self) -> Iterator[Tuple[FakeS3DataStorage, S3PathFactory, str]]:
Expand Down Expand Up @@ -144,14 +237,12 @@ def __init__(self) -> None:
def Object(self, bucket: str, key: str) -> FakeResponse:
full_path = os.path.join("s3://", bucket, key)
data = fake_s3_storage.fetch_obj_s3(full_path)

return FakeResponse(content_length=len(data))

def mock_resource_returner(_: Any):
return MockedS3Resource()

with patch("wicker.core.datasets.boto3.resource", mock_resource_returner):

ds = S3Dataset(
FAKE_NAME,
FAKE_VERSION,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_fetch_file(self) -> None:
# create local file store
local_datastore = FileSystemDataStorage()
# save file to destination
local_datastore.fetch_file(src_dir, dst_path)
local_datastore.fetch_file(src_path, dst_dir)

# verify file exists
assert os.path.exists(dst_path)
Expand Down
53 changes: 38 additions & 15 deletions wicker/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import dataclasses
import json
import os
from typing import Any, Dict
from typing import Any, Dict, List

AWS_S3_CONFIG = "aws_s3_config"
FILESYSTEM_CONFIG = "filesystem_config"


@dataclasses.dataclass(frozen=True)
Expand All @@ -16,7 +19,6 @@ class WickerWandBConfig:

@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerWandBConfig:
# only load them if they exist, otherwise leave out
return cls(
wandb_api_key=data.get("wandb_api_key", None),
wandb_base_url=data.get("wandb_base_url", None),
Expand All @@ -32,9 +34,9 @@ class BotoS3Config:
@classmethod
def from_json(cls, data: Dict[str, Any]) -> BotoS3Config:
return cls(
max_pool_connections=data["max_pool_connections"],
read_timeout_s=data["read_timeout_s"],
connect_timeout_s=data["connect_timeout_s"],
max_pool_connections=data.get("max_pool_connections", 0),
read_timeout_s=data.get("read_timeout_s", 0),
connect_timeout_s=data.get("connect_timeout_s", 0),
)


Expand All @@ -48,13 +50,32 @@ class WickerAwsS3Config:
@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerAwsS3Config:
return cls(
s3_datasets_path=data["s3_datasets_path"],
region=data["region"],
boto_config=BotoS3Config.from_json(data["boto_config"]),
s3_datasets_path=data.get("s3_datasets_path", ""),
region=data.get("region", ""),
boto_config=BotoS3Config.from_json(data.get("boto_config", {})),
store_concatenated_bytes_files_in_dataset=data.get("store_concatenated_bytes_files_in_dataset", False),
)


@dataclasses.dataclass(frozen=True)
class WickerFileSystemConfig:
config_name: str
prefix_replace_path: str
root_datasets_path: str

@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerFileSystemConfig:
return cls(
config_name=data.get("config_name", ""),
prefix_replace_path=data.get("prefix_replace_path", ""),
root_datasets_path=data.get("root_datasets_path", ""),
)

@classmethod
def from_json_list(cls, data: List[Dict[str, Any]]) -> List[WickerFileSystemConfig]:
return [WickerFileSystemConfig.from_json(d) for d in data]


@dataclasses.dataclass(frozen=True)
class StorageDownloadConfig:
retries: int
Expand All @@ -65,26 +86,28 @@ class StorageDownloadConfig:
@classmethod
def from_json(cls, data: Dict[str, Any]) -> StorageDownloadConfig:
return cls(
retries=data["retries"],
timeout=data["timeout"],
retry_backoff=data["retry_backoff"],
retry_delay_s=data["retry_delay_s"],
retries=data.get("retries", 0),
timeout=data.get("timeout", 0),
retry_backoff=data.get("retry_backoff", 0),
retry_delay_s=data.get("retry_delay_s", 0),
)


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass()
class WickerConfig:
raw: Dict[str, Any]
aws_s3_config: WickerAwsS3Config
filesystem_configs: List[WickerFileSystemConfig]
storage_download_config: StorageDownloadConfig
wandb_config: WickerWandBConfig

@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerConfig:
return cls(
raw=data,
aws_s3_config=WickerAwsS3Config.from_json(data["aws_s3_config"]),
storage_download_config=StorageDownloadConfig.from_json(data["storage_download_config"]),
aws_s3_config=WickerAwsS3Config.from_json(data.get(AWS_S3_CONFIG, {})),
filesystem_configs=WickerFileSystemConfig.from_json_list(data.get("filesystem_configs", [])),
storage_download_config=StorageDownloadConfig.from_json(data.get("storage_download_config", {})),
wandb_config=WickerWandBConfig.from_json(data.get("wandb_config", {})),
)

Expand Down
Loading

0 comments on commit babcf9a

Please sign in to comment.