Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of Scala 3 union types as Schema.Enum #787

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,29 +91,35 @@ private case class DeriveSchema()(using val ctx: Quotes) {
case Some(mirror) =>
mirror.mirrorType match {
case MirrorType.Sum =>
deriveEnum[T](mirror, stack)
val types = mirror.types.toList
val labels = mirror.labels.toList
deriveEnum[T](types, labels, stack, false)
case MirrorType.Product =>
deriveCaseClass[T](mirror, stack, top)
deriveCaseClass[T](mirror, stack, top)
}
case None =>
val sym = typeRepr.typeSymbol
if (sym.isClassDef && sym.flags.is(Flags.Module)) {
deriveCaseObject[T](stack, top)
}
else {
report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported")
} else {
def collectOrTypeCases(tpe: TypeRepr): List[TypeRepr] = tpe match {
case OrType(left, right) => collectOrTypeCases(left) ++ collectOrTypeCases(right)
case _ => List(tpe)
}

val types = collectOrTypeCases(typeRepr)
val size = types.size
if (size > 1) {
val labels = (1 to size).map(_.toString).toList
deriveEnum[T](types, labels, stack, true)
} else {
report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported")
}
}
}
}
}
}

//println()
//println()
//println(s"RESULT ${typeRepr.show}")
//println(s"------")
//println(s"RESULT ${result.show}")

result
}

Expand Down Expand Up @@ -369,17 +375,13 @@ private case class DeriveSchema()(using val ctx: Quotes) {
}.toMap
}

def deriveEnum[T: Type](mirror: Mirror, stack: Stack)(using Quotes) = {
def deriveEnum[T: Type](types: List[TypeRepr], labels: List[String], stack: Stack, isUnion: Boolean)(using Quotes) = {
val selfRefSymbol = Symbol.newVal(Symbol.spliceOwner, s"derivedSchema${stack.size}", TypeRepr.of[Schema[T]], Flags.Lazy, Symbol.noSymbol)
val selfRef = Ref(selfRefSymbol)
val newStack = stack.push(selfRef, TypeRepr.of[T])

val labels = mirror.labels.toList
val types = mirror.types.toList
val typesAndLabels = types.zip(labels)

val cases = typesAndLabels.map { case (tpe, label) => deriveCase[T](tpe, label, newStack) }

val numParentFields: Int = TypeRepr.of[T].typeSymbol.declaredFields.length
val childrenFields = TypeRepr.of[T].typeSymbol.children.map(_.declaredFields.length)
val childrenFieldsConstructor = TypeRepr.of[T].typeSymbol.children.map(_.caseFields.length)
Expand All @@ -390,10 +392,11 @@ private case class DeriveSchema()(using val ctx: Quotes) {
val docstringExpr = Expr(docstring)
'{zio.schema.annotation.description(${docstringExpr})}
}
val annotationExprs = (isSimpleEnum, hasSimpleEnumAnn) match {
case (true, false) => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr).+:('{zio.schema.annotation.simpleEnum(true)})
case (false, true) => throw new Exception(s"${TypeRepr.of[T].typeSymbol.name} must be a simple Enum")
case _ => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr)
val annotationExprs = (isUnion, isSimpleEnum, hasSimpleEnumAnn) match {
case (false, true, false) => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr).+:('{zio.schema.annotation.simpleEnum(true)})
case (false, false, true) => throw new Exception(s"${TypeRepr.of[T].typeSymbol.name} must be a simple Enum")
case (true, _, _) => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr).+:('{zio.schema.annotation.noDiscriminator()})
case _ => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr)
}
val genericAnnotations = if (TypeRepr.of[T].classSymbol.exists(_.typeMembers.nonEmpty)){
val typeMembersExpr = Expr.ofSeq(TypeRepr.of[T].classSymbol.get.typeMembers.map { t => Expr(t.name) })
Expand All @@ -402,7 +405,9 @@ private case class DeriveSchema()(using val ctx: Quotes) {
} else List.empty
val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(docAnnotationExpr.toList)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(genericAnnotations)}) }

