Skip to content

Commit 7852bc8

Browse files
committed
OpenAPI gen support for all kinds of enums with(out) discriminators
OpenAPI gen support for default values, optional and transient fields
1 parent 6e00a81 commit 7852bc8

File tree

3 files changed

+1058
-39
lines changed

3 files changed

+1058
-39
lines changed

zio-http/src/main/scala/zio/http/endpoint/openapi/JsonSchema.scala

+231-32
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package zio.http.endpoint.openapi
33
import zio._
44
import zio.json.ast.Json
55

6+
import zio.schema.Schema.CaseClass0
67
import zio.schema._
78
import zio.schema.annotation._
89
import zio.schema.codec._
@@ -30,6 +31,7 @@ private[openapi] case class SerializableJsonSchema(
3031
deprecated: Option[Boolean] = None,
3132
contentEncoding: Option[String] = None,
3233
contentMediaType: Option[String] = None,
34+
default: Option[Json] = None,
3335
) {
3436
def asNullableType(nullable: Boolean): SerializableJsonSchema =
3537
if (nullable && schemaType.isDefined)
@@ -104,6 +106,12 @@ private[openapi] object TypeOrTypes {
104106
}
105107
}
106108

109+
final case class JsonSchemas(
110+
root: JsonSchema,
111+
rootRef: Option[String],
112+
children: Map[String, JsonSchema],
113+
)
114+
107115
sealed trait JsonSchema extends Product with Serializable { self =>
108116

109117
lazy val toJsonBytes: Chunk[Byte] = JsonCodec.schemaBasedBinaryCodec[JsonSchema].encode(self)
@@ -129,6 +137,15 @@ sealed trait JsonSchema extends Product with Serializable { self =>
129137
def examples(examples: Chunk[Json]): JsonSchema =
130138
JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Examples(examples))
131139

140+
def default(default: Option[Json]): JsonSchema =
141+
default match {
142+
case Some(value) => JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Default(value))
143+
case None => self
144+
}
145+
146+
def default(default: Json): JsonSchema =
147+
JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Default(default))
148+
132149
def description(description: String): JsonSchema =
133150
JsonSchema.AnnotatedSchema(self, JsonSchema.MetaData.Description(description))
134151

