Skip to content

Commit

Permalink
refactor scala object deserialization (#657)
Browse files Browse the repository at this point in the history
* refactor scala object deserialization

* Update build.sbt

* Update Classes.scala

* refactor beanintrospector

* Update build.sbt

* Update CaseObjectDeserializerTest.scala
  • Loading branch information
pjfanning authored Dec 1, 2023
1 parent 94062a3 commit b71e3aa
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 23 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ mimaBinaryIssueFilters ++= Seq(
ProblemFilters.exclude[DirectMissingMethodProblem]("com.fasterxml.jackson.module.scala.introspect.ScalaAnnotationIntrospector.findSerializationKeyType"),
ProblemFilters.exclude[DirectMissingMethodProblem]("com.fasterxml.jackson.module.scala.introspect.ScalaAnnotationIntrospector.findSerializationType"),
ProblemFilters.exclude[DirectMissingMethodProblem]("com.fasterxml.jackson.module.scala.introspect.ScalaAnnotationIntrospector.findSerializationInclusionForContent"),
ProblemFilters.exclude[DirectMissingMethodProblem]("com.fasterxml.jackson.module.scala.introspect.ScalaAnnotationIntrospector.findSerializationInclusion")
ProblemFilters.exclude[DirectMissingMethodProblem]("com.fasterxml.jackson.module.scala.introspect.ScalaAnnotationIntrospector.findSerializationInclusion"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("com.fasterxml.jackson.module.scala.util.ClassW.getModuleField"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("com.fasterxml.jackson.module.scala.util.ClassW.com$fasterxml$jackson$module$scala$util$ClassW$$moduleField")
)

def compareVersions(version1: String, version2: String): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,15 @@ import com.fasterxml.jackson.module.scala.util.ClassW
import scala.languageFeature.postfixOps
import scala.util.control.NonFatal

private class ScalaObjectDeserializer(clazz: Class[_]) extends StdDeserializer[Any](classOf[Any]) {
override def deserialize(p: JsonParser, ctxt: DeserializationContext): Any = {
try {
clazz.getField("MODULE$").get(null)
} catch {
case NonFatal(_) => null
}
}
private class ScalaObjectDeserializer(value: Any) extends StdDeserializer[Any](classOf[Any]) {
override def deserialize(p: JsonParser, ctxt: DeserializationContext): Any = value
}

private object ScalaObjectDeserializerResolver extends Deserializers.Base {
override def findBeanDeserializer(javaType: JavaType, config: DeserializationConfig, beanDesc: BeanDescription): JsonDeserializer[_] = {
val clazz = javaType.getRawClass
if (ClassW(clazz).isScalaObject)
new ScalaObjectDeserializer(clazz)
else null
ClassW(javaType.getRawClass).getModuleField.flatMap { field =>
Option(field.get(null))
}.map(new ScalaObjectDeserializer(_)).orNull
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,11 @@ object BeanIntrospector {
//create properties for all appropriate fields
val fields = for {
cls <- hierarchy
scalaCaseObject = isScalaCaseObject(cls)
isScalaObject = ClassW(cls).isScalaObject
field <- cls.getDeclaredFields
isScalaObject = ClassW(cls).isScalaObject || isScalaCaseObject(cls)
name = maybePrivateName(field)
if !name.contains('$')
if (isScalaObject || scalaCaseObject || isAcceptableField(field))
if isScalaObject || isAcceptableField(field)
beanGetter = findBeanGetter(cls, name)
beanSetter = findBeanSetter(cls, name)
} yield PropertyDescriptor(name, findConstructorParam(hierarchy.head, name), Some(field), findGetter(cls, name), findSetter(cls, name), beanGetter, beanSetter)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.fasterxml.jackson.module.scala.util

import java.lang.reflect.Field
import scala.annotation.tailrec
import scala.language.implicitConversions
import scala.reflect.{ScalaLongSignature, ScalaSignature}
Expand Down Expand Up @@ -28,9 +29,11 @@ trait ClassW extends PimpedType[Class[_]] {
hasSigHelper(value)
}

def isScalaObject: Boolean = {
Try(value.getField("MODULE$")).isSuccess
}
def isScalaObject: Boolean = moduleField.isSuccess

def getModuleField: Option[Field] = moduleField.toOption

private lazy val moduleField: Try[Field] = Try(value.getField("MODULE$"))
}

object ClassW {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class CaseObjectDeserializerTest extends DeserializerTest {
val original = TestObject
val json = mapper.writeValueAsString(original)
val deserialized = mapper.readValue(json, TestObject.getClass)
assert(deserialized == original)
assert(deserialized === original)
}

it should "deserialize Foo and not create a new instance" in {
val mapper = JsonMapper.builder().addModule(DefaultScalaModule).addModule(ScalaObjectDeserializerModule).build()
val original = Foo
val json = mapper.writeValueAsString(original)
val deserialized = mapper.readValue(json, Foo.getClass)
assert(deserialized == original)
assert(deserialized === original)
}

it should "deserialize Foo and not create a new instance (visibility settings)" in {
Expand All @@ -42,7 +42,7 @@ class CaseObjectDeserializerTest extends DeserializerTest {
val original = Foo
val json = mapper.writeValueAsString(original)
val deserialized = mapper.readValue(json, Foo.getClass)
assert(deserialized == original)
assert(deserialized === original)
}

"An ObjectMapper with ClassTagExtensions and DefaultScalaModule" should "deserialize a case object and not create a new instance" in {
Expand All @@ -52,7 +52,7 @@ class CaseObjectDeserializerTest extends DeserializerTest {
val original = TestObject
val json = mapper.writeValueAsString(original)
val deserialized = mapper.readValue[TestObject.type](json)
assert(deserialized == original)
assert(deserialized === original)
}

"An ObjectMapper without ScalaObjectDeserializerModule" should "deserialize a case object but create a new instance" in {
Expand Down

0 comments on commit b71e3aa

Please sign in to comment.