Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BaseDataset and FileSystemDataset classes #76

Merged
12 changes: 12 additions & 0 deletions tests/.wickerconfig.test.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
"connect_timeout_s": 140
}
},
"filesystem_configs": [
{
"config_name": "filesystem_1",
"prefix_replace_path": "s3://fake_data_1/",
"root_datasets_path": "/mnt/fake_data_1/"
},
{
"config_name": "filesystem_2",
"prefix_replace_path": "s3://fake_data_2/",
"root_datasets_path": "/mnt/fake_data_2/"
}
],
"dynamodb_config": {
"table_name": "fake_db",
"region": "us-west-2"
Expand Down
99 changes: 95 additions & 4 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@
import pyarrow.parquet as papq # type: ignore

from wicker.core.column_files import ColumnBytesFileWriter
from wicker.core.datasets import S3Dataset
from wicker.core.config import (
FILESYSTEM_CONFIG,
StorageDownloadConfig,
WickerAwsS3Config,
WickerConfig,
WickerFileSystemConfig,
WickerWandBConfig,
)
from wicker.core.datasets import FileSystemDataset, S3Dataset, build_dataset
from wicker.core.definitions import DatasetID, DatasetPartition
from wicker.core.storage import S3PathFactory
from wicker.core.storage import FileSystemDataStorage, S3PathFactory, WickerPathFactory
from wicker.schema import schema, serialization
from wicker.testing.storage import FakeS3DataStorage

Expand All @@ -36,6 +44,25 @@
FAKE_DATA = [{"foo": f"bar{i}", "np_arr": np.eye(4)} for i in range(1000)]


def build_mock_wicker_config(tmpdir: str) -> WickerConfig:
"""Helper function to build WickerConfig objects to use as unit test mocks."""
return WickerConfig(
raw={},
aws_s3_config=WickerAwsS3Config.from_json({}),
filesystem_configs=[
WickerFileSystemConfig.from_json(
{
"config_name": "filesystem_1",
"prefix_replace_path": "",
"root_datasets_path": os.path.join(tmpdir, "fake_data"),
}
),
],
storage_download_config=StorageDownloadConfig.from_json({}),
wandb_config=WickerWandBConfig.from_json({}),
)


@contextmanager
def cwd(path):
"""Changes the current working directory, and returns to the previous directory afterwards"""
Expand All @@ -47,6 +74,72 @@ def cwd(path):
os.chdir(oldpwd)


class TestFileSystemDataset(unittest.TestCase):
@contextmanager
def _setup_storage(self) -> Iterator[Tuple[FileSystemDataStorage, WickerPathFactory, str]]:
with tempfile.TemporaryDirectory() as tmpdir, cwd(tmpdir):
fake_local_fs_storage = FileSystemDataStorage()
fake_local_path_factory = WickerPathFactory(root_path=os.path.join(tmpdir, "fake_data"))
fake_s3_path_factory = S3PathFactory()
fake_s3_storage = FakeS3DataStorage(tmpdir=tmpdir)
with ColumnBytesFileWriter(
storage=fake_s3_storage,
s3_path_factory=fake_s3_path_factory,
target_file_rowgroup_size=10,
) as writer:
locs = [
writer.add("np_arr", FAKE_NUMPY_CODEC.validate_and_encode_object(data["np_arr"])) # type: ignore
for data in FAKE_DATA
]

arrow_metadata_table = pa.Table.from_pydict(
{"foo": [data["foo"] for data in FAKE_DATA], "np_arr": [loc.to_bytes() for loc in locs]}
)
metadata_table_path = os.path.join(
tmpdir, fake_local_path_factory._get_dataset_partition_path(FAKE_DATASET_PARTITION)
)
os.makedirs(os.path.dirname(metadata_table_path), exist_ok=True)
papq.write_table(arrow_metadata_table, metadata_table_path)

# The mock storage class here actually writes to local storage, so we can use it in the test.
fake_s3_storage.put_object_s3(
serialization.dumps(FAKE_SCHEMA).encode("utf-8"),
fake_local_path_factory._get_dataset_schema_path(FAKE_DATASET_ID),
)
yield fake_local_fs_storage, fake_local_path_factory, tmpdir

def test_filesystem_dataset(self):
with self._setup_storage() as (fake_local_storage, fake_local_path_factory, tmpdir):
ds = FileSystemDataset(
FAKE_NAME,
FAKE_VERSION,
FAKE_PARTITION,
fake_local_path_factory,
fake_local_storage,
)
for i in range(len(FAKE_DATA)):
retrieved = ds[i]
reference = FAKE_DATA[i]
self.assertEqual(retrieved["foo"], reference["foo"])
np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"])

# Also double-check that the builder function is working correctly.
with patch("wicker.core.datasets.get_config") as mock_get_config:
mock_get_config.return_value = build_mock_wicker_config(tmpdir)
ds2 = build_dataset(
FILESYSTEM_CONFIG,
FAKE_NAME,
FAKE_VERSION,
FAKE_PARTITION,
config_name="filesystem_1",
)
for i in range(len(FAKE_DATA)):
retrieved = ds2[i]
reference = FAKE_DATA[i]
self.assertEqual(retrieved["foo"], reference["foo"])
np.testing.assert_array_equal(retrieved["np_arr"], reference["np_arr"])


class TestS3Dataset(unittest.TestCase):
@contextmanager
def _setup_storage(self) -> Iterator[Tuple[FakeS3DataStorage, S3PathFactory, str]]:
Expand Down Expand Up @@ -144,14 +237,12 @@ def __init__(self) -> None:
def Object(self, bucket: str, key: str) -> FakeResponse:
full_path = os.path.join("s3://", bucket, key)
data = fake_s3_storage.fetch_obj_s3(full_path)

return FakeResponse(content_length=len(data))

def mock_resource_returner(_: Any):
return MockedS3Resource()

with patch("wicker.core.datasets.boto3.resource", mock_resource_returner):

ds = S3Dataset(
FAKE_NAME,
FAKE_VERSION,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_fetch_file(self) -> None:
# create local file store
local_datastore = FileSystemDataStorage()
# save file to destination
local_datastore.fetch_file(src_dir, dst_path)
local_datastore.fetch_file(src_path, dst_dir)

# verify file exists
assert os.path.exists(dst_path)
Expand Down
53 changes: 38 additions & 15 deletions wicker/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import dataclasses
import json
import os
from typing import Any, Dict
from typing import Any, Dict, List

AWS_S3_CONFIG = "aws_s3_config"
FILESYSTEM_CONFIG = "filesystem_config"


@dataclasses.dataclass(frozen=True)
Expand All @@ -16,7 +19,6 @@ class WickerWandBConfig:

@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerWandBConfig:
# only load them if they exist, otherwise leave out
return cls(
wandb_api_key=data.get("wandb_api_key", None),
wandb_base_url=data.get("wandb_base_url", None),
Expand All @@ -32,9 +34,9 @@ class BotoS3Config:
@classmethod
def from_json(cls, data: Dict[str, Any]) -> BotoS3Config:
return cls(
max_pool_connections=data["max_pool_connections"],
read_timeout_s=data["read_timeout_s"],
connect_timeout_s=data["connect_timeout_s"],
max_pool_connections=data.get("max_pool_connections", 0),
read_timeout_s=data.get("read_timeout_s", 0),
connect_timeout_s=data.get("connect_timeout_s", 0),
)


Expand All @@ -48,13 +50,32 @@ class WickerAwsS3Config:
@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerAwsS3Config:
return cls(
s3_datasets_path=data["s3_datasets_path"],
region=data["region"],
boto_config=BotoS3Config.from_json(data["boto_config"]),
s3_datasets_path=data.get("s3_datasets_path", ""),
region=data.get("region", ""),
boto_config=BotoS3Config.from_json(data.get("boto_config", {})),
store_concatenated_bytes_files_in_dataset=data.get("store_concatenated_bytes_files_in_dataset", False),
)


@dataclasses.dataclass(frozen=True)
class WickerFileSystemConfig:
config_name: str
prefix_replace_path: str
root_datasets_path: str

@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerFileSystemConfig:
return cls(
config_name=data.get("config_name", ""),
prefix_replace_path=data.get("prefix_replace_path", ""),
root_datasets_path=data.get("root_datasets_path", ""),
)

@classmethod
def from_json_list(cls, data: List[Dict[str, Any]]) -> List[WickerFileSystemConfig]:
return [WickerFileSystemConfig.from_json(d) for d in data]


@dataclasses.dataclass(frozen=True)
class StorageDownloadConfig:
retries: int
Expand All @@ -65,26 +86,28 @@ class StorageDownloadConfig:
@classmethod
def from_json(cls, data: Dict[str, Any]) -> StorageDownloadConfig:
return cls(
retries=data["retries"],
timeout=data["timeout"],
retry_backoff=data["retry_backoff"],
retry_delay_s=data["retry_delay_s"],
retries=data.get("retries", 0),
timeout=data.get("timeout", 0),
retry_backoff=data.get("retry_backoff", 0),
retry_delay_s=data.get("retry_delay_s", 0),
)


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass()
class WickerConfig:
raw: Dict[str, Any]
aws_s3_config: WickerAwsS3Config
filesystem_configs: List[WickerFileSystemConfig]
storage_download_config: StorageDownloadConfig
wandb_config: WickerWandBConfig

@classmethod
def from_json(cls, data: Dict[str, Any]) -> WickerConfig:
return cls(
raw=data,
aws_s3_config=WickerAwsS3Config.from_json(data["aws_s3_config"]),
storage_download_config=StorageDownloadConfig.from_json(data["storage_download_config"]),
aws_s3_config=WickerAwsS3Config.from_json(data.get(AWS_S3_CONFIG, {})),
filesystem_configs=WickerFileSystemConfig.from_json_list(data.get("filesystem_configs", [])),
storage_download_config=StorageDownloadConfig.from_json(data.get("storage_download_config", {})),
wandb_config=WickerWandBConfig.from_json(data.get("wandb_config", {})),
)

Expand Down
Loading
Loading