Skip to content

Commit

Permalink
Support discriminators in array items (#1458)
Browse files Browse the repository at this point in the history
* setup usecase

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support discriminator in array

* Support discriminator in array

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Koudai Aono <[email protected]>
  • Loading branch information
3 people authored Nov 24, 2023
1 parent a31148d commit 5ccc441
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 2 deletions.
7 changes: 6 additions & 1 deletion datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,12 @@ def __collapse_root_models(
model_field.constraints = ConstraintsBase.merge_constraints(
root_type_field.constraints, model_field.constraints
)

if isinstance(
root_type_field, pydantic_model.DataModelField
) and not model_field.extras.get('discriminator'): # no: pragma
discriminator = root_type_field.extras.get('discriminator')
if discriminator: # no: pragma
model_field.extras['discriminator'] = discriminator
data_type.parent.data_types.remove(data_type)
data_type.parent.data_types.append(copied_data_type)

Expand Down
8 changes: 7 additions & 1 deletion datamodel_code_generator/parser/jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,13 @@ def parse_item(
return self.parse_array_fields(
name, item, get_special_path('array', path)
).data_type
elif (
item.discriminator
and parent
and parent.is_array
and (item.oneOf or item.anyOf)
):
return self.parse_root_type(name, item, path)
elif item.anyOf:
return self.data_type(
data_types=self.parse_any_of(
Expand Down Expand Up @@ -1201,7 +1208,6 @@ def parse_array_fields(
data_types.append(
self.parse_enum(name, obj, get_special_path('enum', path))
)

return self.data_model_field_type(
data_type=self.data_type(data_types=data_types),
default=obj.default,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# generated by datamodel-codegen:
# filename: discriminator_in_array.yaml
# timestamp: 2023-07-27T00:00:00+00:00

from __future__ import annotations

from enum import Enum
from typing import List, Optional, Union

from pydantic import BaseModel, Field
from typing_extensions import Literal


class Type(Enum):
my_first_object = 'my_first_object'
my_second_object = 'my_second_object'


class ObjectBase(BaseModel):
name: Optional[str] = Field(None, description='Name of the object')
type: Literal['type1'] = Field(..., description='Object type')


class CreateObjectRequest(ObjectBase):
name: str = Field(..., description='Name of the object')
type: Literal['type2'] = Field(..., description='Object type')


class UpdateObjectRequest(ObjectBase):
type: Literal['type3']


class MyArray(BaseModel):
__root__: Union[ObjectBase, CreateObjectRequest, UpdateObjectRequest] = Field(
..., discriminator='type'
)


class Demo(BaseModel):
myArray: List[MyArray]
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# generated by datamodel-codegen:
# filename: discriminator_in_array.yaml
# timestamp: 2023-07-27T00:00:00+00:00

from __future__ import annotations

from enum import Enum
from typing import List, Optional, Union

from pydantic import BaseModel, Field
from typing_extensions import Literal


class Type(Enum):
my_first_object = 'my_first_object'
my_second_object = 'my_second_object'


class ObjectBase(BaseModel):
name: Optional[str] = Field(None, description='Name of the object')
type: Literal['type1'] = Field(..., description='Object type')


class CreateObjectRequest(ObjectBase):
name: str = Field(..., description='Name of the object')
type: Literal['type2'] = Field(..., description='Object type')


class UpdateObjectRequest(ObjectBase):
type: Literal['type3']


class Demo(BaseModel):
myArray: List[Union[ObjectBase, CreateObjectRequest, UpdateObjectRequest]] = Field(
..., discriminator='type'
)
48 changes: 48 additions & 0 deletions tests/data/openapi/discriminator_in_array_anyof.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
openapi: "3.0.0"
components:
schemas:
ObjectBase:
description: Object schema
type: object
properties:
name:
description: Name of the object
type: string
type:
description: Object type
type: string
enum:
- my_first_object
- my_second_object
CreateObjectRequest:
description: Request schema for object creation
type: object
allOf:
- $ref: '#/components/schemas/ObjectBase'
required:
- name
- type
UpdateObjectRequest:
description: Request schema for object updates
type: object
allOf:
- $ref: '#/components/schemas/ObjectBase'
Demo:
type: object
required:
- myArray
properties:
myArray:
type: array
items:
oneOf:
- $ref: "#/components/schemas/ObjectBase"
- $ref: "#/components/schemas/CreateObjectRequest"
- $ref: "#/components/schemas/UpdateObjectRequest"
discriminator:
propertyName: type
mapping:
type1: "#/components/schemas/ObjectBase"
type2: "#/components/schemas/CreateObjectRequest"
type3: "#/components/schemas/UpdateObjectRequest"

48 changes: 48 additions & 0 deletions tests/data/openapi/discriminator_in_array_oneof.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
openapi: "3.0.0"
components:
schemas:
ObjectBase:
description: Object schema
type: object
properties:
name:
description: Name of the object
type: string
type:
description: Object type
type: string
enum:
- my_first_object
- my_second_object
CreateObjectRequest:
description: Request schema for object creation
type: object
allOf:
- $ref: '#/components/schemas/ObjectBase'
required:
- name
- type
UpdateObjectRequest:
description: Request schema for object updates
type: object
allOf:
- $ref: '#/components/schemas/ObjectBase'
Demo:
type: object
required:
- myArray
properties:
myArray:
type: array
items:
anyOf:
- $ref: "#/components/schemas/ObjectBase"
- $ref: "#/components/schemas/CreateObjectRequest"
- $ref: "#/components/schemas/UpdateObjectRequest"
discriminator:
propertyName: type
mapping:
type1: "#/components/schemas/ObjectBase"
type2: "#/components/schemas/CreateObjectRequest"
type3: "#/components/schemas/UpdateObjectRequest"

43 changes: 43 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5038,6 +5038,49 @@ def test_main_openapi_discriminator(input, output):
)


@freeze_time('2023-07-27')
@pytest.mark.parametrize(
'kind,option, expected',
[
(
'anyOf',
'--collapse-root-models',
'main_openapi_discriminator_in_array_collapse_root_models',
),
(
'oneOf',
'--collapse-root-models',
'main_openapi_discriminator_in_array_collapse_root_models',
),
('anyOf', None, 'main_openapi_discriminator_in_array'),
('oneOf', None, 'main_openapi_discriminator_in_array'),
],
)
def test_main_openapi_discriminator_in_array(kind, option, expected):
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
input_file = f'discriminator_in_array_{kind.lower()}.yaml'
return_code: Exit = main(
[
a
for a in [
'--input',
str(OPEN_API_DATA_PATH / input_file),
'--output',
str(output_file),
'--input-file-type',
'openapi',
option,
]
if a
]
)
assert return_code == Exit.OK
assert output_file.read_text() == (
EXPECTED_MAIN_PATH / expected / 'output.py'
).read_text().replace('discriminator_in_array.yaml', input_file)


@freeze_time('2019-07-26')
def test_main_jsonschema_pattern_properties_by_reference():
with TemporaryDirectory() as output_dir:
Expand Down

0 comments on commit 5ccc441

Please sign in to comment.