Skip to content

Commit

Permalink
Search for discriminator children through multiple levels of allOf ne…
Browse files Browse the repository at this point in the history
…sting (#1460)
  • Loading branch information
krishanjmistry authored Nov 17, 2023
1 parent 6671dd3 commit 0d09ed3
Show file tree
Hide file tree
Showing 36 changed files with 1,322 additions and 21 deletions.
247 changes: 240 additions & 7 deletions services/autorust/codegen/src/codegen_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use proc_macro2::{Ident, TokenStream};
use quote::{quote, ToTokens};
use serde_json::Value;
use spec::{get_schema_schema_references, openapi, RefKey};
use std::collections::{HashMap, HashSet};
use std::{
cmp::Reverse,
collections::{BinaryHeap, HashMap, HashSet},
};

#[derive(Clone)]
pub struct PropertyGen {
Expand Down Expand Up @@ -547,6 +550,32 @@ pub struct UnionCode {
pub description: Option<String>,
}

#[derive(Debug)]
struct Depth<T> {
inner: T,
depth: usize,
}

impl<T> Ord for Depth<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.depth.cmp(&other.depth)
}
}

impl<T> PartialOrd for Depth<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.depth.cmp(&other.depth))
}
}

impl<T> Eq for Depth<T> {}

impl<T> PartialEq for Depth<T> {
fn eq(&self, other: &Self) -> bool {
self.depth == other.depth
}
}

