|
7 | 7 | from enum import Enum, auto
|
8 | 8 | from typing import TYPE_CHECKING, Type, Union
|
9 | 9 |
|
| 10 | +from django.contrib.postgres.fields import ( |
| 11 | + BigIntegerRangeField, |
| 12 | + DateRangeField, |
| 13 | + DateTimeRangeField, |
| 14 | + DecimalRangeField, |
| 15 | + IntegerRangeField, |
| 16 | + RangeOperators, |
| 17 | +) |
10 | 18 | from django.core.exceptions import FieldDoesNotExist
|
11 | 19 | from django.db.backends.postgresql.psycopg_any import (
|
12 | 20 | DateRange,
|
|
26 | 34 | from django_segments.models import AbstractSegment, AbstractSpan
|
27 | 35 |
|
28 | 36 |
|
| 37 | +def get_allowed_postgres_range_field_type_names() -> list[str]: |
| 38 | + """Get the names of all allowed PostgreSQL range field types.""" |
| 39 | + return [type.__name__ for type in POSTGRES_RANGE_FIELDS.keys()] |
| 40 | + |
| 41 | + |
| 42 | +def get_allowed_postgres_range_field_types() -> list[str]: |
| 43 | + """Get the allowed PostgreSQL range field types.""" |
| 44 | + return list(POSTGRES_RANGE_FIELDS.keys()) |
| 45 | + |
| 46 | + |
29 | 47 | class BoundaryType(Enum): # pylint: disable=C0115
|
30 | 48 | LOWER = auto()
|
31 | 49 | UPPER = auto()
|
32 | 50 |
|
33 | 51 |
|
34 | 52 | class BaseHelper: # pylint: disable=R0903
|
35 |
| - """Base class for all segment and span helpers.""" |
| 53 | + """Base class for all segment and span helpers. |
| 54 | +
|
| 55 | + Provides common methods and attributes for all segment and span helpers. It should not be instantiated directly. |
| 56 | + """ |
36 | 57 |
|
37 | 58 | def __init__(self, obj: Union[AbstractSpan, AbstractSegment]):
|
38 | 59 | self.obj = obj
|
39 |
| - self.range_field_type = None |
40 |
| - self.field_value_type = None |
41 |
| - self._initialize_range_field() |
| 60 | + self.range_field_type = obj.range_field_type |
| 61 | + self.validate_range_field_type() |
| 62 | + |
| 63 | + self.value_type = self._get_value_type(self.range_field_type) |
| 64 | + self.delta_value_type = self._get_delta_value_type(self.range_field_type) |
| 65 | + self.range_type = self._get_range_type(self.range_field_type) |
42 | 66 |
|
43 |
| - def _initialize_range_field(self) -> None: |
| 67 | + self.range_field_type_name = "" |
| 68 | + self.field_value_type_name = "" |
| 69 | + self._initialize_type_names() |
| 70 | + |
| 71 | + def _initialize_type_names(self) -> None: |
44 | 72 | """Initialize the range field type and value type."""
|
45 | 73 | for field_name in ["current_range", "segment_range"]:
|
46 | 74 | if hasattr(self.obj, field_name):
|
47 | 75 | range_value = getattr(self.obj, field_name)
|
48 | 76 | range_field = self._get_range_field(field_name)
|
49 | 77 | if range_field:
|
50 |
| - self.range_field_type = range_field.get_internal_type() |
51 |
| - self.field_value_type = type(range_value).__name__ |
| 78 | + self.range_field_type_name = range_field.get_internal_type() |
| 79 | + self.field_value_type_name = type(range_value).__name__ |
52 | 80 | return
|
53 | 81 | raise ValueError("Object must have either a `segment_range` or `current_range` field.")
|
54 | 82 |
|
55 |
| - def _get_range_field(self, field_name: str) -> Type: |
| 83 | + def _get_range_field( |
| 84 | + self, field_name: str |
| 85 | + ) -> Union[IntegerRangeField, BigIntegerRangeField, DecimalRangeField, DateRangeField, DateTimeRangeField]: |
56 | 86 | """Get the range field from the model."""
|
57 | 87 | try:
|
58 | 88 | return self.obj._meta.get_field(field_name) # pylint: disable=W0212
|
59 | 89 | except FieldDoesNotExist as e:
|
60 | 90 | logger.error("FieldDoesNotExist error: %s", e)
|
61 | 91 | return None
|
62 | 92 |
|
| 93 | + def validate_range_field_type(self) -> None: |
| 94 | + """Validate that the range field type is allowed.""" |
| 95 | + if self.range_field_type not in POSTGRES_RANGE_FIELDS: |
| 96 | + raise ValueError( |
| 97 | + f"Unsupported field type for `segment_range` field: " |
| 98 | + f"{self.range_field_type=} not in {POSTGRES_RANGE_FIELDS=}" |
| 99 | + ) |
| 100 | + |
63 | 101 | def validate_value_type(self, value: Union[int, Decimal, date, datetime]) -> None:
|
64 | 102 | """Validate the type of the provided value against the model's range_field_type."""
|
65 | 103 | if value is None:
|
66 | 104 | raise ValueError("Value cannot be None")
|
67 | 105 |
|
68 |
| - if self.range_field_type not in POSTGRES_RANGE_FIELDS: |
| 106 | + expected_value_type = self._get_value_type(self.range_field_type) |
| 107 | + if not isinstance(value, expected_value_type): |
69 | 108 | raise ValueError(
|
70 |
| - f"Unsupported field type for `segment_range` field: " |
71 |
| - f"{self.range_field_type=} not in {POSTGRES_RANGE_FIELDS.keys()=}" |
| 109 | + f"BaseHelper.validate_value_type(): Value must be of type {expected_value_type.__name__}, " |
| 110 | + f"not {type(value).__name__}. Provided value: {value}." |
72 | 111 | )
|
73 | 112 |
|
74 |
| - expected_type = self._get_expected_type(self.range_field_type) |
75 |
| - if not isinstance(value, expected_type): |
| 113 | + def validate_delta_value_type(self, delta_value: Union[int, Decimal, timezone.timedelta]) -> None: |
| 114 | + """Validate the type of the provided delta value against the model's range_field_type.""" |
| 115 | + if delta_value is None: |
| 116 | + raise ValueError("Delta value cannot be None") |
| 117 | + |
| 118 | + expected_delta_value_type = self._get_delta_value_type(self.range_field_type) |
| 119 | + if not isinstance(delta_value, expected_delta_value_type): |
76 | 120 | raise ValueError(
|
77 |
| - f"BaseHelper.validate_value_type(): Value must be of type {expected_type.__name__}, " |
78 |
| - f"not {type(value).__name__}. Provided value: {value}." |
| 121 | + "BaseHelper.validate_delta_value_type(): Delta value must be of type " |
| 122 | + f"{expected_delta_value_type.__name__}, " |
| 123 | + f"not {type(delta_value).__name__}. Provided delta value: {delta_value}." |
79 | 124 | )
|
80 | 125 |
|
81 | 126 | @staticmethod
|
82 |
| - def _get_expected_type(range_field_type: str) -> Type: |
| 127 | + def _get_value_type( |
| 128 | + range_field_type: get_allowed_postgres_range_field_types(), |
| 129 | + ) -> Union[type[int], type[Decimal], type[date], type[datetime]]: |
83 | 130 | """Get the expected type for a given range field type."""
|
84 | 131 | for key, val in POSTGRES_RANGE_FIELDS.items():
|
85 |
| - if key in range_field_type: |
86 |
| - return val.get("type") |
87 |
| - raise ValueError(f"No expected type found for range field type: {range_field_type}") |
| 132 | + if key is range_field_type: |
| 133 | + return val.get("value_type") |
| 134 | + raise ValueError(f"No value type found for range field type: {range_field_type}") |
| 135 | + |
| 136 | + @staticmethod |
| 137 | + def _get_delta_value_type( |
| 138 | + range_field_type: get_allowed_postgres_range_field_types(), |
| 139 | + ) -> Union[type[int], type[Decimal], type[timezone.timedelta]]: |
| 140 | + """Get the expected type for a given range field type.""" |
| 141 | + for key, val in POSTGRES_RANGE_FIELDS.items(): |
| 142 | + if key is range_field_type: |
| 143 | + return val.get("delta_type") |
| 144 | + raise ValueError(f"No delta type found for range field type: {range_field_type}") |
| 145 | + |
| 146 | + @staticmethod |
| 147 | + def _get_range_type(range_field_type: get_allowed_postgres_range_field_types()) -> Type[Range]: |
| 148 | + """Get the range type from the range field type.""" |
| 149 | + for key, val in POSTGRES_RANGE_FIELDS.items(): |
| 150 | + if key is range_field_type: |
| 151 | + print(f"_get_range_type {val=} {val.get('range_type')=}") |
| 152 | + return val.get("range_type") |
| 153 | + raise ValueError(f"No range type found for range field type: {range_field_type}") |
88 | 154 |
|
89 | 155 | def set_boundary(
|
90 |
| - self, range_field: Range, new_boundary: Union[int, Decimal, datetime, date], boundary_type: BoundaryType |
| 156 | + self, *, range_field: Range, new_boundary: Union[int, Decimal, datetime, date], boundary_type: BoundaryType |
91 | 157 | ) -> Range:
|
92 |
| - """Set the boundary of the range field.""" |
| 158 | + """Set the boundary of the model range field.""" |
93 | 159 | return range_field.__class__(
|
94 | 160 | lower=new_boundary if boundary_type == BoundaryType.LOWER else range_field.lower,
|
95 | 161 | upper=new_boundary if boundary_type == BoundaryType.UPPER else range_field.upper,
|
96 | 162 | )
|
97 |
| - |
98 |
| - def validate_range( |
99 |
| - self, |
100 |
| - range_value: Union[Range, DateRange, DateTimeTZRange, NumericRange], |
101 |
| - lower_bound: Union[int, Decimal, datetime, date], |
102 |
| - upper_bound: Union[int, Decimal, datetime, date], |
103 |
| - ) -> None: |
104 |
| - """Validate that the range is within the specified bounds.""" |
105 |
| - if range_value.lower < lower_bound or range_value.upper > upper_bound: |
106 |
| - raise ValueError("Range must be within the specified bounds.") |
|
0 commit comments