val typeInfo = '{TypeId.parse(${Expr(TypeRepr.of[T].classSymbol.get.fullName.replaceAll("\\$", ""))})}
val typeInfo =
if (isUnion) '{TypeId.fromTypeName("|")}
else '{TypeId.parse(${Expr(TypeRepr.of[T].classSymbol.get.fullName.replaceAll("\\$", ""))})}

val applied = if (cases.length <= 22) {
val args = List(typeInfo) ++ cases :+ annotations
Expand Down Expand Up @@ -441,7 +446,6 @@ private case class DeriveSchema()(using val ctx: Quotes) {
}
}


// Derive Field for a CaseClass
def deriveField[T: Type](repr: TypeRepr, name: String, anns: List[Expr[Any]], stack: Stack)(using Quotes) = {
import zio.schema.validation.Validation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ trait VersionSpecificDeriveSchemaSpec extends ZIOSpecDefault {
)
assert(Schema[AutoDerives])(hasSameSchema(expected))
},
test("correctly assigns noDiscriminator to union") {
val derived: Schema[Int | String | Boolean] = DeriveSchema.gen
derived match {
case Schema.Enum3(id, case1, case2, case3, annotations) =>
assertTrue(id.name == "|") &&
assertTrue(case1.id == "1") &&
assertTrue(case2.id == "2") &&
assertTrue(case3.id == "3") &&
assertTrue(annotations == Chunk(noDiscriminator()))
case _ => assertTrue(false)
}
},
test("correctly assigns simpleEnum to enum") {
val derived: Schema[Colour] = DeriveSchema.gen[Colour]
assertTrue(derived.annotations == Chunk(simpleEnum(true)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import zio.Console._
import zio._
import zio.json.{ DeriveJsonEncoder, JsonEncoder }
import zio.schema._
import zio.schema.annotation._
import zio.test.Assertion._
import zio.test.TestAspect._
import zio.test.*
Expand All @@ -21,6 +22,26 @@ object DefaultValueSpec extends ZIOSpecDefault {
val result = JsonCodec.jsonDecoder(Schema[WithDefaultValue]).decodeJson("""{"orderId": 1}""")
assertTrue(result.isRight)
}
),
suite("union types")(
test("union type of standard types") {
val schema = Schema.chunk(DeriveSchema.gen[Int | String | Boolean])
val decoder = JsonCodec.jsonDecoder(schema)
val encoder = JsonCodec.jsonEncoder(schema)
val json = """["abc",1,true]"""
val value = Chunk[Int | String | Boolean]("abc", 1, true)
assert(decoder.decodeJson(json))(equalTo(Right(value)))
assert(encoder.encodeJson(value))(equalTo(json))
},
test("union type of enums") {
val schema = Schema.chunk(Schema[Result])
val decoder = JsonCodec.jsonDecoder(schema)
val encoder = JsonCodec.jsonEncoder(schema)
val json = """[{"res":{"Left":"Err1"}},{"res":{"Left":"Err21"}},{"res":{"Right":{"i":1}}}]"""
val value = Chunk[Result](Result(Left(ErrorGroup1.Err1)), Result(Left(ErrorGroup2.Err21)), Result(Right(Value(1))))
assert(decoder.decodeJson(json))(equalTo(Right(value)))
assert(encoder.encodeJson(value))(equalTo(json))
}
)
)

Expand All @@ -30,4 +51,21 @@ object DefaultValueSpec extends ZIOSpecDefault {
implicit lazy val schema: Schema[WithDefaultValue] = DeriveSchema.gen[WithDefaultValue]
}

enum ErrorGroup1:
case Err1
case Err2
case Err3

enum ErrorGroup2:
case Err21
case Err22
case Err23

case class Value(i: Int)
object Value:
given Schema[Value] = DeriveSchema.gen[Value]

case class Result(res: Either[ErrorGroup1 | ErrorGroup2, Value])
object Result:
given Schema[Result] = DeriveSchema.gen[Result]
}
Loading