Skip to content

Commit

Permalink
Use PathWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinscholl committed Dec 5, 2023
1 parent f404005 commit 2e683fd
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 63 deletions.
39 changes: 0 additions & 39 deletions mex/common/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,42 +268,3 @@ def resolve_paths(self) -> "BaseSettings":
setattr(self, name, base_path / value)

return self

@model_validator(mode="before")
@classmethod
def coerce_paths(cls, raw_input: dict[str, Any]) -> dict[str, Any]:
"""Coerce relevant input strings to AssetPath or WorkPaths."""
fields_by_keys = cls.model_fields
fields_by_alias = {
field.validation_alias: field for key, field in cls.model_fields.items()
}

for field_name, field_value in raw_input.items():
if field_name not in fields_by_keys and field_name not in fields_by_alias:
continue
field_type = None
if field_name in fields_by_keys:
field_type = fields_by_keys[field_name].annotation
elif field_name in fields_by_alias:
field_type = fields_by_alias[field_name].annotation
if field_type is not None:
if isinstance(field_type, AssetsPath) or (
hasattr(field_type, "__args__")
and AssetsPath in field_type.__args__
):
try:
raw_input[field_name] = AssetsPath(raw_input[field_name])
# TODO coercing here is not enough to check if input is pathlike
# All input comes in as string and AssetsPath and WorkPath only
# check if the input is a string.
except TypeError:
continue
elif isinstance(field_type, WorkPath) or (
hasattr(field_type, "__args__") and WorkPath in field_type.__args__
):
try:
raw_input[field_name] = WorkPath(raw_input[field_name])
except TypeError:
continue

return raw_input
3 changes: 3 additions & 0 deletions mex/common/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pydantic import BaseModel as PydanticModel

from mex.common.types import Timestamp
from mex.common.types.path import PathWrapper


class MExEncoder(json.JSONEncoder):
Expand All @@ -31,6 +32,8 @@ def default(self, obj: Any) -> Any:
return str(obj)
if isinstance(obj, PurePath):
return obj.as_posix()
if isinstance(obj, PathWrapper):
return str(obj)
return json.JSONEncoder.default(self, obj)


Expand Down
101 changes: 83 additions & 18 deletions mex/common/types/path.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,95 @@
import sys
from pathlib import Path, PosixPath, WindowsPath
from typing import Any
from os import PathLike
from pathlib import Path
from typing import Any, Type, TypeVar, Union
from warnings import warn

if sys.platform == "win32":
path_type = WindowsPath
else:
path_type = PosixPath
from pydantic_core import core_schema

PathWrapperT = TypeVar("PathWrapperT", bound="PathWrapper")

class _DeprecatedResolvedPath(path_type):
"""Class to support the deprecated .resolve() and .raw() methods."""

def resolve(self, *args: Any, **kwargs: Any) -> Path: # type: ignore
"""Return absolute path."""
class PathWrapper(PathLike[str]):
"""Custom path for settings that can be absolute or relative to another setting."""

_path: Path

def __init__(self, path: Union[str, Path, "PathWrapper"]) -> None:
"""Create a new resolved path instance."""
if isinstance(path, str):
path = Path(path)
elif isinstance(path, PathWrapper):
path = path._path
self._path = path

def __fspath__(self) -> str:
"""Return the file system path representation."""
return self._path.__fspath__()

def __truediv__(self, other: str | PathLike[str]) -> Path:
"""Return a joined path on the basis of `/`."""
return self._path.__truediv__(other)

def __str__(self) -> str:
"""Return a string rendering of the resolved path."""
return self._path.as_posix()

def __repr__(self) -> str:
"""Return a representation string of the resolved path."""
return f'{self.__class__.__name__}("{self}")'

def __eq__(self, other: Any) -> bool:
"""Return true for two PathWrappers with equal paths."""
if isinstance(other, PathWrapper):
return self._path.__eq__(other._path)
raise TypeError(f"Can't compare {type(other)} with {type(self)}")

def is_absolute(self) -> bool:
"""True if the underlying path is absolute."""
return self._path.is_absolute()

def resolve(self) -> Path:
"""Return the resolved path which is the underlying path."""
warn("deprecated", DeprecationWarning)
return self._path

def raw(self) -> Path:
"""Return the raw underlying path without resolving it."""
warn("deprecated", DeprecationWarning)
return super().resolve(*args, **kwargs)
return self._path

@classmethod
def __get_pydantic_core_schema__(cls, _source: Type[Any]) -> core_schema.CoreSchema:
"""Set schema to str schema."""
from_str_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(
cls.validate,
),
]
)
from_anything_schema = core_schema.chain_schema(
[
core_schema.no_info_plain_validator_function(cls.validate),
core_schema.is_instance_schema(PathWrapper),
]
)
return core_schema.json_or_python_schema(
json_schema=from_str_schema,
python_schema=from_anything_schema,
)

@classmethod
def validate(cls: type[PathWrapperT], value: Any) -> PathWrapperT:
"""Convert a string value to a Text instance."""
if isinstance(value, (str, Path, PathWrapper)):
return cls(value)
raise ValueError(f"Cannot parse {type(value)} as {cls.__name__}")


class AssetsPath(
_DeprecatedResolvedPath
): # TODO: inherit from path_type instead after removal of deprecated class
class AssetsPath(PathWrapper):
"""Custom path for settings that can be absolute or relative to `assets_dir`."""


class WorkPath(
_DeprecatedResolvedPath
): # TODO: inherit from path_type instead after removal of deprecated class
class WorkPath(PathWrapper):
"""Custom path for settings that can be absolute or relative to `work_dir`."""
14 changes: 8 additions & 6 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,13 @@ class DummySettings(BaseSettings):
abs_path=absolute,
work_path=WorkPath(relative),
assets_path=AssetsPath(relative),
assets_dir=AssetsPath(absolute / "assets_dir"),
assets_dir=Path(absolute / "assets_dir"),
)

settings_dir = settings.model_dump(exclude_defaults=True)
assert settings_dir["non_path"] == "blablabla"
assert settings_dir["abs_path"] == absolute
assert settings.work_path == settings.work_dir / relative
assert settings_dir["assets_path"] == absolute / "assets_dir" / relative
settings_dict = settings.model_dump(exclude_defaults=True)
assert settings_dict["non_path"] == "blablabla"
assert settings_dict["abs_path"] == absolute
assert settings_dict["work_path"] == WorkPath(settings.work_dir / relative)
assert settings_dict["assets_path"] == AssetsPath(
absolute / "assets_dir" / relative
)

0 comments on commit 2e683fd

Please sign in to comment.