@@ -173,6 +190,13 @@ object JsonSchema {
173190

174191
private[openapi] val codec = JsonCodec.schemaBasedBinaryCodec[JsonSchema]
175192

193+
private def toJsonAst(schema: Schema[_], v: Any): Json =
194+
JsonCodec
195+
.jsonEncoder(schema.asInstanceOf[Schema[Any]])
196+
.toJsonAST(v)
197+
.toOption
198+
.get
199+
176200
private def fromSerializableSchema(schema: SerializableJsonSchema): JsonSchema = {
177201
val additionalProperties = schema.additionalProperties match {
178202
case Some(BoolOrSchema.BooleanWrapper(false)) => Left(false)
@@ -244,6 +268,8 @@ object JsonSchema {
244268
case None => ()
245269
}
246270

271+
jsonSchema = jsonSchema.default(schema.default)
272+
247273
jsonSchema = jsonSchema.deprecated(schema.deprecated.getOrElse(false))
248274

249275
jsonSchema
@@ -295,12 +321,158 @@ object JsonSchema {
295321
},
296322
)
297323

324+
def fromZSchemaMulti(schema: Schema[_], refType: SchemaStyle = SchemaStyle.Inline): JsonSchemas = {
325+
val ref = nominal(schema, refType)
326+
schema match {
327+
case enum0: Schema.Enum[_] if enum0.cases.forall(_.schema.isInstanceOf[CaseClass0[_]]) =>
328+
JsonSchemas(fromZSchema(enum0, SchemaStyle.Inline), ref, Map.empty)
329+
case enum0: Schema.Enum[_] =>
330+
JsonSchemas(
331+
fromZSchema(enum0, SchemaStyle.Inline),
332+
ref,
333+
enum0.cases
334+
.filterNot(_.annotations.exists(_.isInstanceOf[transientCase]))
335+
.flatMap { c =>
336+
val key =
337+
nominal(c.schema, refType)
338+
.orElse(nominal(c.schema, SchemaStyle.Compact))
339+
.getOrElse(throw new Exception(s"Unsupported enum case schema: ${c.schema}"))
340+
val nested = fromZSchemaMulti(
341+
c.schema,
342+
refType,
343+
)
344+
nested.children + (key -> nested.root)
345+
}
346+
.toMap,
347+
)
348+
case record: Schema.Record[_] =>
349+
val children = record.fields
350+
.filterNot(_.annotations.exists(_.isInstanceOf[transientField]))
351+
.flatMap { field =>
352+
val key = nominal(field.schema, refType).orElse(nominal(field.schema, SchemaStyle.Compact))
353+
val nested = fromZSchemaMulti(
354+
field.schema,
355+
refType,
356+
)
357+
key.map(k => nested.children + (k -> nested.root)).getOrElse(nested.children)
358+
}
359+
.toMap
360+
JsonSchemas(fromZSchema(record, SchemaStyle.Inline), ref, children)
361+
case collection: Schema.Collection[_, _] =>
362+
collection match {
363+
case Schema.Sequence(elementSchema, _, _, _, _) =>
364+
arraySchemaMulti(refType, ref, elementSchema)
365+
case Schema.Map(_, valueSchema, _) =>
366+
val nested = fromZSchemaMulti(valueSchema, refType)
367+
if (valueSchema.isInstanceOf[Schema.Primitive[_]]) {
368+
JsonSchemas(
369+
JsonSchema.Object(
370+
Map.empty,
371+
Right(nested.root),
372+
Chunk.empty,
373+
),
374+
ref,
375+
nested.children,
376+
)
377+
} else {
378+
JsonSchemas(
379+
JsonSchema.Object(
380+
Map.empty,
381+
Right(nested.root),
382+
Chunk.empty,
383+
),
384+
ref,
385+
nested.children + (nested.rootRef.get -> nested.root),
386+
)
387+
}
388+
case Schema.Set(elementSchema, _) =>
389+
arraySchemaMulti(refType, ref, elementSchema)
390+
}
391+
case Schema.Transform(schema, _, _, _, _) =>
392+
fromZSchemaMulti(schema, refType)
393+
case Schema.Primitive(_, _) =>
394+
JsonSchemas(fromZSchema(schema, SchemaStyle.Inline), ref, Map.empty)
395+
case Schema.Optional(schema, _) =>
396+
fromZSchemaMulti(schema, refType)
397+
case Schema.Fail(_, _) =>
398+
throw new IllegalArgumentException("Fail schema is not supported.")
399+
case Schema.Tuple2(left, right, _) =>
400+
val leftSchema = fromZSchemaMulti(left, refType)
401+
val rightSchema = fromZSchemaMulti(right, refType)
402+
JsonSchemas(
403+
AllOfSchema(Chunk(leftSchema.root, rightSchema.root)),
404+
ref,
405+
leftSchema.children ++ rightSchema.children,
406+
)
407+
case Schema.Either(left, right, _) =>
408+
val leftSchema = fromZSchemaMulti(left, refType)
409+
val rightSchema = fromZSchemaMulti(right, refType)
410+
JsonSchemas(
411+
OneOfSchema(Chunk(leftSchema.root, rightSchema.root)),
412+
ref,
413+
leftSchema.children ++ rightSchema.children,
414+
)
415+
case Schema.Lazy(schema0) =>
416+
fromZSchemaMulti(schema0(), refType)
417+
case Schema.Dynamic(_) =>
418+
throw new IllegalArgumentException("Dynamic schema is not supported.")
419+
}
420+
}
421+
422+
private def arraySchemaMulti(
423+
refType: SchemaStyle,
424+
ref: Option[String],
425+
elementSchema: Schema[_],
426+
): JsonSchemas = {
427+
val nested = fromZSchemaMulti(elementSchema, refType)
428+
if (elementSchema.isInstanceOf[Schema.Primitive[_]]) {
429+
JsonSchemas(
430+
JsonSchema.ArrayType(Some(nested.root)),
431+
ref,
432+
nested.children,
433+
)
434+
} else {
435+
JsonSchemas(
436+
JsonSchema.ArrayType(Some(nested.root)),
437+
ref,
438+
nested.children + (nested.rootRef.get -> nested.root),
439+
)
440+
}
441+
}
442+
298443
def fromZSchema(schema: Schema[_], refType: SchemaStyle = SchemaStyle.Inline): JsonSchema =
299444
schema match {
300445
case enum0: Schema.Enum[_] if refType != SchemaStyle.Inline && nominal(enum0).isDefined =>
301446
JsonSchema.RefSchema(nominal(enum0, refType).get)
302-
case enum0: Schema.Enum[_] =>
447+
case enum0: Schema.Enum[_] if enum0.cases.forall(_.schema.isInstanceOf[CaseClass0[_]]) =>
303448
JsonSchema.Enum(enum0.cases.map(c => EnumValue.Str(c.id)))
449+
case enum0: Schema.Enum[_] =>
450+
val noDiscriminator = enum0.annotations.exists(_.isInstanceOf[noDiscriminator])
451+
val discriminatorName0 =
452+
enum0.annotations.collectFirst { case discriminatorName(name) => name }
453+
val nonTransientCases = enum0.cases.filterNot(_.annotations.exists(_.isInstanceOf[transientCase]))
454+
if (noDiscriminator) {
455+
JsonSchema
456+
.OneOfSchema(nonTransientCases.map(c => fromZSchema(c.schema, SchemaStyle.Compact)))
457+
} else if (discriminatorName0.isDefined) {
458+
JsonSchema
459+
.OneOfSchema(nonTransientCases.map(c => fromZSchema(c.schema, SchemaStyle.Compact)))
460+
.discriminator(
461+
OpenAPI.Discriminator(
462+
propertyName = discriminatorName0.get,
463+
mapping = nonTransientCases.map { c =>
464+
val name = c.annotations.collectFirst { case caseName(name) => name }.getOrElse(c.id)
465+
name -> nominal(c.schema, refType).orElse(nominal(c.schema, SchemaStyle.Compact)).get
466+
}.toMap,
467+
),
468+
)
469+
} else {
470+
JsonSchema
471+
.OneOfSchema(nonTransientCases.map { c =>
472+
val name = c.annotations.collectFirst { case caseName(name) => name }.getOrElse(c.id)
473+
Object(Map(name -> fromZSchema(c.schema, SchemaStyle.Compact)), Left(false), Chunk(name))
474+
})
475+
}
304476
case record: Schema.Record[_] if refType != SchemaStyle.Inline && nominal(record).isDefined =>
305477
JsonSchema.RefSchema(nominal(record, refType).get)
306478
case record: Schema.Record[_] =>
@@ -310,17 +482,28 @@ object JsonSchema {
310482
} else {
311483
Left(true)
312484
}
485+
val nonTransientFields =
486+
record.fields.filterNot(_.annotations.exists(_.isInstanceOf[transientField]))
313487
JsonSchema
314488
.Object(
315489
Map.empty,
316490
additionalProperties,
317491
Chunk.empty,
318492
)
319-
.addAll(record.fields.map { field =>
493+
.addAll(nonTransientFields.map { field =>
320494
field.name ->
321-
fromZSchema(field.schema, refType).deprecated(deprecated(field.schema))
495+
fromZSchema(field.schema, refType)
496+
.deprecated(deprecated(field.schema))
497+
.description(fieldDoc(field))
498+
.default(fieldDefault(field))
322499
})
323-
.required(record.fields.filterNot(_.schema.isInstanceOf[Schema.Optional[_]]).map(_.name))
500+
.required(
501+
nonTransientFields
502+
.filterNot(_.schema.isInstanceOf[Schema.Optional[_]])
503+
.filterNot(_.annotations.exists(_.isInstanceOf[fieldDefaultValue[_]]))
504+
.filterNot(_.annotations.exists(_.isInstanceOf[optionalField]))
505+
.map(_.name),
506+
)
324507
.deprecated(deprecated(record))
325508
case collection: Schema.Collection[_, _] =>
326509
collection match {
@@ -339,34 +522,34 @@ object JsonSchema {
339522
fromZSchema(schema, refType)
340523
case Schema.Primitive(standardType, _) =>
341524
standardType match {
342-
case StandardType.UnitType => JsonSchema.Null // is this null or empty object?
343-
case StandardType.StringType => JsonSchema.String
344-
case StandardType.BoolType => JsonSchema.Boolean
345-
case StandardType.ByteType => JsonSchema.String
346-
case StandardType.ShortType => JsonSchema.Integer(IntegerFormat.Int32)
347-
case StandardType.IntType => JsonSchema.Integer(IntegerFormat.Int32)
348-
case StandardType.LongType => JsonSchema.Integer(IntegerFormat.Int64)
349-
case StandardType.FloatType => JsonSchema.Number(NumberFormat.Float)
350-
case StandardType.DoubleType => JsonSchema.Number(NumberFormat.Double)
351-
case StandardType.BinaryType => JsonSchema.String
352-
case StandardType.CharType => JsonSchema.String
353-
case StandardType.UUIDType => JsonSchema.String
354-
case StandardType.BigDecimalType => JsonSchema.Number(NumberFormat.Double) // TODO: Is this correct?
355-
case StandardType.BigIntegerType => JsonSchema.Integer(IntegerFormat.Int64)
356-
case StandardType.DayOfWeekType => JsonSchema.String
357-
case StandardType.MonthType => JsonSchema.String
358-
case StandardType.MonthDayType => JsonSchema.String
359-
case StandardType.PeriodType => JsonSchema.String
360-
case StandardType.YearType => JsonSchema.String
361-
case StandardType.YearMonthType => JsonSchema.String
362-
case StandardType.ZoneIdType => JsonSchema.String
363-
case StandardType.ZoneOffsetType => JsonSchema.String
364-
case StandardType.DurationType => JsonSchema.String
365-
case StandardType.InstantType => JsonSchema.String
366-
case StandardType.LocalDateType => JsonSchema.String
367-
case StandardType.LocalTimeType => JsonSchema.String
368-
case StandardType.LocalDateTimeType => JsonSchema.String
369-
case StandardType.OffsetTimeType => JsonSchema.String
525+
case StandardType.UnitType => JsonSchema.Null
526+
case StandardType.StringType => JsonSchema.String
527+
case StandardType.BoolType => JsonSchema.Boolean
528+
case StandardType.ByteType => JsonSchema.String
529+
case StandardType.ShortType => JsonSchema.Integer(IntegerFormat.Int32)
530+
case StandardType.IntType => JsonSchema.Integer(IntegerFormat.Int32)
531+
case StandardType.LongType => JsonSchema.Integer(IntegerFormat.Int64)
532+
case StandardType.FloatType => JsonSchema.Number(NumberFormat.Float)
533+
case StandardType.DoubleType => JsonSchema.Number(NumberFormat.Double)
534+
case StandardType.BinaryType => JsonSchema.String
535+
case StandardType.CharType => JsonSchema.String
536+
case StandardType.UUIDType => JsonSchema.String
537+
case StandardType.BigDecimalType => JsonSchema.Number(NumberFormat.Double) // TODO: Is this correct?
538+
case StandardType.BigIntegerType => JsonSchema.Integer(IntegerFormat.Int64)
539+
case StandardType.DayOfWeekType => JsonSchema.String
540+
case StandardType.MonthType => JsonSchema.String
541+
case StandardType.MonthDayType => JsonSchema.String
542+
case StandardType.PeriodType => JsonSchema.String
543+
case StandardType.YearType => JsonSchema.String
544+
case StandardType.YearMonthType => JsonSchema.String
545+
case StandardType.ZoneIdType => JsonSchema.String
546+
case StandardType.ZoneOffsetType => JsonSchema.String
547+
case StandardType.DurationType => JsonSchema.String
548+
case StandardType.InstantType => JsonSchema.String
549+
case StandardType.LocalDateType => JsonSchema.String
550+
case StandardType.LocalTimeType => JsonSchema.String
551+
case StandardType.LocalDateTimeType => JsonSchema.String
552+
case StandardType.OffsetTimeType => JsonSchema.String
370553
case StandardType.OffsetDateTimeType => JsonSchema.String
371554
case StandardType.ZonedDateTimeType => JsonSchema.String
372555
}
@@ -406,6 +589,19 @@ object JsonSchema {
406589
private def deprecated(schema: Schema[_]): Boolean =
407590
schema.annotations.exists(_.isInstanceOf[scala.deprecated])
408591

592+
private def fieldDoc(schema: Schema.Field[_, _]): Option[String] = {
593+
val description0 = schema.annotations.collectFirst { case description(value) => value }
594+
val defaultValue = schema.annotations.collectFirst { case fieldDefaultValue(value) => value }.map { _ =>
595+
s"${if (description0.isDefined) "\n" else ""}If not set, this field defaults to the value of the default annotation."
596+
}
597+
Some(description0.getOrElse("") + defaultValue.getOrElse(""))
598+
.filter(_.nonEmpty)
599+
}
600+
601+
private def fieldDefault(schema: Schema.Field[_, _]): Option[Json] =
602+
schema.annotations.collectFirst { case fieldDefaultValue(value) => value }
603+
.map(toJsonAst(schema.schema, _))
604+
409605
private def nominal(schema: Schema[_], referenceType: SchemaStyle = SchemaStyle.Reference): Option[String] =
410606
schema match {
411607
case enumSchema: Schema.Enum[_] => refForTypeId(enumSchema.id, referenceType)
@@ -447,13 +643,16 @@ object JsonSchema {
447643
schema.toSerializableSchema.copy(contentMediaType = Some(mediaType))
448644
case MetaData.Deprecated =>
449645
schema.toSerializableSchema.copy(deprecated = Some(true))
646+
case MetaData.Default(default) =>
647+
schema.toSerializableSchema.copy(default = Some(default))
450648
}
451649
}
452650
}
453651

454652
sealed trait MetaData extends Product with Serializable
455653
object MetaData {
456654
final case class Examples(chunk: Chunk[Json]) extends MetaData
655+
final case class Default(default: Json) extends MetaData
457656
final case class Discriminator(discriminator: OpenAPI.Discriminator) extends MetaData
458657
final case class Nullable(nullable: Boolean) extends MetaData
459658
final case class Description(description: String) extends MetaData

0 commit comments

Comments
 (0)