impl UnionCode {
fn from_schema(
cg: &CodeGen,
Expand All @@ -558,12 +587,8 @@ impl UnionCode {
) -> Result<Self> {
let mut values = Vec::new();
for (child_ref_key, child_schema) in all_schemas {
if child_schema
.all_of()
.iter()
.any(|all_of_schema| all_of_schema.ref_key.as_ref() == Some(ref_key))
{
if let Some(tag) = child_schema.discriminator_value() {
if let Some(tag) = child_schema.discriminator_value() {
if Self::breadth_first_search_all_of(ref_key, child_schema) {
let name = tag.to_camel_case_ident()?;
let mut type_name = TypeNameCode::from(child_ref_key.name.to_camel_case_ident()?);
cg.set_if_union_type(&mut type_name);
Expand All @@ -584,6 +609,37 @@ impl UnionCode {
description,
})
}

/// Performs a BFS through multiple layers of allOf properties on a provided start schema
fn breadth_first_search_all_of(search_for_ref_key: &RefKey, start_schema: &SchemaGen) -> bool {
let mut heap = BinaryHeap::new();
Self::populate_heap(&mut heap, start_schema, 0);
while !heap.is_empty() {
let Reverse(Depth { inner: schema, .. }) = heap.pop().unwrap();
if schema.ref_key.as_ref() == Some(search_for_ref_key) {
// we have found an all of schema that matches the ref key we are searching for
return true;
}
if schema.discriminator().is_some() {
// if there is another discriminator defined, we can stop searching as the start schema would be a child of this discriminator instead
break;
}
}
false
}

/// Populate a binary heap with all allOf schemas from a provided start schema
fn populate_heap(heap: &mut BinaryHeap<Reverse<Depth<SchemaGen>>>, schema: &SchemaGen, depth: usize) {
if depth != 0 {
heap.push(Reverse(Depth {
inner: schema.clone(),
depth,
}))
};
for referenced_schema in schema.all_of().iter() {
Self::populate_heap(heap, referenced_schema, depth + 1);
}
}
}

impl ToTokens for UnionCode {
Expand Down Expand Up @@ -1419,3 +1475,180 @@ fn create_struct_field_code(
}
}
}

#[cfg(test)]
mod union_code_tests {
use super::*;

/// Helper to create a [RefKey] for testing
fn create_ref_key(name: &str) -> RefKey {
RefKey {
file_path: Utf8PathBuf::from(name),
name: name.to_string(),
}
}

/// Helper to create a [SchemaGen] for testing from a [RefKey]
fn create_schemagen(ref_key: RefKey) -> SchemaGen {
let schema = Schema::default();
SchemaGen::new(Some(ref_key.clone()), schema, Utf8PathBuf::from(ref_key.name))
}

/// Helper to create a [SchemaGen] and a [RefKey] for testing
fn create_schema(name: &str) -> (RefKey, SchemaGen) {
let ref_key = create_ref_key(name);
(ref_key.clone(), create_schemagen(ref_key))
}

const SCHEMA_1A: &str = "schema_1a";
const SCHEMA_1B: &str = "schema_1b";
const SCHEMA_2A: &str = "schema_2a";
const SCHEMA_2B: &str = "schema_2b";
const SCHEMA_2C: &str = "schema_2c";
const SCHEMA_3A: &str = "schema_3a";
const SCHEMA_3B: &str = "schema_3b";
const SCHEMA_3C: &str = "schema_3c";

/// Helper function to setup a scenario to test search functions in [UnionCode]
///
/// level 1:
/// - A: is discriminator
/// - B: nothing special
///
/// level 2:
/// - A: all of over 1A, has discriminator value
/// - B: all of over 1A & 1B, is discriminator, has discriminator value
/// - C: all of over 1B
///
/// level 3:
/// - A: all of over 2A, has discriminator value
/// - B: all of over 2B, has discriminator value
/// - C: all of over 2C
fn setup_scenario() -> HashMap<&'static str, (RefKey, SchemaGen)> {
let mut schema_1a = create_schema(SCHEMA_1A);
let schema_1b = create_schema(SCHEMA_1B);
let mut schema_2a = create_schema(SCHEMA_2A);
let mut schema_2b = create_schema(SCHEMA_2B);
let mut schema_2c = create_schema(SCHEMA_2C);
let mut schema_3a = create_schema(SCHEMA_3A);
let mut schema_3b = create_schema(SCHEMA_3B);
let mut schema_3c = create_schema(SCHEMA_3C);

schema_1a.1.schema.discriminator = Some("schema_1a_discriminator".to_string());

schema_2a.1.all_of = vec![schema_1a.1.clone()];
schema_2a.1.schema.x_ms_discriminator_value = Some("schema_2a_discriminator_value".to_string());

schema_2b.1.all_of = vec![schema_1a.1.clone(), schema_1b.1.clone()];
schema_2b.1.schema.discriminator = Some("schema_2b_discriminator".to_string());
schema_2b.1.schema.x_ms_discriminator_value = Some("schema_2b_discriminator_value".to_string());

schema_2c.1.all_of = vec![schema_1b.1.clone()];

schema_3a.1.all_of = vec![schema_2a.1.clone()];
schema_3a.1.schema.x_ms_discriminator_value = Some("schema_3a_discriminator_value".to_string());
schema_3b.1.all_of = vec![schema_2b.1.clone()];
schema_3b.1.schema.x_ms_discriminator_value = Some("schema_3b_discriminator_value".to_string());
schema_3c.1.all_of = vec![schema_2c.1.clone()];

let mut schemas = HashMap::new();
schemas.insert(SCHEMA_1A, schema_1a);
schemas.insert(SCHEMA_1B, schema_1b);
schemas.insert(SCHEMA_2A, schema_2a);
schemas.insert(SCHEMA_2B, schema_2b);
schemas.insert(SCHEMA_2C, schema_2c);
schemas.insert(SCHEMA_3A, schema_3a);
schemas.insert(SCHEMA_3B, schema_3b);
schemas.insert(SCHEMA_3C, schema_3c);
schemas
}

#[test]
fn test_breadth_first_search_all_of() {
let schemas = setup_scenario();

// Test case 1: Searching for (A) with start schema (A), there are no allOf properties
assert_eq!(
UnionCode::breadth_first_search_all_of(&create_ref_key(SCHEMA_1A), &schemas.get(SCHEMA_1A).unwrap().1),
false
);

// Test case 2: Start schema (A) has allOf properties which includes search value (B)
assert_eq!(
UnionCode::breadth_first_search_all_of(&create_ref_key(SCHEMA_1A), &schemas.get(SCHEMA_2A).unwrap().1),
true
);

// Test case 3: Start schema (A) has allOf properties which includes search value (B), but itself is a discriminator
assert_eq!(
UnionCode::breadth_first_search_all_of(&create_ref_key(SCHEMA_1A), &schemas.get(SCHEMA_2B).unwrap().1),
true
);

// Test case 4: Start schema (A) has allOf properties, where one of those (B) contains a reference to what we're searching for (C)
assert_eq!(
UnionCode::breadth_first_search_all_of(&create_ref_key(SCHEMA_1A), &schemas.get(SCHEMA_3A).unwrap().1),
true
);

// Test case 5: Start schema (A) has allOf properties, where one of those (B) contains a reference to what we're searching for (C), but (B) is a discriminator
// If we search for (B) instead, we should find it on (A)
assert_eq!(
UnionCode::breadth_first_search_all_of(&create_ref_key(SCHEMA_1A), &schemas.get(SCHEMA_3B).unwrap().1),
false
);
assert_eq!(
UnionCode::breadth_first_search_all_of(&create_ref_key(SCHEMA_2B), &schemas.get(SCHEMA_3B).unwrap().1),
true
);
}

#[test]
fn populate_heap_on_schema_with_no_all_of() {
let schemas = setup_scenario();
let schema = schemas.get(SCHEMA_1A).unwrap().1.clone();
let mut heap = BinaryHeap::new();

UnionCode::populate_heap(&mut heap, &schema, 0);
assert_eq!(heap.len(), 0);
}

#[test]
fn populate_heap_on_schema_with_single_all_of() {
let schemas = setup_scenario();
let schema = schemas.get(SCHEMA_2A).unwrap().1.clone();
let mut heap = BinaryHeap::new();

UnionCode::populate_heap(&mut heap, &schema, 0);
// This should include 1A
assert_eq!(heap.len(), 1);
assert_eq!(heap.pop().unwrap().0.depth, 1);
}

#[test]
fn populate_heap_on_schema_with_multiple_all_of() {
let schemas = setup_scenario();
let schema = schemas.get(SCHEMA_2B).unwrap().1.clone();
let mut heap = BinaryHeap::new();

UnionCode::populate_heap(&mut heap, &schema, 0);
// This should include 1A, 1B
assert_eq!(heap.len(), 2);
assert_eq!(heap.pop().unwrap().0.depth, 1);
assert_eq!(heap.pop().unwrap().0.depth, 1);
}

#[test]
fn populate_heap_on_schema_with_nested_all_of() {
let schemas = setup_scenario();
let schema = schemas.get(SCHEMA_3B).unwrap().1.clone();
let mut heap = BinaryHeap::new();

UnionCode::populate_heap(&mut heap, &schema, 0);
// This should include 2B, 1A, 1B
assert_eq!(heap.len(), 3);
assert_eq!(heap.pop().unwrap().0.depth, 1);
assert_eq!(heap.pop().unwrap().0.depth, 2);
assert_eq!(heap.pop().unwrap().0.depth, 2);
}
}
4 changes: 2 additions & 2 deletions services/autorust/openapi/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub enum AdditionalProperties {
/// common fields in both Schema Object & Parameter Object
/// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/2.0.md#schemaObject
/// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/2.0.md#parameter-object
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct SchemaCommon {
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -137,7 +137,7 @@ pub struct SchemaCommon {
}

/// https://github.com/OAI/OpenAPI-Specification/blob/master/versions/2.0.md#schemaObject
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Default)]
#[serde(rename_all = "camelCase")]
pub struct Schema {
#[serde(flatten)]
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0d09ed3

Please sign in to comment.