diff --git a/ingestion/src/metadata/parsers/avro_parser.py b/ingestion/src/metadata/parsers/avro_parser.py index 730e780fe5ff..c06dd57811fb 100644 --- a/ingestion/src/metadata/parsers/avro_parser.py +++ b/ingestion/src/metadata/parsers/avro_parser.py @@ -37,7 +37,9 @@ def _parse_array_children( return f"ARRAY<{display_type}>", children if isinstance(arr_item, UnionSchema): - display_type, children = _parse_union_children(arr_item, cls=cls) + display_type, children = _parse_union_children( + parent=None, union_field=arr_item, cls=cls + ) return f"UNION<{display_type}>", children if isinstance(arr_item, RecordSchema): @@ -104,7 +106,7 @@ def parse_array_fields( def _parse_union_children( - union_field: UnionSchema, cls: ModelMetaclass = FieldModel + parent: Optional[Schema], union_field: UnionSchema, cls: ModelMetaclass = FieldModel ) -> Tuple[str, Optional[Union[FieldModel, Column]]]: non_null_schema = [ (i, schema) @@ -122,11 +124,12 @@ def _parse_union_children( sub_type[non_null_schema[0][0] ^ 1] = "null" return ",".join(sub_type), children + # if the child is a recursive instance of parent we will only process it once if isinstance(field, RecordSchema): children = cls( name=field.name, dataType=str(field.type).upper(), - children=get_avro_fields(field, cls), + children=None if field == parent else get_avro_fields(field, cls), description=field.doc, ) return sub_type, children @@ -155,7 +158,9 @@ def parse_record_fields(field: RecordSchema, cls: ModelMetaclass = FieldModel): def parse_union_fields( - union_field: Schema, cls: ModelMetaclass = FieldModel + parent: Optional[Schema], + union_field: Schema, + cls: ModelMetaclass = FieldModel, ) -> Optional[List[Union[FieldModel, Column]]]: """ Parse union field for avro schema @@ -194,7 +199,9 @@ def parse_union_fields( dataType=str(field_type.type).upper(), description=union_field.doc, ) - sub_type, children = _parse_union_children(field_type, cls) + sub_type, children = _parse_union_children( + union_field=field_type, cls=cls, parent=parent + ) obj.dataTypeDisplay = f"UNION<{sub_type}>" if children and cls == FieldModel: obj.children = [children] @@ -252,7 +259,9 @@ def get_avro_fields( if isinstance(field.type, ArraySchema): field_models.append(parse_array_fields(field, cls=cls)) elif isinstance(field.type, UnionSchema): - field_models.append(parse_union_fields(field, cls=cls)) + field_models.append( + parse_union_fields(union_field=field, cls=cls, parent=parsed_schema) + ) elif isinstance(field.type, RecordSchema): field_models.append(parse_record_fields(field, cls=cls)) else: diff --git a/ingestion/tests/unit/test_avro_parser.py b/ingestion/tests/unit/test_avro_parser.py index e7d3d9ee48c5..455dd8219804 100644 --- a/ingestion/tests/unit/test_avro_parser.py +++ b/ingestion/tests/unit/test_avro_parser.py @@ -76,6 +76,65 @@ } """ +RECURSIVE_AVRO_SCHEMA = """ +{ + "name":"MainRecord", + "type":"record", + "fields":[ + { + "default":"None", + "name":"NestedRecord", + "type":[ + "null", + { + "fields":[ + { + "default":"None", + "name":"FieldA", + "type":[ + "null", + { + "items":{ + "fields":[ + { + "name":"FieldAA", + "type":"string" + }, + { + "default":"None", + "name":"FieldBB", + "type":[ + "null", + "string" + ] + }, + { + "default":"None", + "name":"FieldCC", + "type":[ + "null", + "RecursionIssueRecord" + ] + } + ], + "name":"RecursionIssueRecord", + "type":"record" + }, + "type":"array" + } + ] + } + ], + "name":"FieldInNestedRecord", + "type":"record" + } + ] + } + ] +} +""" + + ARRAY_OF_STR = """ { "type": "record", @@ -647,3 +706,48 @@ def test_nested_record_parsing(self): parsed_record_schema[0].children[2].children[0].children[1].dataType.name, "ARRAY", ) + + def test_recursive_record_parsing(self): + parsed_recursive_schema = parse_avro_schema(RECURSIVE_AVRO_SCHEMA) + + # test that the recursive schema stops processing after 1st occurrence + self.assertEqual( + parsed_recursive_schema[0] + .children[0] + .children[0] + .children[0] + .children[0] + .name.__root__, + "RecursionIssueRecord", + ) + self.assertEqual( + parsed_recursive_schema[0] + .children[0] + .children[0] + .children[0] + .children[0] + .children[2] + .name.__root__, + "FieldCC", + ) + self.assertEqual( + parsed_recursive_schema[0] + .children[0] + .children[0] + .children[0] + .children[0] + .children[2] + .children[0] + .name.__root__, + "RecursionIssueRecord", + ) + self.assertIsNone( + parsed_recursive_schema[0] + .children[0] + .children[0] + .children[0] + .children[0] + .children[2] + .children[0] + .children + )