From 935bf1cba79fd193f89b494271587887a474d6d2 Mon Sep 17 00:00:00 2001 From: Andriy Plokhotnyuk Date: Tue, 14 Jan 2025 19:05:43 +0100 Subject: [PATCH] Support of Scala 3 union types --- .../scala-3/zio/schema/DeriveSchema.scala | 50 ++++++++++--------- .../VersionSpecificDeriveSchemaSpec.scala | 12 +++++ .../zio/schema/codec/DefaultValueSpec.scala | 38 ++++++++++++++ 3 files changed, 77 insertions(+), 23 deletions(-) diff --git a/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala b/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala index 78e3e9c4c..353bc0afe 100644 --- a/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala +++ b/zio-schema-derivation/shared/src/main/scala-3/zio/schema/DeriveSchema.scala @@ -91,29 +91,35 @@ private case class DeriveSchema()(using val ctx: Quotes) { case Some(mirror) => mirror.mirrorType match { case MirrorType.Sum => - deriveEnum[T](mirror, stack) + val types = mirror.types.toList + val labels = mirror.labels.toList + deriveEnum[T](types, labels, stack, false) case MirrorType.Product => - deriveCaseClass[T](mirror, stack, top) + deriveCaseClass[T](mirror, stack, top) } case None => val sym = typeRepr.typeSymbol if (sym.isClassDef && sym.flags.is(Flags.Module)) { deriveCaseObject[T](stack, top) - } - else { - report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported") + } else { + def collectOrTypeCases(tpe: TypeRepr): List[TypeRepr] = tpe match { + case OrType(left, right) => collectOrTypeCases(left) ++ collectOrTypeCases(right) + case _ => List(tpe) + } + + val types = collectOrTypeCases(typeRepr) + val size = types.size + if (size > 1) { + val labels = (1 to size).map(_.toString).toList + deriveEnum[T](types, labels, stack, true) + } else { + report.errorAndAbort(s"Deriving schema for ${typeRepr.show} is not supported") + } } } } } } - - //println() - //println() - //println(s"RESULT ${typeRepr.show}") - //println(s"------") - //println(s"RESULT ${result.show}") - result } @@ -369,17 +375,13 @@ private case class DeriveSchema()(using val ctx: Quotes) { }.toMap } - def deriveEnum[T: Type](mirror: Mirror, stack: Stack)(using Quotes) = { + def deriveEnum[T: Type](types: List[TypeRepr], labels: List[String], stack: Stack, isUnion: Boolean)(using Quotes) = { val selfRefSymbol = Symbol.newVal(Symbol.spliceOwner, s"derivedSchema${stack.size}", TypeRepr.of[Schema[T]], Flags.Lazy, Symbol.noSymbol) val selfRef = Ref(selfRefSymbol) val newStack = stack.push(selfRef, TypeRepr.of[T]) - val labels = mirror.labels.toList - val types = mirror.types.toList val typesAndLabels = types.zip(labels) - val cases = typesAndLabels.map { case (tpe, label) => deriveCase[T](tpe, label, newStack) } - val numParentFields: Int = TypeRepr.of[T].typeSymbol.declaredFields.length val childrenFields = TypeRepr.of[T].typeSymbol.children.map(_.declaredFields.length) val childrenFieldsConstructor = TypeRepr.of[T].typeSymbol.children.map(_.caseFields.length) @@ -390,10 +392,11 @@ private case class DeriveSchema()(using val ctx: Quotes) { val docstringExpr = Expr(docstring) '{zio.schema.annotation.description(${docstringExpr})} } - val annotationExprs = (isSimpleEnum, hasSimpleEnumAnn) match { - case (true, false) => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr).+:('{zio.schema.annotation.simpleEnum(true)}) - case (false, true) => throw new Exception(s"${TypeRepr.of[T].typeSymbol.name} must be a simple Enum") - case _ => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr) + val annotationExprs = (isUnion, isSimpleEnum, hasSimpleEnumAnn) match { + case (false, true, false) => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr).+:('{zio.schema.annotation.simpleEnum(true)}) + case (false, false, true) => throw new Exception(s"${TypeRepr.of[T].typeSymbol.name} must be a simple Enum") + case (true, _, _) => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr).+:('{zio.schema.annotation.noDiscriminator()}) + case _ => TypeRepr.of[T].typeSymbol.annotations.filter(filterAnnotation).map(_.asExpr) } val genericAnnotations = if (TypeRepr.of[T].classSymbol.exists(_.typeMembers.nonEmpty)){ val typeMembersExpr = Expr.ofSeq(TypeRepr.of[T].classSymbol.get.typeMembers.map { t => Expr(t.name) }) @@ -402,7 +405,9 @@ private case class DeriveSchema()(using val ctx: Quotes) { } else List.empty val annotations = '{ zio.Chunk.fromIterable(${Expr.ofSeq(annotationExprs)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(docAnnotationExpr.toList)}) ++ zio.Chunk.fromIterable(${Expr.ofSeq(genericAnnotations)}) } - val typeInfo = '{TypeId.parse(${Expr(TypeRepr.of[T].classSymbol.get.fullName.replaceAll("\\$", ""))})} + val typeInfo = + if (isUnion) '{TypeId.fromTypeName("|")} + else '{TypeId.parse(${Expr(TypeRepr.of[T].classSymbol.get.fullName.replaceAll("\\$", ""))})} val applied = if (cases.length <= 22) { val args = List(typeInfo) ++ cases :+ annotations @@ -441,7 +446,6 @@ private case class DeriveSchema()(using val ctx: Quotes) { } } - // Derive Field for a CaseClass def deriveField[T: Type](repr: TypeRepr, name: String, anns: List[Expr[Any]], stack: Stack)(using Quotes) = { import zio.schema.validation.Validation diff --git a/zio-schema-derivation/shared/src/test/scala-3/zio/schema/VersionSpecificDeriveSchemaSpec.scala b/zio-schema-derivation/shared/src/test/scala-3/zio/schema/VersionSpecificDeriveSchemaSpec.scala index 68c7d195d..a6e3e97ae 100644 --- a/zio-schema-derivation/shared/src/test/scala-3/zio/schema/VersionSpecificDeriveSchemaSpec.scala +++ b/zio-schema-derivation/shared/src/test/scala-3/zio/schema/VersionSpecificDeriveSchemaSpec.scala @@ -81,6 +81,18 @@ trait VersionSpecificDeriveSchemaSpec extends ZIOSpecDefault { ) assert(Schema[AutoDerives])(hasSameSchema(expected)) }, + test("correctly assigns noDiscriminator to union") { + val derived: Schema[Int | String | Boolean] = DeriveSchema.gen + derived match { + case Schema.Enum3(id, case1, case2, case3, annotations) => + assertTrue(id.name == "|") && + assertTrue(case1.id == "1") && + assertTrue(case2.id == "2") && + assertTrue(case3.id == "3") && + assertTrue(annotations == Chunk(noDiscriminator())) + case _ => assertTrue(false) + } + }, test("correctly assigns simpleEnum to enum") { val derived: Schema[Colour] = DeriveSchema.gen[Colour] assertTrue(derived.annotations == Chunk(simpleEnum(true))) diff --git a/zio-schema-json/shared/src/test/scala-3/zio/schema/codec/DefaultValueSpec.scala b/zio-schema-json/shared/src/test/scala-3/zio/schema/codec/DefaultValueSpec.scala index 4b2dd94bc..da3d721e6 100644 --- a/zio-schema-json/shared/src/test/scala-3/zio/schema/codec/DefaultValueSpec.scala +++ b/zio-schema-json/shared/src/test/scala-3/zio/schema/codec/DefaultValueSpec.scala @@ -4,6 +4,7 @@ import zio.Console._ import zio._ import zio.json.{ DeriveJsonEncoder, JsonEncoder } import zio.schema._ +import zio.schema.annotation._ import zio.test.Assertion._ import zio.test.TestAspect._ import zio.test.* @@ -21,6 +22,26 @@ object DefaultValueSpec extends ZIOSpecDefault { val result = JsonCodec.jsonDecoder(Schema[WithDefaultValue]).decodeJson("""{"orderId": 1}""") assertTrue(result.isRight) } + ), + suite("union types")( + test("union type of standard types") { + val schema = Schema.chunk(DeriveSchema.gen[Int | String | Boolean]) + val decoder = JsonCodec.jsonDecoder(schema) + val encoder = JsonCodec.jsonEncoder(schema) + val json = """["abc",1,true]""" + val value = Chunk[Int | String | Boolean]("abc", 1, true) + assert(decoder.decodeJson(json))(equalTo(Right(value))) + assert(encoder.encodeJson(value))(equalTo(json)) + }, + test("union type of enums") { + val schema = Schema.chunk(Schema[Result]) + val decoder = JsonCodec.jsonDecoder(schema) + val encoder = JsonCodec.jsonEncoder(schema) + val json = """[{"res":{"Left":"Err1"}},{"res":{"Left":"Err21"}},{"res":{"Right":{"i":1}}}]""" + val value = Chunk[Result](Result(Left(ErrorGroup1.Err1)), Result(Left(ErrorGroup2.Err21)), Result(Right(Value(1)))) + assert(decoder.decodeJson(json))(equalTo(Right(value))) + assert(encoder.encodeJson(value))(equalTo(json)) + } ) ) @@ -30,4 +51,21 @@ object DefaultValueSpec extends ZIOSpecDefault { implicit lazy val schema: Schema[WithDefaultValue] = DeriveSchema.gen[WithDefaultValue] } + enum ErrorGroup1: + case Err1 + case Err2 + case Err3 + + enum ErrorGroup2: + case Err21 + case Err22 + case Err23 + + case class Value(i: Int) + object Value: + given Schema[Value] = DeriveSchema.gen[Value] + + case class Result(res: Either[ErrorGroup1 | ErrorGroup2, Value]) + object Result: + given Schema[Result] = DeriveSchema.gen[Result] }