Skip to content

Commit

Permalink
Merge pull request #63 from commercetools/scala3-codegen
Browse files Browse the repository at this point in the history
Improve support for Scala 3 code generation through Scalameta dialects
  • Loading branch information
sbrunk authored Jan 24, 2025
2 parents 7a057e1 + 01e4a3a commit 7050455
Show file tree
Hide file tree
Showing 37 changed files with 393 additions and 299 deletions.
4 changes: 2 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ lazy val root = (project in file("."))
.settings(
name := "sbt-scraml",
libraryDependencies += "com.commercetools.rmf" % "raml-model" % "0.2.0-20240722205528",
libraryDependencies += "org.scalameta" %% "scalameta" % "4.12.6",
libraryDependencies += "org.scalameta" %% "scalafmt-dynamic" % "3.8.5",
libraryDependencies += "org.scalameta" %% "scalameta" % "4.12.7",
libraryDependencies += "org.scalameta" %% "scalafmt-dynamic" % "3.8.6",
libraryDependencies += "org.typelevel" %% "cats-effect" % "3.5.7",
libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.19" % Test,
libraryDependencies ++= Seq(
Expand Down
2 changes: 1 addition & 1 deletion examples/build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
scalaVersion := "3.3.4"

val circeVersion = "0.14.7"
val circeVersion = "0.14.10"
val tapirVersion = "1.11.9"

lazy val examples = (project in file("."))
Expand Down
21 changes: 15 additions & 6 deletions src/main/scala/scraml/DefaultModelGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ object DefaultModelGen extends ModelGen {
extendedType = Init(Type.Name(stringType.getName), Name(""), Seq.empty)
enumInstances: List[Stat] = enumInstanceNames.map { instanceName =>
q"""
case object ${Term.Name(instanceName.toUpperCase())} extends $extendedType
case object ${Term.Name(instanceName)} extends $extendedType
"""
}.toList

Expand Down Expand Up @@ -379,11 +379,19 @@ object DefaultModelGen extends ModelGen {
file: File,
source: GeneratedSource,
formatConfig: Option[File],
dialect: Dialect,
formatter: Scalafmt
): IO[GeneratedFile] = {
val sourceString =
s"${source.comment.map(_ + "\n").getOrElse("")}${source.source
.toString()}\n${source.companion.map(_.toString() + "\n").getOrElse("")}\n"

// We're using printSyntaxFor to ensure generating valid syntax for different Scala versions
val sourceString = {
s"""
|${source.comment.map(_ + "\n").getOrElse("")}
|${source.source.printSyntaxFor(dialect)}
|${source.companion.map(_.printSyntaxFor(dialect) + "\n").getOrElse("")}
|""".stripMargin
}

val formattedSource = formatConfig match {
case Some(configFile) if configFile.exists() =>
formatter.format(configFile.toPath, file.toPath, sourceString)
Expand All @@ -407,10 +415,11 @@ object DefaultModelGen extends ModelGen {
packageFile.getParentFile.mkdirs()
packageFile
}
packageStatement = Pkg(packageTerm(s"${params.basePackage}"), Pkg.Body(Nil)).toString()
packageStatement = Pkg(packageTerm(s"${params.basePackage}"), Pkg.Body(Nil))
.printSyntaxFor(params.dialect)
fileWithPackage <- FileUtil.writeToFile(file, s"$packageStatement\n\n")
files <- generatedPackage.sources
.map(appendSource(fileWithPackage, _, params.formatConfig, scalafmt))
.map(appendSource(fileWithPackage, _, params.formatConfig, params.dialect, scalafmt))
.sequence
} yield files
}
Expand Down
29 changes: 25 additions & 4 deletions src/main/scala/scraml/ModelGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ final case class ModelGenParams(
fieldMatchPolicy: FieldMatchPolicy,
defaultTypes: DefaultTypes,
librarySupport: Set[LibrarySupport],
scalaVersion: Option[(Long, Long)] = Some((2, 12)),
scalaVersion: Option[(Long, Long)] = Some((2, 13)),
formatConfig: Option[File] = None,
generateDateCreated: Boolean = false,
logger: Option[ManagedLogger] = None,
Expand All @@ -54,6 +54,12 @@ final case class ModelGenParams(
generateDefaultEnumVariant: Option[String] = None
) {
lazy val allLibraries: List[LibrarySupport] = librarySupport.toList.sorted
def dialect: Dialect = scalaVersion match {
case Some((2, 12)) => dialects.Scala212
case Some((2, 13)) => dialects.Scala213
case Some((3, _)) => dialects.Scala3
case _ => dialects.Scala213
}
}

final case class GeneratedModel(sourceFiles: Seq[GeneratedFile], packageObject: GeneratedFile) {
Expand Down Expand Up @@ -434,8 +440,23 @@ object LibrarySupport {
lib.modifyEnum(enumType, params)(acc.defn, acc.companion)
}

def appendObjectStats(defn: Defn.Object, stats: List[Stat]): Defn.Object =
defn.copy(templ = defn.templ.copy(stats = defn.templ.body.stats ++ stats))
def appendObjectStats(defn: Defn.Object, stats: List[Stat])(implicit
dialect: Dialect
): Defn.Object = {
// copy on Defn.Object and on Template.Body loses dialect information, so we're building them manually,
// even though we currently don't rely on dialects while building the tree
Defn.Object(
mods = defn.mods,
name = defn.name,
templ = defn.templ.copy(
earlyClause = defn.templ.earlyClause,
inits = defn.templ.inits,
body = Template.Body(selfOpt = None, stats = defn.templ.body.stats ++ stats),
derives = defn.templ.derives
)
)
}

def appendPkgObjectStats(packageObject: Pkg.Object, stats: List[Stat]): Pkg.Object =
packageObject.copy(templ =
packageObject.templ.copy(stats = packageObject.templ.body.stats ++ stats)
Expand Down Expand Up @@ -471,7 +492,7 @@ object ModelGen {
def addDefaultEnum(property: StringType): TypeRefDetails = {
Option(property.getDefault).fold(this) { instance =>
val enumType = Term.Name(property.getName)
val enumInstance = Term.Name(instance.getValue.toString.toUpperCase)
val enumInstance = Term.Name(instance.getValue.toString)

copy(defaultValue = Option(q"$enumType.$enumInstance"))
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scraml/ScramlPlugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ object ScramlPlugin extends AutoPlugin {
CrossVersion.partialVersion(scalaVersion.value),
s.log
)

s.log.info(s"generating API model targeting Scala ${scalaVersion.value}")
val generated = ModelGenRunner.run(DefaultModelGen)(params).unsafeRunSync()

s.log.info(s"generated API model for ${definition.raml} in $targetDir")
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/scraml/libs/CirceJsonSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ class CirceJsonSupport(formats: Map[String, String], imports: Seq[String])
propertyType match {
case Some(stringEnum: StringType) if isEnumType(stringEnum) =>
val enumType = Term.Name(stringEnum.getName)
val enumInstance = Term.Name(stringEnum.getDefault.getValue.toString.toUpperCase)
val enumInstance = Term.Name(stringEnum.getDefault.getValue.toString)

if (isRequired)
Option(q"$enumType.$enumInstance")
Expand Down Expand Up @@ -1080,7 +1080,7 @@ class CirceJsonSupport(formats: Map[String, String], imports: Seq[String])
None,
Term.Apply(
Term.Name("Right"),
Term.ArgClause(List(Term.Name(instance.getValue.toString.toUpperCase)))
Term.ArgClause(List(Term.Name(instance.getValue.toString)))
)
)
)
Expand Down Expand Up @@ -1141,7 +1141,7 @@ class CirceJsonSupport(formats: Map[String, String], imports: Seq[String])
enumType.getEnum.asScala
.map(instance =>
Case(
Term.Name(instance.getValue.toString.toUpperCase),
Term.Name(instance.getValue.toString),
None,
Lit.String(instance.getValue.toString)
)
Expand Down Expand Up @@ -1171,6 +1171,6 @@ class CirceJsonSupport(formats: Map[String, String], imports: Seq[String])
$enumDecode
""".stats

DefnWithCompanion(enumTrait, companion.map(appendObjectStats(_, stats)))
DefnWithCompanion(enumTrait, companion.map(appendObjectStats(_, stats)(params.dialect)))
}
}
36 changes: 16 additions & 20 deletions src/main/scala/scraml/libs/RefinedSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,20 @@ object RefinedSupport extends LibrarySupport {
min: Option[BigDecimal],
max: Option[BigDecimal]
): List[Type.Apply] = {
def toLiteral(number: BigDecimal): Lit = number match {
case n if n.isValidInt => Lit.Int(number.toInt)
case n if n.isValidLong => Lit.Long(number.toInt)
case n => Lit.Double(n.toDouble)
}
(min, max) match {
case (Some(lower), Some(upper)) if upper < lower =>
throw new RuntimeException("invalid min/max number bounds detected")
case (Some(lower), Some(upper)) =>
predicateType("Interval.Closed", lower, upper) :: Nil
predicateType("Interval.Closed", toLiteral(lower), toLiteral(upper)) :: Nil
case (Some(lower), None) =>
predicateType("GreaterEqual", lower) :: Nil
predicateType("GreaterEqual", toLiteral(lower)) :: Nil
case (None, Some(upper)) =>
predicateType("LessEqual", upper) :: Nil
predicateType("LessEqual", toLiteral(upper)) :: Nil
case _ =>
Nil
}
Expand All @@ -184,35 +189,27 @@ object RefinedSupport extends LibrarySupport {
case (Some(lower), Some(upper)) if upper < lower =>
throw new RuntimeException("invalid min/max string bounds detected")
case (Some(lower), Some(upper)) =>
predicateType("MinSize", lower) ::
predicateType("MaxSize", upper) ::
predicateType("MinSize", Lit.Int(lower.toInt)) ::
predicateType("MaxSize", Lit.Int(upper.toInt)) ::
Nil
case (None, Some(upper)) =>
predicateType("MaxSize", upper) :: Nil
predicateType("MaxSize", Lit.Int(upper.toInt)) :: Nil
case (Some(lower), None) =>
predicateType("MinSize", lower) :: Nil
predicateType("MinSize", Lit.Int(lower.toInt)) :: Nil
case _ =>
Nil
}
}

protected def pattern(regex: Option[RegExp]): List[Type.Apply] =
regex.toList.map { re =>
predicateType("MatchesRegex", "\"" + re.toString + "\"")
predicateType("MatchesRegex", Lit.String(re.toString))
}

protected def predicateType(name: String, constants: AnyRef*): Type.Apply =
protected def predicateType(name: String, constants: Lit*): Type.Apply =
Type.Apply(
MetaUtil.typeFromName(name),
Type.ArgClause(constants.toList.map { constant =>
Type.Select(
Term.Select(
Term.Name("Witness"),
Term.Name(constant.toString)
),
Type.Name("T")
)
})
Type.ArgClause(constants.toList)
)

protected def propertyDefinition(
Expand Down Expand Up @@ -996,14 +993,13 @@ object RefinedSupport extends LibrarySupport {
}
}

private def preface(implicit context: ModelGenContext) =
private def preface(implicit context: ModelGenContext): List[Stat] =
q"""
import eu.timepit.refined.api.Refined
import eu.timepit.refined.boolean.And
import eu.timepit.refined.collection._
import eu.timepit.refined.numeric._
import eu.timepit.refined.string._
import shapeless.Witness

""".stats

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scraml/libs/SphereJsonSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ object SphereJsonSupport extends LibrarySupport with JsonSupport {
Case(
Lit.String(instance.getValue().toString()),
None,
Term.Select(Term.Name(instance.getValue().toString().toUpperCase()), Term.Name("valid"))
Term.Select(Term.Name(instance.getValue().toString()), Term.Name("valid"))
)
)
.toList ++ List(other)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/scraml/libs/TapirSupport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ final class TapirSupport(endpointsObjectName: String) extends LibrarySupport {
Case(
Lit.String(enum.getValue.toString),
None,
q"sttp.tapir.DecodeResult.Value(${Term.Name(enum.getValue.toString.toUpperCase())})"
q"sttp.tapir.DecodeResult.Value(${Term.Name(enum.getValue.toString)})"
)
}
++
Expand Down Expand Up @@ -440,7 +440,7 @@ final class TapirSupport(endpointsObjectName: String) extends LibrarySupport {
${Term.PartialFunction(
enumNames.map { enum =>
Case(
Term.Name(enum.toUpperCase),
Term.Name(enum),
None,
Lit.String(enum)
)
Expand Down
3 changes: 2 additions & 1 deletion src/sbt-test/sbt-scraml/cats/build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
lazy val root = (project in file("."))
.settings(
scalaVersion := "2.13.15",
scalaVersion := "2.13.16",
crossScalaVersions ++= Seq("3.3.4"),
name := "scraml-cats-test",
version := "0.1",
ramlFile := Some(file("api/simple.raml")),
Expand Down
2 changes: 1 addition & 1 deletion src/sbt-test/sbt-scraml/cats/project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version = 1.10.5
sbt.version = 1.10.7
5 changes: 5 additions & 0 deletions src/sbt-test/sbt-scraml/cats/test
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@
# make sure the expected output is present
$ exists target/scala-2.13/src_managed/main/scraml/datatypes.scala
$ exists target/scala-2.13/src_managed/main/scraml/package.scala
> clean
> ++3.3
> compile
$ exists target/scala-3.3.4/src_managed/main/scraml/datatypes.scala
$ exists target/scala-3.3.4/src_managed/main/scraml/package.scala
4 changes: 2 additions & 2 deletions src/sbt-test/sbt-scraml/ct-api-sphere/build.sbt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
val circeVersion = "0.14.1"
val circeVersion = "0.14.10"

lazy val root = (project in file("."))
.settings(
scalaVersion := "2.13.15",
scalaVersion := "2.13.16",
name := "scraml-ct-api-sphere-test",
version := "0.1",
ramlFile := Some(file("reference/api-specs/api/api.raml")),
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version = 1.10.5
sbt.version = 1.10.7
14 changes: 7 additions & 7 deletions src/sbt-test/sbt-scraml/ct-api/build.sbt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
val circeVersion = "0.14.2"
val circeVersion = "0.14.10"
val monocleVersion = "3.1.0"
val refinedVersion = "0.9.27"
val tapirVersion = "1.1.0"
val refinedVersion = "0.11.3"
val tapirVersion = "1.11.9"

lazy val root = (project in file("."))
.settings(
scalaVersion := "2.13.15",
scalaVersion := "2.13.16",
crossScalaVersions ++= Seq("3.3.4"),
name := "scraml-ct-api-circe-test",
version := "0.1",
ramlFile := Some(file("reference/api-specs/api/api.raml")),
Expand All @@ -24,17 +25,16 @@ lazy val root = (project in file("."))
scraml.libs.RefinedSupport
),
Compile / sourceGenerators += runScraml,
libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.7",
libraryDependencies ++= Seq(
"eu.timepit" %% "refined",
"eu.timepit" %% "refined-cats"
).map(_ % refinedVersion),
libraryDependencies ++= Seq(
"io.circe" %% "circe-core",
"io.circe" %% "circe-generic",
"io.circe" %% "circe-parser",
"io.circe" %% "circe-refined"
"io.circe" %% "circe-parser"
).map(_ % circeVersion),
libraryDependencies += "io.circe" %% "circe-refined" % "0.15.1",
libraryDependencies ++= Seq(
"dev.optics" %% "monocle-core",
"dev.optics" %% "monocle-macro"
Expand Down
2 changes: 1 addition & 1 deletion src/sbt-test/sbt-scraml/ct-api/project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version = 1.10.5
sbt.version = 1.10.7
5 changes: 4 additions & 1 deletion src/sbt-test/sbt-scraml/ct-api/test
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
> compile
> compile
> clean
> ++3.3
> compile
12 changes: 8 additions & 4 deletions src/sbt-test/sbt-scraml/json/build.sbt
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
val circeVersion = "0.14.2"
val circeVersion = "0.14.10"

lazy val root = (project in file("."))
.settings(
name := "scraml-json-test",
scalaVersion := "2.13.15",
scalaVersion := "2.13.16",
crossScalaVersions ++= Seq("3.3.4"),
version := "0.1",
ramlFile := Some(file("api/json.raml")),
basePackageName := "scraml",
librarySupport := Set(scraml.libs.CirceJsonSupport(
formats = Map("localDateTime" -> "io.circe.Decoder.decodeLocalDateTime"),
// formats = Map("localDateTime" -> "io.circe.Decoder.decodeLocalDateTime"),
imports = Seq("io.circe.Decoder.decodeLocalDateTime") // alternative to formats to provide custom codecs via import
)),
defaultEnumVariant := Some("Unknown"),
Compile / sourceGenerators += runScraml,
libraryDependencies += "com.commercetools" %% "sphere-json" % "0.12.5",
libraryDependencies ++= (CrossVersion.partialVersion(scalaVersion.value) match {
case Some((2, 13)) => Seq("com.commercetools" %% "sphere-json" % "0.12.5")
case _ => Seq()
}),
libraryDependencies ++= Seq(
"io.circe" %% "circe-core",
"io.circe" %% "circe-generic",
Expand Down
2 changes: 1 addition & 1 deletion src/sbt-test/sbt-scraml/json/project/build.properties
Original file line number Diff line number Diff line change
@@ -1 +1 @@
sbt.version = 1.10.5
sbt.version = 1.10.7
6 changes: 6 additions & 0 deletions src/sbt-test/sbt-scraml/json/test
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@
$ exists target/scala-2.13/src_managed/main/scraml/datatypes.scala
$ exists target/scala-2.13/src_managed/main/scraml/package.scala
> run
> clean
> ++3.3
> compile
$ exists target/scala-3.3.4/src_managed/main/scraml/datatypes.scala
$ exists target/scala-3.3.4/src_managed/main/scraml/package.scala
> run
Loading

0 comments on commit 7050455

Please sign in to comment.