From 425327d69ab4134bf221986752d1c1a0d13b293b Mon Sep 17 00:00:00 2001 From: Zac Sweers Date: Mon, 4 Dec 2023 16:56:34 -0500 Subject: [PATCH] Add KSP support to InjectConstructorFactory, MembersInjectorCodeGen, and assisted injection (#795) These were done together as they are linked. Tried to share code where possible, but there are some meaty parts in dagger generation util that I couldn't really easily share much. Note that one test is disabled in KSP until binding module generation supports KSP, and I felt it best to save that for a later PR since that generator requires some reworking to avoid class merging during generation. I also fixed a few mis-named files along the way. Ref: #751 --- .../compiler/internal/KotlinPoetUtils.kt | 49 ++ .../java/com/squareup/anvil/compiler/Utils.kt | 3 + .../codegen/ClassReferenceExtensions.kt | 5 +- .../codegen/dagger/AssistedFactoryCodeGen.kt | 629 ++++++++++++++++++ .../dagger/AssistedFactoryGenerator.kt | 339 ---------- ...tGenerator.kt => AssistedInjectCodeGen.kt} | 175 +++-- .../codegen/dagger/DaggerGenerationUtils.kt | 244 ++++++- ....kt => InjectConstructorFactoryCodeGen.kt} | 165 +++-- ...orGenerator.kt => MapKeyCreatorCodeGen.kt} | 9 +- ...Generator.kt => MembersInjectorCodeGen.kt} | 155 +++-- .../anvil/compiler/codegen/ksp/KspUtil.kt | 94 +++ .../com/squareup/anvil/compiler/TestUtils.kt | 39 ++ .../dagger/AssistedFactoryGeneratorTest.kt | 25 +- .../dagger/AssistedInjectGeneratorTest.kt | 18 +- .../anvil/compiler/dagger/DaggerTestUtils.kt | 5 + .../InjectConstructorFactoryGeneratorTest.kt | 15 +- .../dagger/MapKeyCreatorGeneratorTest.kt | 19 +- .../dagger/MembersInjectorGeneratorTest.kt | 11 +- 18 files changed, 1477 insertions(+), 522 deletions(-) create mode 100644 compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryCodeGen.kt delete mode 100644 compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt rename compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/{AssistedInjectGenerator.kt => AssistedInjectCodeGen.kt} (54%) rename compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/{InjectConstructorFactoryGenerator.kt => InjectConstructorFactoryCodeGen.kt} (51%) rename compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/{MapKeyCreatorGenerator.kt => MapKeyCreatorCodeGen.kt} (97%) rename compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/{MembersInjectorGenerator.kt => MembersInjectorCodeGen.kt} (55%) diff --git a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/KotlinPoetUtils.kt b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/KotlinPoetUtils.kt index 2740fd7f2..01c4775a8 100644 --- a/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/KotlinPoetUtils.kt +++ b/compiler-utils/src/main/java/com/squareup/anvil/compiler/internal/KotlinPoetUtils.kt @@ -7,10 +7,16 @@ import com.squareup.anvil.compiler.api.AnvilCompilationException import com.squareup.anvil.compiler.internal.reference.AnnotatedReference import com.squareup.anvil.compiler.internal.reference.TypeReference import com.squareup.anvil.compiler.internal.reference.canResolveFqName +import com.squareup.kotlinpoet.ANY import com.squareup.kotlinpoet.AnnotationSpec import com.squareup.kotlinpoet.ClassName +import com.squareup.kotlinpoet.Dynamic import com.squareup.kotlinpoet.FileSpec +import com.squareup.kotlinpoet.LambdaTypeName +import com.squareup.kotlinpoet.ParameterizedTypeName import com.squareup.kotlinpoet.TypeName +import com.squareup.kotlinpoet.TypeVariableName +import com.squareup.kotlinpoet.WildcardTypeName import com.squareup.kotlinpoet.jvm.jvmSuppressWildcards import org.jetbrains.kotlin.descriptors.ClassDescriptor import org.jetbrains.kotlin.descriptors.ModuleDescriptor @@ -140,3 +146,46 @@ public fun FileSpec.Companion.createAnvilSpec( } .addFileComment(generatorComment) .build() + +/** + * For `Map` this will return [`String`, `Int`]. For star projections like + * `List<*>` the result will be mapped to [Any]. + */ +public val TypeName.unwrappedTypes: List get() { + return when (this) { + is ParameterizedTypeName -> typeArguments + else -> emptyList() + } +} + +/** + * Returns the result of [findRawType] or throws. + */ +public fun TypeName.requireRawType(): ClassName { + return findRawType() ?: error("Cannot get raw type from $this") +} + +/** + * Returns the raw type for this [TypeName] or null if one can't be resolved. + */ +public fun TypeName.findRawType(): ClassName? { + return when (this) { + is ClassName -> this + is ParameterizedTypeName -> rawType + is TypeVariableName -> ANY + is WildcardTypeName -> outTypes.first().findRawType() + is LambdaTypeName -> { + var count = parameters.size + if (receiver != null) { + count++ + } + val functionSimpleName = if (count >= 23) { + "FunctionN" + } else { + "Function$count" + } + ClassName("kotlin.jvm.functions", functionSimpleName) + } + Dynamic -> null + } +} diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/Utils.kt b/compiler/src/main/java/com/squareup/anvil/compiler/Utils.kt index e79f4d120..75df8dd39 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/Utils.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/Utils.kt @@ -11,6 +11,7 @@ import com.squareup.anvil.annotations.compat.MergeModules import com.squareup.anvil.compiler.internal.fqName import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.toClassReferenceOrNull +import com.squareup.kotlinpoet.asClassName import dagger.Binds import dagger.Component import dagger.Lazy @@ -49,6 +50,7 @@ internal val daggerModuleFqName = Module::class.fqName internal val daggerBindsFqName = Binds::class.fqName internal val daggerProvidesFqName = Provides::class.fqName internal val daggerLazyFqName = Lazy::class.fqName +internal val daggerLazyClassName = Lazy::class.asClassName() internal val injectFqName = Inject::class.fqName internal val qualifierFqName = Qualifier::class.fqName internal val mapKeyFqName = MapKey::class.fqName @@ -56,6 +58,7 @@ internal val assistedFqName = Assisted::class.fqName internal val assistedFactoryFqName = AssistedFactory::class.fqName internal val assistedInjectFqName = AssistedInject::class.fqName internal val providerFqName = Provider::class.fqName +internal val providerClassName = Provider::class.asClassName() internal val jvmSuppressWildcardsFqName = JvmSuppressWildcards::class.fqName internal val jvmFieldFqName = JvmField::class.fqName internal val publishedApiFqName = PublishedApi::class.fqName diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ClassReferenceExtensions.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ClassReferenceExtensions.kt index b393f76ec..e622b5916 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ClassReferenceExtensions.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ClassReferenceExtensions.kt @@ -127,7 +127,10 @@ internal fun Collection.injectConstructor(): T? constructor.annotations.joinToString(" ", postfix = " ") // We special-case @Inject to match Dagger using the non-fully-qualified name .replace("@javax.inject.Inject", "@Inject") + - constructor.fqName.toString().replace(".", "") + constructor.fqName.toString().replace(".", "") + + constructor.parameters.joinToString(", ", prefix = "(", postfix = ")") { param -> + param.type().asClassReference().shortName + } }.joinToString() throw AnvilCompilationExceptionClassReference( classReference = constructors[0].declaringClass, diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryCodeGen.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryCodeGen.kt new file mode 100644 index 000000000..646e4370e --- /dev/null +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryCodeGen.kt @@ -0,0 +1,629 @@ +package com.squareup.anvil.compiler.codegen.dagger + +import com.google.auto.service.AutoService +import com.google.devtools.ksp.KspExperimental +import com.google.devtools.ksp.getAllSuperTypes +import com.google.devtools.ksp.getAnnotationsByType +import com.google.devtools.ksp.getConstructors +import com.google.devtools.ksp.getDeclaredFunctions +import com.google.devtools.ksp.getVisibility +import com.google.devtools.ksp.processing.Resolver +import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.processing.SymbolProcessorProvider +import com.google.devtools.ksp.symbol.KSAnnotated +import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSFunction +import com.google.devtools.ksp.symbol.KSFunctionDeclaration +import com.google.devtools.ksp.symbol.KSNode +import com.google.devtools.ksp.symbol.KSValueParameter +import com.google.devtools.ksp.symbol.Visibility.PROTECTED +import com.google.devtools.ksp.symbol.Visibility.PUBLIC +import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker +import com.squareup.anvil.compiler.api.AnvilCompilationException +import com.squareup.anvil.compiler.api.AnvilContext +import com.squareup.anvil.compiler.api.CodeGenerator +import com.squareup.anvil.compiler.api.GeneratedFile +import com.squareup.anvil.compiler.api.createGeneratedFile +import com.squareup.anvil.compiler.assistedFactoryFqName +import com.squareup.anvil.compiler.assistedFqName +import com.squareup.anvil.compiler.assistedInjectFqName +import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator +import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryCodeGen.AssistedParameterKey.Companion.toAssistedParameterKey +import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryCodeGen.Embedded.AssistedFactoryFunction.Companion.toAssistedFactoryFunction +import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryCodeGen.KspGenerator.AssistedFactoryFunction.Companion.toAssistedFactoryFunction +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider +import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException +import com.squareup.anvil.compiler.codegen.ksp.argumentAt +import com.squareup.anvil.compiler.codegen.ksp.isAnnotationPresent +import com.squareup.anvil.compiler.codegen.ksp.isInterface +import com.squareup.anvil.compiler.codegen.ksp.resolveKSClassDeclaration +import com.squareup.anvil.compiler.internal.createAnvilSpec +import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionClassReference +import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionReference +import com.squareup.anvil.compiler.internal.reference.ClassReference +import com.squareup.anvil.compiler.internal.reference.MemberFunctionReference +import com.squareup.anvil.compiler.internal.reference.ParameterReference +import com.squareup.anvil.compiler.internal.reference.Visibility +import com.squareup.anvil.compiler.internal.reference.allSuperTypeClassReferences +import com.squareup.anvil.compiler.internal.reference.argumentAt +import com.squareup.anvil.compiler.internal.reference.asClassName +import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferences +import com.squareup.anvil.compiler.internal.reference.generateClassName +import com.squareup.kotlinpoet.ClassName +import com.squareup.kotlinpoet.FileSpec +import com.squareup.kotlinpoet.FunSpec +import com.squareup.kotlinpoet.KModifier.OVERRIDE +import com.squareup.kotlinpoet.KModifier.PRIVATE +import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy +import com.squareup.kotlinpoet.PropertySpec +import com.squareup.kotlinpoet.TypeName +import com.squareup.kotlinpoet.TypeSpec +import com.squareup.kotlinpoet.TypeVariableName +import com.squareup.kotlinpoet.asClassName +import com.squareup.kotlinpoet.jvm.jvmStatic +import com.squareup.kotlinpoet.ksp.TypeParameterResolver +import com.squareup.kotlinpoet.ksp.toClassName +import com.squareup.kotlinpoet.ksp.toTypeName +import com.squareup.kotlinpoet.ksp.toTypeParameterResolver +import com.squareup.kotlinpoet.ksp.toTypeVariableName +import com.squareup.kotlinpoet.ksp.writeTo +import dagger.assisted.Assisted +import dagger.assisted.AssistedInject +import dagger.internal.InstanceFactory +import org.jetbrains.kotlin.descriptors.ModuleDescriptor +import org.jetbrains.kotlin.psi.KtFile +import java.io.File +import javax.inject.Provider + +object AssistedFactoryCodeGen : AnvilApplicabilityChecker { + + override fun isApplicable(context: AnvilContext) = context.generateFactories + + internal class KspGenerator( + override val env: SymbolProcessorEnvironment, + ) : AnvilSymbolProcessor() { + @AutoService(SymbolProcessorProvider::class) + class Provider : AnvilSymbolProcessorProvider(AssistedFactoryCodeGen, ::KspGenerator) + + override fun processChecked(resolver: Resolver): List { + resolver.getSymbolsWithAnnotation(assistedFactoryFqName.asString()) + .filterIsInstance() + .forEach { clazz -> + generateFactoryClass(clazz) + .writeTo(env.codeGenerator, aggregating = false, listOf(clazz.containingFile!!)) + } + return emptyList() + } + + private fun generateFactoryClass( + clazz: KSClassDeclaration, + ): FileSpec { + val typeParameterResolver = clazz.typeParameters.toTypeParameterResolver() + val function = clazz.requireSingleAbstractFunction(typeParameterResolver) + + val returnType = try { + function.returnType + } catch (e: Exception) { + // Catch the exception and throw the same error that Dagger would. + throw KspAnvilException( + message = "Invalid return type: ${clazz.qualifiedName?.asString()}. An assisted factory's " + + "abstract method must return a type with an @AssistedInject-annotated constructor.", + node = function.node, + cause = e, + ) + } + + // The return type of the function must have an @AssistedInject constructor. + val constructor = returnType + .getConstructors() + .singleOrNull { + it.isAnnotationPresent() + } + ?: throw KspAnvilException( + message = "Invalid return type: ${returnType.qualifiedName?.asString()}. An assisted factory's abstract " + + "method must return a type with an @AssistedInject-annotated constructor.", + node = clazz, + ) + + val functionParameters = function.parameterKeys + val assistedParameters = constructor.parameters.filter { parameter -> + parameter.isAnnotationPresent() + } + + // Check that the parameters of the function match the @Assisted parameters of the constructor. + if (assistedParameters.size != functionParameters.size) { + throw KspAnvilException( + message = "The parameters in the factory method must match the @Assisted parameters in " + + "${returnType.qualifiedName?.asString()}.", + node = clazz, + ) + } + + // Compute for each parameter its key. + val functionParameterKeys = function.parameterKeys + val assistedParameterKeys = assistedParameters.map { + it.toAssistedParameterKey(it.type.toTypeName(typeParameterResolver)) + } + + // The factory function may not have two or more parameters with the same key. + val duplicateKeys = functionParameterKeys + .groupBy { it.key } + .filter { it.value.size > 1 } + .values + .flatten() + + if (duplicateKeys.isNotEmpty()) { + // Complain about the first duplicate key that occurs, similar to Dagger. + val key = functionParameterKeys.first { it in duplicateKeys } + + throw KspAnvilException( + message = buildString { + append("@AssistedFactory method has duplicate @Assisted types: ") + if (key.identifier.isNotEmpty()) { + append("@Assisted(\"${key.identifier}\") ") + } + append(key.typeName) + }, + node = clazz, + ) + } + + // Check that for each parameter of the factory function there is a parameter with the same + // key in the @AssistedInject constructor. + val notMatchingKeys = (functionParameterKeys + assistedParameterKeys) + .groupBy { it.key } + .filter { it.value.size == 1 } + .values + .flatten() + + if (notMatchingKeys.isNotEmpty()) { + throw KspAnvilException( + message = "The parameters in the factory method must match the @Assisted parameters in " + + "${returnType.qualifiedName?.asString()}.", + node = clazz, + ) + } + + val typeParameters = clazz.typeParameters + + val functionName = function.simpleName + val baseFactoryIsInterface = clazz.isInterface() + val functionParameterPairs = function.parameterPairs + + val spec = buildSpec( + originClassNAme = clazz.toClassName(), + targetType = returnType.toClassName(), + functionName = functionName, + typeParameters = typeParameters.map { it.toTypeVariableName(typeParameterResolver) }, + assistedParameterKeys = assistedParameterKeys, + baseFactoryIsInterface = baseFactoryIsInterface, + functionParameterPairs = functionParameterPairs.map { (ref, typeName) -> + ref.name!!.asString() to typeName + }, + functionParameterKeys = functionParameterKeys, + ) + + return spec + } + + private fun KSClassDeclaration.requireSingleAbstractFunction( + typeParameterResolver: TypeParameterResolver, + ): AssistedFactoryFunction { + val implementingType = asType(emptyList()) + + // `clazz` must be first in the list because of `distinctBy { ... }`, which keeps the first + // matched element. If the function's inherited, it can be overridden as well. Prioritizing + // the version from the file we're parsing ensures the correct variance of the referenced types. + // TODO can't use getAllFunctions() yet due to https://github.com/google/ksp/issues/1619 + val assistedFunctions = sequenceOf(this) + .plus(getAllSuperTypes().mapNotNull { it.resolveKSClassDeclaration() }) + .distinctBy { it.qualifiedName?.asString() } + .flatMap { clazz -> + clazz.getDeclaredFunctions() + .filter { + it.isAbstract && + (it.getVisibility() == PUBLIC || it.getVisibility() == PROTECTED) + } + } + .distinctBy { it.simpleName.asString() } + .map { + it.asMemberOf(implementingType) + .toAssistedFactoryFunction(it, typeParameterResolver) + } + .toList() + + // Check for exact number of functions. + return when (assistedFunctions.size) { + 0 -> throw KspAnvilException( + message = "The @AssistedFactory-annotated type is missing an abstract, non-default " + + "method whose return type matches the assisted injection type.", + node = this, + ) + + 1 -> assistedFunctions[0] + else -> { + val foundFunctions = assistedFunctions + .sortedBy { it.simpleName } + .joinToString { func -> + "${func.qualifiedName}(${func.parameterPairs.map { it.first.name }})" + } + throw KspAnvilException( + message = "The @AssistedFactory-annotated type should contain a single abstract, " + + "non-default method but found multiple: [$foundFunctions]", + node = this, + ) + } + } + } + + /** + * Represents a parsed function in an `@AssistedInject.Factory`-annotated interface. + */ + private data class AssistedFactoryFunction( + val simpleName: String, + val qualifiedName: String, + val returnType: KSClassDeclaration, + val node: KSNode, + val parameterKeys: List, + /** + * Pair of parameter reference to parameter type. + */ + val parameterPairs: List>, + ) { + + companion object { + fun KSFunction.toAssistedFactoryFunction( + originalDeclaration: KSFunctionDeclaration, + typeParameterResolver: TypeParameterResolver, + ): AssistedFactoryFunction { + return AssistedFactoryFunction( + simpleName = originalDeclaration.simpleName.asString(), + qualifiedName = originalDeclaration.qualifiedName!!.asString(), + returnType = returnType!!.resolveKSClassDeclaration()!!, + node = originalDeclaration, + parameterKeys = originalDeclaration.parameters.mapIndexed { index, param -> + param.toAssistedParameterKey( + parameterTypes[index]!!.toTypeName(typeParameterResolver), + ) + }, + parameterPairs = originalDeclaration.parameters.mapIndexed { index, param -> + param to parameterTypes[index]!!.toTypeName(typeParameterResolver) + }, + ) + } + } + } + } + + @AutoService(CodeGenerator::class) + internal class Embedded : PrivateCodeGenerator() { + + override fun isApplicable(context: AnvilContext) = AssistedFactoryCodeGen.isApplicable(context) + + override fun generateCodePrivate( + codeGenDir: File, + module: ModuleDescriptor, + projectFiles: Collection, + ) { + projectFiles + .classAndInnerClassReferences(module) + .filter { it.isAnnotatedWith(assistedFactoryFqName) } + .forEach { clazz -> + generateFactoryClass(codeGenDir, clazz) + } + } + + private fun generateFactoryClass( + codeGenDir: File, + clazz: ClassReference.Psi, + ): GeneratedFile { + val function = clazz.requireSingleAbstractFunction() + + val returnType = try { + function.function.resolveGenericReturnType(clazz) + } catch (e: AnvilCompilationException) { + // Catch the exception and throw the same error that Dagger would. + throw AnvilCompilationExceptionFunctionReference( + message = "Invalid return type: ${clazz.fqName}. An assisted factory's " + + "abstract method must return a type with an @AssistedInject-annotated constructor.", + functionReference = function.function, + cause = e, + ) + } + + // The return type of the function must have an @AssistedInject constructor. + val constructor = returnType + .constructors + .singleOrNull { it.isAnnotatedWith(assistedInjectFqName) } + ?: throw AnvilCompilationExceptionClassReference( + message = "Invalid return type: ${returnType.fqName}. An assisted factory's abstract " + + "method must return a type with an @AssistedInject-annotated constructor.", + classReference = clazz, + ) + + val functionParameters = function.parameterKeys + val assistedParameters = constructor.parameters.filter { parameter -> + parameter.annotations.any { it.fqName == assistedFqName } + } + + // Check that the parameters of the function match the @Assisted parameters of the constructor. + if (assistedParameters.size != functionParameters.size) { + throw AnvilCompilationExceptionClassReference( + message = "The parameters in the factory method must match the @Assisted parameters in " + + "${returnType.fqName}.", + classReference = clazz, + ) + } + + // Compute for each parameter its key. + val functionParameterKeys = function.parameterKeys + val assistedParameterKeys = assistedParameters.map { it.toAssistedParameterKey(clazz) } + + // The factory function may not have two or more parameters with the same key. + val duplicateKeys = functionParameterKeys + .groupBy { it.key } + .filter { it.value.size > 1 } + .values + .flatten() + + if (duplicateKeys.isNotEmpty()) { + // Complain about the first duplicate key that occurs, similar to Dagger. + val key = functionParameterKeys.first { it in duplicateKeys } + + throw AnvilCompilationExceptionClassReference( + message = buildString { + append("@AssistedFactory method has duplicate @Assisted types: ") + if (key.identifier.isNotEmpty()) { + append("@Assisted(\"${key.identifier}\") ") + } + append(key.typeName) + }, + classReference = clazz, + ) + } + + // Check that for each parameter of the factory function there is a parameter with the same + // key in the @AssistedInject constructor. + val notMatchingKeys = (functionParameterKeys + assistedParameterKeys) + .groupBy { it.key } + .filter { it.value.size == 1 } + .values + .flatten() + + if (notMatchingKeys.isNotEmpty()) { + throw AnvilCompilationExceptionClassReference( + message = "The parameters in the factory method must match the @Assisted parameters in " + + "${returnType.fqName}.", + classReference = clazz, + ) + } + + val typeParameters = clazz.typeParameters + + val functionName = function.function.name + val baseFactoryIsInterface = clazz.isInterface() + val functionParameterPairs = function.parameterPairs + + val spec = buildSpec( + originClassNAme = clazz.asClassName(), + targetType = returnType.asClassName(), + functionName = functionName, + typeParameters = typeParameters.map { it.typeVariableName }, + assistedParameterKeys = assistedParameterKeys, + baseFactoryIsInterface = baseFactoryIsInterface, + functionParameterPairs = functionParameterPairs.map { (ref, typeName) -> + ref.name to typeName + }, + functionParameterKeys = functionParameterKeys, + ) + + return createGeneratedFile(codeGenDir, spec.packageName, spec.name, spec.toString()) + } + + private fun ClassReference.Psi.requireSingleAbstractFunction(): AssistedFactoryFunction { + // `clazz` must be first in the list because of `distinctBy { ... }`, which keeps the first + // matched element. If the function's inherited, it can be overridden as well. Prioritizing + // the version from the file we're parsing ensures the correct variance of the referenced types. + val assistedFunctions = allSuperTypeClassReferences(includeSelf = true) + .distinctBy { it.fqName } + .flatMap { clazz -> + clazz.functions + .filter { + it.isAbstract() && + (it.visibility() == Visibility.PUBLIC || it.visibility() == Visibility.PROTECTED) + } + } + .distinctBy { it.name } + .map { it.toAssistedFactoryFunction(this) } + .toList() + + // Check for exact number of functions. + return when (assistedFunctions.size) { + 0 -> throw AnvilCompilationExceptionClassReference( + message = "The @AssistedFactory-annotated type is missing an abstract, non-default " + + "method whose return type matches the assisted injection type.", + classReference = this, + ) + + 1 -> assistedFunctions[0] + else -> { + val foundFunctions = assistedFunctions + .sortedBy { it.function.name } + .joinToString { func -> + "${func.function.fqName}(${func.parameterPairs.map { it.first.name }})" + } + throw AnvilCompilationExceptionClassReference( + message = "The @AssistedFactory-annotated type should contain a single abstract, " + + "non-default method but found multiple: [$foundFunctions]", + classReference = this, + ) + } + } + } + + /** + * Represents a parsed function in an `@AssistedInject.Factory`-annotated interface. + */ + private data class AssistedFactoryFunction( + val function: MemberFunctionReference, + val parameterKeys: List, + /** + * Pair of parameter reference to parameter type. + */ + val parameterPairs: List>, + ) { + companion object { + fun MemberFunctionReference.toAssistedFactoryFunction( + factoryClass: ClassReference.Psi, + ): AssistedFactoryFunction { + return AssistedFactoryFunction( + function = this, + parameterKeys = parameters.map { it.toAssistedParameterKey(factoryClass) }, + parameterPairs = parameters.map { it to it.resolveTypeName(factoryClass) }, + ) + } + } + } + } + + private const val DELEGATE_FACTORY_NAME = "delegateFactory" + + private fun buildSpec( + originClassNAme: ClassName, + targetType: ClassName, + functionName: String, + typeParameters: List, + assistedParameterKeys: List, + baseFactoryIsInterface: Boolean, + functionParameterPairs: List>, + functionParameterKeys: List, + ): FileSpec { + val generatedFactoryTypeName = targetType.generateClassName(suffix = "_Factory") + .optionallyParameterizedByNames(typeParameters) + + val baseFactoryTypeName = originClassNAme.optionallyParameterizedByNames(typeParameters) + val returnTypeName = targetType.optionallyParameterizedByNames(typeParameters) + val implClassName = originClassNAme.generateClassName(suffix = "_Impl") + val implParameterizedTypeName = implClassName.optionallyParameterizedByNames(typeParameters) + + return FileSpec.createAnvilSpec(implClassName.packageName, implClassName.simpleName) { + TypeSpec.classBuilder(implClassName) + .apply { + addTypeVariables(typeParameters) + + if (baseFactoryIsInterface) { + addSuperinterface(baseFactoryTypeName) + } else { + superclass(baseFactoryTypeName) + } + + primaryConstructor( + FunSpec.constructorBuilder() + .addParameter(DELEGATE_FACTORY_NAME, generatedFactoryTypeName) + .build(), + ) + + addProperty( + PropertySpec.builder(DELEGATE_FACTORY_NAME, generatedFactoryTypeName) + .initializer(DELEGATE_FACTORY_NAME) + .addModifiers(PRIVATE) + .build(), + ) + } + .addFunction( + FunSpec.builder(functionName) + .addModifiers(OVERRIDE) + .returns(returnTypeName) + .apply { + functionParameterPairs.forEach { parameter -> + addParameter(parameter.first, parameter.second) + } + + // We call the @AssistedInject constructor. Therefore, find for each assisted + // parameter the function parameter where the keys match. + val argumentList = assistedParameterKeys.joinToString { assistedParameterKey -> + val functionIndex = functionParameterKeys.indexOfFirst { + it.key == assistedParameterKey.key + } + check(functionIndex >= 0) { + // Sanity check, this should not happen with the noMatchingKeys list check above. + "Unexpected assistedIndex." + } + + functionParameterPairs[functionIndex].first + } + + addStatement("return $DELEGATE_FACTORY_NAME.get($argumentList)") + } + .build(), + ) + .apply { + TypeSpec.companionObjectBuilder() + .addFunction( + FunSpec.builder("create") + .jvmStatic() + .addTypeVariables(typeParameters) + .addParameter(DELEGATE_FACTORY_NAME, generatedFactoryTypeName) + .returns(Provider::class.asClassName().parameterizedBy(baseFactoryTypeName)) + .addStatement( + "return %T.create(%T($DELEGATE_FACTORY_NAME))", + InstanceFactory::class, + implParameterizedTypeName, + ) + .build(), + ) + .build() + .let { + addType(it) + } + } + .build() + .let { addType(it) } + } + } + + // Dagger matches parameters of the factory function with the parameters of the @AssistedInject + // constructor through a key. Initially, they used the order of parameters, but that has changed. + // The key is a combination of the type and identifier (value parameter) of the + // @Assisted("...") annotation. For each parameter the key must be unique. + private data class AssistedParameterKey( + val typeName: TypeName, + val identifier: String, + ) { + + // Key value is similar to a hash function. There used to be a special case for KotlinTypes + // which were parameterized, but this is now handled by KotlinPoet's TypeName. + // `MyType` and `MyType` now generate different hashCodes. + val key: Int = identifier.hashCode() * 31 + typeName.hashCode() + + companion object { + @OptIn(KspExperimental::class) + fun KSValueParameter.toAssistedParameterKey( + typeName: TypeName, + ): AssistedParameterKey { + return AssistedParameterKey( + typeName, + getAnnotationsByType(Assisted::class) + .singleOrNull() + ?.value + .orEmpty(), + ) + } + + fun ParameterReference.toAssistedParameterKey( + factoryClass: ClassReference.Psi, + ): AssistedParameterKey { + return AssistedParameterKey( + typeName = resolveTypeName(factoryClass), + identifier = annotations + .singleOrNull { it.fqName == assistedFqName } + ?.let { annotation -> + annotation.argumentAt("value", index = 0)?.value() + } + .orEmpty(), + ) + } + } + } +} diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt deleted file mode 100644 index 0232785c8..000000000 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedFactoryGenerator.kt +++ /dev/null @@ -1,339 +0,0 @@ -package com.squareup.anvil.compiler.codegen.dagger - -import com.google.auto.service.AutoService -import com.squareup.anvil.compiler.api.AnvilCompilationException -import com.squareup.anvil.compiler.api.AnvilContext -import com.squareup.anvil.compiler.api.CodeGenerator -import com.squareup.anvil.compiler.api.GeneratedFile -import com.squareup.anvil.compiler.api.createGeneratedFile -import com.squareup.anvil.compiler.assistedFactoryFqName -import com.squareup.anvil.compiler.assistedFqName -import com.squareup.anvil.compiler.assistedInjectFqName -import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator -import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryGenerator.AssistedFactoryFunction.Companion.toAssistedFactoryFunction -import com.squareup.anvil.compiler.codegen.dagger.AssistedFactoryGenerator.AssistedParameterKey.Companion.toAssistedParameterKey -import com.squareup.anvil.compiler.internal.asClassName -import com.squareup.anvil.compiler.internal.buildFile -import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionClassReference -import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionFunctionReference -import com.squareup.anvil.compiler.internal.reference.ClassReference -import com.squareup.anvil.compiler.internal.reference.MemberFunctionReference -import com.squareup.anvil.compiler.internal.reference.ParameterReference -import com.squareup.anvil.compiler.internal.reference.Visibility -import com.squareup.anvil.compiler.internal.reference.allSuperTypeClassReferences -import com.squareup.anvil.compiler.internal.reference.argumentAt -import com.squareup.anvil.compiler.internal.reference.asClassName -import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferences -import com.squareup.anvil.compiler.internal.reference.generateClassName -import com.squareup.anvil.compiler.internal.safePackageString -import com.squareup.kotlinpoet.FileSpec -import com.squareup.kotlinpoet.FunSpec -import com.squareup.kotlinpoet.KModifier.OVERRIDE -import com.squareup.kotlinpoet.KModifier.PRIVATE -import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy -import com.squareup.kotlinpoet.PropertySpec -import com.squareup.kotlinpoet.TypeName -import com.squareup.kotlinpoet.TypeSpec -import com.squareup.kotlinpoet.asClassName -import com.squareup.kotlinpoet.jvm.jvmStatic -import dagger.internal.InstanceFactory -import org.jetbrains.kotlin.descriptors.ModuleDescriptor -import org.jetbrains.kotlin.psi.KtFile -import java.io.File -import javax.inject.Provider - -@AutoService(CodeGenerator::class) -internal class AssistedFactoryGenerator : PrivateCodeGenerator() { - - override fun isApplicable(context: AnvilContext) = context.generateFactories - - override fun generateCodePrivate( - codeGenDir: File, - module: ModuleDescriptor, - projectFiles: Collection, - ) { - projectFiles - .classAndInnerClassReferences(module) - .filter { it.isAnnotatedWith(assistedFactoryFqName) } - .forEach { clazz -> - generateFactoryClass(codeGenDir, clazz) - } - } - - private fun generateFactoryClass( - codeGenDir: File, - clazz: ClassReference.Psi, - ): GeneratedFile { - val packageName = clazz.packageFqName.safePackageString() - val delegateFactoryName = "delegateFactory" - - val function = clazz.requireSingleAbstractFunction() - - val returnType = try { - function.function.resolveGenericReturnType(clazz) - } catch (e: AnvilCompilationException) { - // Catch the exception and throw the same error that Dagger would. - throw AnvilCompilationExceptionFunctionReference( - message = "Invalid return type: ${clazz.fqName}. An assisted factory's " + - "abstract method must return a type with an @AssistedInject-annotated constructor.", - functionReference = function.function, - cause = e, - ) - } - - // The return type of the function must have an @AssistedInject constructor. - val constructor = returnType - .constructors - .singleOrNull { it.isAnnotatedWith(assistedInjectFqName) } - ?: throw AnvilCompilationExceptionClassReference( - message = "Invalid return type: ${returnType.fqName}. An assisted factory's abstract " + - "method must return a type with an @AssistedInject-annotated constructor.", - classReference = clazz, - ) - - val functionParameters = function.parameterKeys - val assistedParameters = constructor.parameters.filter { parameter -> - parameter.annotations.any { it.fqName == assistedFqName } - } - - // Check that the parameters of the function match the @Assisted parameters of the constructor. - if (assistedParameters.size != functionParameters.size) { - throw AnvilCompilationExceptionClassReference( - message = "The parameters in the factory method must match the @Assisted parameters in " + - "${returnType.fqName}.", - classReference = clazz, - ) - } - - // Compute for each parameter its key. - val functionParameterKeys = function.parameterKeys - val assistedParameterKeys = assistedParameters.map { it.toAssistedParameterKey(clazz) } - - // The factory function may not have two or more parameters with the same key. - val duplicateKeys = functionParameterKeys - .groupBy { it.key } - .filter { it.value.size > 1 } - .values - .flatten() - - if (duplicateKeys.isNotEmpty()) { - // Complain about the first duplicate key that occurs, similar to Dagger. - val key = functionParameterKeys.first { it in duplicateKeys } - - throw AnvilCompilationExceptionClassReference( - message = buildString { - append("@AssistedFactory method has duplicate @Assisted types: ") - if (key.identifier.isNotEmpty()) { - append("@Assisted(\"${key.identifier}\") ") - } - append(key.typeName) - }, - classReference = clazz, - ) - } - - // Check that for each parameter of the factory function there is a parameter with the same - // key in the @AssistedInject constructor. - val notMatchingKeys = (functionParameterKeys + assistedParameterKeys) - .groupBy { it.key } - .filter { it.value.size == 1 } - .values - .flatten() - - if (notMatchingKeys.isNotEmpty()) { - throw AnvilCompilationExceptionClassReference( - message = "The parameters in the factory method must match the @Assisted parameters in " + - "${returnType.fqName}.", - classReference = clazz, - ) - } - - val typeParameters = clazz.typeParameters - - val generatedFactoryTypeName = returnType.generateClassName(suffix = "_Factory") - .asClassName() - .optionallyParameterizedBy(typeParameters) - - val baseFactoryTypeName = clazz.asClassName().optionallyParameterizedBy(typeParameters) - - val returnTypeName = returnType.asClassName().optionallyParameterizedBy(typeParameters) - - val implClassName = clazz.generateClassName(suffix = "_Impl").asClassName() - val implParameterizedTypeName = implClassName.optionallyParameterizedBy(typeParameters) - val functionName = function.function.name - val baseFactoryIsInterface = clazz.isInterface() - val functionParameterPairs = function.parameterPairs - - val content = FileSpec.buildFile(packageName, implClassName.simpleName) { - TypeSpec.classBuilder(implClassName) - .apply { - typeParameters.forEach { addTypeVariable(it.typeVariableName) } - - if (baseFactoryIsInterface) { - addSuperinterface(baseFactoryTypeName) - } else { - superclass(baseFactoryTypeName) - } - - primaryConstructor( - FunSpec.constructorBuilder() - .addParameter(delegateFactoryName, generatedFactoryTypeName) - .build(), - ) - - addProperty( - PropertySpec.builder(delegateFactoryName, generatedFactoryTypeName) - .initializer(delegateFactoryName) - .addModifiers(PRIVATE) - .build(), - ) - } - .addFunction( - FunSpec.builder(functionName) - .addModifiers(OVERRIDE) - .returns(returnTypeName) - .apply { - functionParameterPairs.forEach { parameter -> - addParameter(parameter.first.name, parameter.second) - } - - // We call the @AssistedInject constructor. Therefore, find for each assisted - // parameter the function parameter where the keys match. - val argumentList = assistedParameterKeys.joinToString { assistedParameterKey -> - val functionIndex = functionParameterKeys.indexOfFirst { - it.key == assistedParameterKey.key - } - check(functionIndex >= 0) { - // Sanity check, this should not happen with the noMatchingKeys list check above. - "Unexpected assistedIndex." - } - - functionParameterPairs[functionIndex].first.name - } - - addStatement("return $delegateFactoryName.get($argumentList)") - } - .build(), - ) - .apply { - TypeSpec.companionObjectBuilder() - .addFunction( - FunSpec.builder("create") - .jvmStatic() - .addTypeVariables(typeParameters.map { it.typeVariableName }) - .addParameter(delegateFactoryName, generatedFactoryTypeName) - .returns(Provider::class.asClassName().parameterizedBy(baseFactoryTypeName)) - .addStatement( - "return %T.create(%T($delegateFactoryName))", - InstanceFactory::class, - implParameterizedTypeName, - ) - .build(), - ) - .build() - .let { - addType(it) - } - } - .build() - .let { addType(it) } - } - - return createGeneratedFile(codeGenDir, packageName, implClassName.simpleName, content) - } - - private fun ClassReference.Psi.requireSingleAbstractFunction(): AssistedFactoryFunction { - // `clazz` must be first in the list because of `distinctBy { ... }`, which keeps the first - // matched element. If the function's inherited, it can be overridden as well. Prioritizing - // the version from the file we're parsing ensures the correct variance of the referenced types. - val assistedFunctions = allSuperTypeClassReferences(includeSelf = true) - .distinctBy { it.fqName } - .flatMap { clazz -> - clazz.functions - .filter { - it.isAbstract() && - (it.visibility() == Visibility.PUBLIC || it.visibility() == Visibility.PROTECTED) - } - } - .distinctBy { it.name } - .map { it.toAssistedFactoryFunction(this) } - .toList() - - // Check for exact number of functions. - return when (assistedFunctions.size) { - 0 -> throw AnvilCompilationExceptionClassReference( - message = "The @AssistedFactory-annotated type is missing an abstract, non-default " + - "method whose return type matches the assisted injection type.", - classReference = this, - ) - 1 -> assistedFunctions[0] - else -> { - val foundFunctions = assistedFunctions - .sortedBy { it.function.name } - .joinToString { func -> - "${func.function.fqName}(${func.parameterPairs.map { it.first.name }})" - } - throw AnvilCompilationExceptionClassReference( - message = "The @AssistedFactory-annotated type should contain a single abstract, " + - "non-default method but found multiple: [$foundFunctions]", - classReference = this, - ) - } - } - } - - /** - * Represents a parsed function in an `@AssistedInject.Factory`-annotated interface. - */ - private data class AssistedFactoryFunction( - val function: MemberFunctionReference, - val parameterKeys: List, - /** - * Pair of parameter reference to parameter type. - */ - val parameterPairs: List>, - ) { - companion object { - fun MemberFunctionReference.toAssistedFactoryFunction( - factoryClass: ClassReference.Psi, - ): AssistedFactoryFunction { - return AssistedFactoryFunction( - function = this, - parameterKeys = parameters.map { it.toAssistedParameterKey(factoryClass) }, - parameterPairs = parameters.map { it to it.resolveTypeName(factoryClass) }, - ) - } - } - } - - // Dagger matches parameters of the factory function with the parameters of the @AssistedInject - // constructor through a key. Initially, they used the order of parameters, but that has changed. - // The key is a combination of the type and identifier (value parameter) of the - // @Assisted("...") annotation. For each parameter the key must be unique. - private data class AssistedParameterKey( - val typeName: TypeName, - val identifier: String, - ) { - - // Key value is similar to a hash function. There used to be a special case for KotlinTypes - // which were parameterized, but this is now handled by KotlinPoet's TypeName. - // `MyType` and `MyType` now generate different hashCodes. - val key: Int = identifier.hashCode() * 31 + typeName.hashCode() - - companion object { - fun ParameterReference.toAssistedParameterKey( - factoryClass: ClassReference.Psi, - ): AssistedParameterKey { - return AssistedParameterKey( - typeName = resolveTypeName(factoryClass), - identifier = annotations - .singleOrNull { it.fqName == assistedFqName } - ?.let { annotation -> - annotation.argumentAt("value", index = 0)?.value() - } - .orEmpty(), - ) - } - } - } -} diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedInjectGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedInjectCodeGen.kt similarity index 54% rename from compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedInjectGenerator.kt rename to compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedInjectCodeGen.kt index 23a886c96..741f83931 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedInjectGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/AssistedInjectCodeGen.kt @@ -1,6 +1,13 @@ package com.squareup.anvil.compiler.codegen.dagger import com.google.auto.service.AutoService +import com.google.devtools.ksp.processing.Resolver +import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.processing.SymbolProcessorProvider +import com.google.devtools.ksp.symbol.KSAnnotated +import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSFunctionDeclaration +import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker import com.squareup.anvil.compiler.api.AnvilContext import com.squareup.anvil.compiler.api.CodeGenerator import com.squareup.anvil.compiler.api.GeneratedFile @@ -8,21 +15,31 @@ import com.squareup.anvil.compiler.api.createGeneratedFile import com.squareup.anvil.compiler.assistedInjectFqName import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator import com.squareup.anvil.compiler.codegen.injectConstructor -import com.squareup.anvil.compiler.internal.asClassName -import com.squareup.anvil.compiler.internal.buildFile +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider +import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException +import com.squareup.anvil.compiler.codegen.ksp.injectConstructors +import com.squareup.anvil.compiler.codegen.ksp.isAnnotationPresent +import com.squareup.anvil.compiler.internal.createAnvilSpec import com.squareup.anvil.compiler.internal.reference.AnvilCompilationExceptionClassReference import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.MemberFunctionReference import com.squareup.anvil.compiler.internal.reference.asClassName import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferences import com.squareup.anvil.compiler.internal.reference.generateClassName -import com.squareup.anvil.compiler.internal.safePackageString +import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier.PRIVATE import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec +import com.squareup.kotlinpoet.TypeVariableName import com.squareup.kotlinpoet.jvm.jvmStatic +import com.squareup.kotlinpoet.ksp.toClassName +import com.squareup.kotlinpoet.ksp.toTypeParameterResolver +import com.squareup.kotlinpoet.ksp.toTypeVariableName +import com.squareup.kotlinpoet.ksp.writeTo +import dagger.assisted.AssistedInject import org.jetbrains.kotlin.descriptors.ModuleDescriptor import org.jetbrains.kotlin.psi.KtFile import java.io.File @@ -37,56 +54,130 @@ import java.io.File * class AssistedService_Factory { .. } * ``` */ -@AutoService(CodeGenerator::class) -internal class AssistedInjectGenerator : PrivateCodeGenerator() { +object AssistedInjectCodeGen : AnvilApplicabilityChecker { override fun isApplicable(context: AnvilContext) = context.generateFactories - override fun generateCodePrivate( - codeGenDir: File, - module: ModuleDescriptor, - projectFiles: Collection, - ) { - projectFiles - .classAndInnerClassReferences(module) - .forEach { clazz -> - clazz.constructors - .injectConstructor() - ?.takeIf { it.isAnnotatedWith(assistedInjectFqName) } - ?.let { - generateFactoryClass(codeGenDir, clazz, it) + internal class KspGenerator( + override val env: SymbolProcessorEnvironment, + ) : AnvilSymbolProcessor() { + @AutoService(SymbolProcessorProvider::class) + class Provider : AnvilSymbolProcessorProvider(AssistedInjectCodeGen, ::KspGenerator) + + override fun processChecked(resolver: Resolver): List { + resolver.injectConstructors() + .forEach { (clazz, constructor) -> + if (!constructor.isAnnotationPresent()) { + // Only generating @AssistedInject constructors + return@forEach } - } + generateFactoryClass( + clazz = clazz, + constructor = constructor, + ) + .writeTo(env.codeGenerator, aggregating = false, listOf(constructor.containingFile!!)) + } + return emptyList() + } + + private fun generateFactoryClass( + clazz: KSClassDeclaration, + constructor: KSFunctionDeclaration, + ): FileSpec { + val typeParameterResolver = clazz.typeParameters.toTypeParameterResolver() + val constructorParameters = constructor.parameters + .mapToConstructorParameters(typeParameterResolver) + val memberInjectParameters = clazz.memberInjectParameters() + val typeParameters = clazz.typeParameters + + val spec = generateFactoryClass( + clazz = clazz.toClassName(), + memberInjectParameters = memberInjectParameters, + typeParameters = typeParameters.map { it.toTypeVariableName(typeParameterResolver) }, + constructorParameters = constructorParameters, + onError = { message -> + throw KspAnvilException( + message = message, + node = constructor, + ) + }, + ) + + return spec + } } - private fun generateFactoryClass( - codeGenDir: File, - clazz: ClassReference.Psi, - constructor: MemberFunctionReference.Psi, - ): GeneratedFile { - val packageName = clazz.packageFqName.safePackageString() - val classIdName = clazz.generateClassName(suffix = "_Factory") - val className = classIdName.relativeClassName.asString() + @AutoService(CodeGenerator::class) + internal class Embedded : PrivateCodeGenerator() { - val constructorParameters = constructor.parameters.mapToConstructorParameters() - val memberInjectParameters = clazz.memberInjectParameters() + override fun isApplicable(context: AnvilContext) = AssistedInjectCodeGen.isApplicable(context) + + override fun generateCodePrivate( + codeGenDir: File, + module: ModuleDescriptor, + projectFiles: Collection, + ) { + projectFiles + .classAndInnerClassReferences(module) + .forEach { clazz -> + clazz.constructors + .injectConstructor() + ?.takeIf { it.isAnnotatedWith(assistedInjectFqName) } + ?.let { + generateFactoryClass(codeGenDir, clazz, it) + } + } + } + + private fun generateFactoryClass( + codeGenDir: File, + clazz: ClassReference.Psi, + constructor: MemberFunctionReference.Psi, + ): GeneratedFile { + val constructorParameters = constructor.parameters.mapToConstructorParameters() + val memberInjectParameters = clazz.memberInjectParameters() + val typeParameters = clazz.typeParameters + + val spec = generateFactoryClass( + clazz = clazz.asClassName(), + memberInjectParameters = memberInjectParameters, + typeParameters = typeParameters.map { it.typeVariableName }, + constructorParameters = constructorParameters, + onError = { message -> + throw AnvilCompilationExceptionClassReference( + message = message, + classReference = clazz, + ) + }, + ) + + return createGeneratedFile(codeGenDir, spec.packageName, spec.name, spec.toString()) + } + } + + private fun generateFactoryClass( + clazz: ClassName, + memberInjectParameters: List, + typeParameters: List, + constructorParameters: List, + onError: (String) -> Nothing, + ): FileSpec { + val packageName = clazz.packageName + val factoryClass = clazz.generateClassName(suffix = "_Factory") val parameters = constructorParameters + memberInjectParameters val parametersAssisted = parameters.filter { it.isAssisted } val parametersNotAssisted = parameters.filterNot { it.isAssisted } - checkAssistedParametersAreDistinct(clazz, parametersAssisted) - - val typeParameters = clazz.typeParameters + checkAssistedParametersAreDistinct(parametersAssisted, onError) - val factoryClass = classIdName.asClassName() - val factoryClassParameterized = factoryClass.optionallyParameterizedBy(typeParameters) - val classType = clazz.asClassName().optionallyParameterizedBy(typeParameters) + val factoryClassParameterized = factoryClass.optionallyParameterizedByNames(typeParameters) + val classType = clazz.optionallyParameterizedByNames(typeParameters) - val content = FileSpec.buildFile(packageName, className) { + val spec = FileSpec.createAnvilSpec(packageName, factoryClass.simpleName) { TypeSpec.classBuilder(factoryClass) .apply { - typeParameters.forEach { addTypeVariable(it.typeVariableName) } + addTypeVariables(typeParameters) primaryConstructor( FunSpec.constructorBuilder() @@ -138,7 +229,7 @@ internal class AssistedInjectGenerator : PrivateCodeGenerator() { .jvmStatic() .apply { if (typeParameters.isNotEmpty()) { - addTypeVariables(typeParameters.map { it.typeVariableName }) + addTypeVariables(typeParameters) } parametersNotAssisted.forEach { parameter -> addParameter(parameter.name, parameter.providerTypeName) @@ -162,7 +253,7 @@ internal class AssistedInjectGenerator : PrivateCodeGenerator() { .jvmStatic() .apply { if (typeParameters.isNotEmpty()) { - addTypeVariables(typeParameters.map { it.typeVariableName }) + addTypeVariables(typeParameters) } constructorParameters.forEach { parameter -> addParameter( @@ -186,12 +277,12 @@ internal class AssistedInjectGenerator : PrivateCodeGenerator() { .let { addType(it) } } - return createGeneratedFile(codeGenDir, packageName, className, content) + return spec } private fun checkAssistedParametersAreDistinct( - clazz: ClassReference, parameters: List, + onError: (String) -> Nothing, ) { // Parameters are identical, if there types and identifier match. val duplicateAssistedParameters = parameters @@ -219,6 +310,6 @@ internal class AssistedInjectGenerator : PrivateCodeGenerator() { append(parameter.typeName) } - throw AnvilCompilationExceptionClassReference(message = errorMessage, classReference = clazz) + onError(errorMessage) } } diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/DaggerGenerationUtils.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/DaggerGenerationUtils.kt index 6b16dca5d..7c3a5fd9a 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/DaggerGenerationUtils.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/DaggerGenerationUtils.kt @@ -1,7 +1,27 @@ package com.squareup.anvil.compiler.codegen.dagger +import com.google.devtools.ksp.KspExperimental +import com.google.devtools.ksp.getAllSuperTypes +import com.google.devtools.ksp.getAnnotationsByType +import com.google.devtools.ksp.getDeclaredProperties +import com.google.devtools.ksp.getVisibility +import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSPropertyDeclaration +import com.google.devtools.ksp.symbol.KSType +import com.google.devtools.ksp.symbol.KSValueParameter +import com.google.devtools.ksp.symbol.Visibility import com.squareup.anvil.compiler.assistedFqName +import com.squareup.anvil.compiler.codegen.ksp.KspAnvilException +import com.squareup.anvil.compiler.codegen.ksp.argumentAt +import com.squareup.anvil.compiler.codegen.ksp.getKSAnnotationsByType +import com.squareup.anvil.compiler.codegen.ksp.isAnnotationPresent +import com.squareup.anvil.compiler.codegen.ksp.isInterface +import com.squareup.anvil.compiler.codegen.ksp.isLateInit +import com.squareup.anvil.compiler.codegen.ksp.isQualifier +import com.squareup.anvil.compiler.codegen.ksp.resolveKSClassDeclaration +import com.squareup.anvil.compiler.codegen.ksp.withJvmSuppressWildcardsIfNeeded import com.squareup.anvil.compiler.daggerDoubleCheckFqNameString +import com.squareup.anvil.compiler.daggerLazyClassName import com.squareup.anvil.compiler.daggerLazyFqName import com.squareup.anvil.compiler.injectFqName import com.squareup.anvil.compiler.internal.capitalize @@ -18,8 +38,11 @@ import com.squareup.anvil.compiler.internal.reference.allSuperTypeClassReference import com.squareup.anvil.compiler.internal.reference.argumentAt import com.squareup.anvil.compiler.internal.reference.asClassName import com.squareup.anvil.compiler.internal.reference.generateClassName +import com.squareup.anvil.compiler.internal.requireRawType +import com.squareup.anvil.compiler.internal.unwrappedTypes import com.squareup.anvil.compiler.internal.withJvmSuppressWildcardsIfNeeded import com.squareup.anvil.compiler.jvmFieldFqName +import com.squareup.anvil.compiler.providerClassName import com.squareup.anvil.compiler.providerFqName import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FunSpec @@ -27,8 +50,16 @@ import com.squareup.kotlinpoet.ParameterizedTypeName import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import com.squareup.kotlinpoet.TypeName import com.squareup.kotlinpoet.asClassName +import com.squareup.kotlinpoet.ksp.TypeParameterResolver +import com.squareup.kotlinpoet.ksp.toAnnotationSpec +import com.squareup.kotlinpoet.ksp.toClassName +import com.squareup.kotlinpoet.ksp.toTypeName +import com.squareup.kotlinpoet.ksp.toTypeParameterResolver import dagger.Lazy +import dagger.assisted.Assisted import dagger.internal.ProviderOfLazy +import org.jetbrains.kotlin.name.FqName +import javax.inject.Inject import javax.inject.Provider internal fun TypeName.wrapInProvider(): ParameterizedTypeName { @@ -82,6 +113,57 @@ private fun ParameterReference.toConstructorParameter( ) } +@JvmName("mapToConstructorParametersKsp") +internal fun List.mapToConstructorParameters( + typeParameterResolver: TypeParameterResolver, +): List { + return fold(listOf()) { acc, callableReference -> + acc + callableReference.toConstructorParameter(callableReference.name!!.asString().uniqueParameterName(acc), typeParameterResolver) + } +} + +@OptIn(KspExperimental::class) +private fun KSValueParameter.toConstructorParameter( + uniqueName: String, + typeParameterResolver: TypeParameterResolver, +): ConstructorParameter { + val type = type.resolve() + val paramTypeName = type.toTypeName(typeParameterResolver) + val rawType = paramTypeName.requireRawType() + + val isWrappedInProvider = rawType == providerClassName + val isWrappedInLazy = rawType == daggerLazyClassName + val isLazyWrappedInProvider = isWrappedInProvider && + (paramTypeName.unwrappedTypes.first().requireRawType()) == daggerLazyClassName + + val typeName = when { + isLazyWrappedInProvider -> paramTypeName.unwrappedTypes.first().unwrappedTypes.first() + isWrappedInProvider || isWrappedInLazy -> paramTypeName.unwrappedTypes.first() + else -> paramTypeName + }.withJvmSuppressWildcardsIfNeeded(this, type) + + val assistedAnnotation = getKSAnnotationsByType(Assisted::class) + .singleOrNull() + + val assistedIdentifier = getAnnotationsByType(Assisted::class) + .singleOrNull() + ?.value + .orEmpty() + + return ConstructorParameter( + name = uniqueName, + originalName = name!!.asString(), + typeName = typeName, + providerTypeName = typeName.wrapInProvider(), + lazyTypeName = typeName.wrapInLazy(), + isWrappedInProvider = isWrappedInProvider, + isWrappedInLazy = isWrappedInLazy, + isLazyWrappedInProvider = isLazyWrappedInProvider, + isAssisted = assistedAnnotation != null, + assistedIdentifier = assistedIdentifier, + ) +} + internal fun FunSpec.Builder.addMemberInjection( memberInjectParameters: List, instanceName: String, @@ -147,6 +229,60 @@ private fun ClassReference.declaredMemberInjectParameters( } } +/** + * Returns all member-injected parameters for the receiver class *and any superclasses*. + * + * Order is important. Dagger expects the properties of the most-upstream class to be listed first + * in a factory's constructor. + * + * Given the hierarchy: + * Impl -> Middle -> Base + * The order of dependencies in `Impl_Factory`'s constructor should be: + * Base -> Middle -> Impl + */ +internal fun KSClassDeclaration.memberInjectParameters(): List { + // TODO can we use getAllProperties() after https://github.com/google/ksp/issues/1619? + return sequenceOf(asType(emptyList())) + .plus(getAllSuperTypes()) + .mapNotNull { + it.resolveKSClassDeclaration() + } + .filterNot { + it.isInterface() + } + .toList() + .foldRight(listOf()) { classDeclaration, acc -> + acc + classDeclaration.declaredMemberInjectParameters(acc, this) + } +} + +/** + * @param superParameters injected parameters from any super-classes, regardless of whether they're + * overridden by the receiver class + * @return the member-injected parameters for this class only, not including any super-classes + */ +private fun KSClassDeclaration.declaredMemberInjectParameters( + superParameters: List, + implementingClass: KSClassDeclaration, +): List { + val implementingType = implementingClass.asType(emptyList()) + return getDeclaredProperties() + .filter { + it.isAnnotationPresent() || + it.setter?.isAnnotationPresent() == true + } + .filter { it.getVisibility() != Visibility.PRIVATE } + .fold(listOf()) { acc, property -> + val uniqueName = property.simpleName.asString().uniqueParameterName(superParameters, acc) + acc + property.toMemberInjectParameter( + uniqueName = uniqueName, + declaringClass = this@declaredMemberInjectParameters, + implementingType = implementingType, + implementingClass = implementingClass, + ) + } +} + /** * Converts the parameter list to comma separated argument list that can be used to call other * functions, e.g. @@ -296,13 +432,119 @@ private fun MemberPropertyReference.toMemberInjectParameter( ) } +@OptIn(KspExperimental::class) +private fun KSPropertyDeclaration.toMemberInjectParameter( + uniqueName: String, + declaringClass: KSClassDeclaration, + implementingType: KSType, + implementingClass: KSClassDeclaration, +): MemberInjectParameter { + if ( + !isLateInit() && + !isAnnotationPresent() && + setter?.isAnnotationPresent() != true + ) { + // Technically this works with Anvil and we could remove this check. But we prefer consistency + // with Dagger. + throw KspAnvilException( + message = "Dagger does not support injection into private fields. Either use a " + + "'lateinit var' or '@JvmField'.", + node = this, + ) + } + + val originalName = simpleName.asString() + val classParams = implementingClass.typeParameters.toTypeParameterResolver() + val resolvedType = asMemberOf(implementingType) + // TODO do we want to convert function types to lambdas? + val propertyTypeName = resolvedType.toTypeName(classParams) + val rawType = propertyTypeName.requireRawType() + + val isWrappedInProvider = rawType == providerClassName + val isWrappedInLazy = rawType == daggerLazyClassName + val isLazyWrappedInProvider = isWrappedInProvider && + (propertyTypeName.unwrappedTypes.first().requireRawType()) == daggerLazyClassName + + val unwrappedType = when { + isLazyWrappedInProvider -> propertyTypeName.unwrappedTypes.first().unwrappedTypes.first() + isWrappedInProvider || isWrappedInLazy -> propertyTypeName.unwrappedTypes.first() + else -> propertyTypeName + } + + val typeName = unwrappedType.withJvmSuppressWildcardsIfNeeded(this, resolvedType) + + val resolvedTypeName = if ((resolvedType.declaration as? KSClassDeclaration)?.typeParameters.orEmpty().isNotEmpty()) { + unwrappedType.requireRawType() + .optionallyParameterizedByNames( + unwrappedType.unwrappedTypes, + ) + .withJvmSuppressWildcardsIfNeeded(this, resolvedType) + } else { + null + } + + val assistedAnnotation = getAnnotationsByType(Assisted::class) + .singleOrNull() + + val assistedIdentifier = assistedAnnotation + ?.value + .orEmpty() + + val implementingClassName = declaringClass + .toClassName() + val memberInjectorClassName = implementingClassName + .generateClassName(separator = "_", suffix = "_MembersInjector") + .simpleNames + .joinToString(".") + + val memberInjectorClass = ClassName( + implementingClassName.packageName, + memberInjectorClassName, + ) + + val isSetterInjected = this.setter?.isAnnotationPresent() == true + + // setter delegates require a "set" prefix for their inject function + val accessName = if (isSetterInjected) { + "set${originalName.capitalize()}" + } else { + originalName + } + + val qualifierAnnotations = annotations + .filter { it.isQualifier() } + .map { it.toAnnotationSpec() } + .toList() + + val providerTypeName = typeName.wrapInProvider() + + return MemberInjectParameter( + name = uniqueName, + originalName = originalName, + typeName = typeName, + providerTypeName = providerTypeName, + lazyTypeName = typeName.wrapInLazy(), + isWrappedInProvider = isWrappedInProvider, + isWrappedInLazy = isWrappedInLazy, + isLazyWrappedInProvider = isLazyWrappedInProvider, + isAssisted = assistedAnnotation != null, + assistedIdentifier = assistedIdentifier, + memberInjectorClassName = memberInjectorClass, + isSetterInjected = isSetterInjected, + accessName = accessName, + qualifierAnnotationSpecs = qualifierAnnotations, + injectedFieldSignature = FqName(qualifiedName!!.asString()), + resolvedProviderTypeName = resolvedTypeName?.wrapInProvider() ?: providerTypeName, + ) +} + private fun TypeReference.isGenericExcludingTypeAliases(): Boolean { // A TypeReference for 'typealias StringList = List would still show up as generic but // would have no unwrapped types available (String). return isGenericType() && unwrappedTypes.isNotEmpty() } -private fun ClassName.optionallyParameterizedByNames( +internal fun ClassName.optionallyParameterizedByNames( typeNames: List, ): TypeName { return if (typeNames.isEmpty()) { diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/InjectConstructorFactoryGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/InjectConstructorFactoryCodeGen.kt similarity index 51% rename from compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/InjectConstructorFactoryGenerator.kt rename to compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/InjectConstructorFactoryCodeGen.kt index 5378072a5..16b6e7ff1 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/InjectConstructorFactoryGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/InjectConstructorFactoryCodeGen.kt @@ -1,21 +1,31 @@ package com.squareup.anvil.compiler.codegen.dagger import com.google.auto.service.AutoService +import com.google.devtools.ksp.processing.Resolver +import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.processing.SymbolProcessorProvider +import com.google.devtools.ksp.symbol.KSAnnotated +import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSFunctionDeclaration +import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker import com.squareup.anvil.compiler.api.AnvilContext import com.squareup.anvil.compiler.api.CodeGenerator import com.squareup.anvil.compiler.api.GeneratedFile import com.squareup.anvil.compiler.api.createGeneratedFile import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator import com.squareup.anvil.compiler.codegen.injectConstructor +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider +import com.squareup.anvil.compiler.codegen.ksp.injectConstructors +import com.squareup.anvil.compiler.codegen.ksp.isAnnotationPresent import com.squareup.anvil.compiler.injectFqName -import com.squareup.anvil.compiler.internal.asClassName -import com.squareup.anvil.compiler.internal.buildFile +import com.squareup.anvil.compiler.internal.createAnvilSpec import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.MemberFunctionReference import com.squareup.anvil.compiler.internal.reference.asClassName import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferences import com.squareup.anvil.compiler.internal.reference.generateClassName -import com.squareup.anvil.compiler.internal.safePackageString +import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier.OVERRIDE @@ -23,64 +33,135 @@ import com.squareup.kotlinpoet.KModifier.PRIVATE import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec +import com.squareup.kotlinpoet.TypeVariableName import com.squareup.kotlinpoet.asClassName import com.squareup.kotlinpoet.jvm.jvmStatic +import com.squareup.kotlinpoet.ksp.toClassName +import com.squareup.kotlinpoet.ksp.toTypeParameterResolver +import com.squareup.kotlinpoet.ksp.toTypeVariableName +import com.squareup.kotlinpoet.ksp.writeTo import dagger.internal.Factory import org.jetbrains.kotlin.descriptors.ModuleDescriptor import org.jetbrains.kotlin.psi.KtFile import java.io.File +import javax.inject.Inject -@AutoService(CodeGenerator::class) -internal class InjectConstructorFactoryGenerator : PrivateCodeGenerator() { - +object InjectConstructorFactoryCodeGen : AnvilApplicabilityChecker { override fun isApplicable(context: AnvilContext) = context.generateFactories - override fun generateCodePrivate( - codeGenDir: File, - module: ModuleDescriptor, - projectFiles: Collection, - ) { - projectFiles - .classAndInnerClassReferences(module) - .forEach { clazz -> - clazz.constructors - .injectConstructor() - ?.takeIf { it.isAnnotatedWith(injectFqName) } - ?.let { - generateFactoryClass(codeGenDir, clazz, it) + internal class KspGenerator( + override val env: SymbolProcessorEnvironment, + ) : AnvilSymbolProcessor() { + @AutoService(SymbolProcessorProvider::class) + class Provider : AnvilSymbolProcessorProvider(InjectConstructorFactoryCodeGen, ::KspGenerator) + + override fun processChecked(resolver: Resolver): List { + resolver.injectConstructors() + .forEach { (_, constructor) -> + if (!constructor.isAnnotationPresent()) { + // Only generating @Inject constructors + return@forEach } - } + + generateFactoryClass(constructor) + .writeTo( + env.codeGenerator, + aggregating = false, + originatingKSFiles = listOf(constructor.containingFile!!), + ) + } + + return emptyList() + } + + private fun generateFactoryClass( + constructor: KSFunctionDeclaration, + ): FileSpec { + val clazz = constructor.parentDeclaration as KSClassDeclaration + val constructorParameters = constructor.parameters.mapToConstructorParameters( + clazz.typeParameters.toTypeParameterResolver(), + ) + val memberInjectParameters = clazz.memberInjectParameters() + val typeParameters = clazz.typeParameters.map { it.toTypeVariableName() } + + return generateFactoryClass( + injectedClassName = clazz.toClassName(), + typeParameters = typeParameters, + constructorParameters = constructorParameters, + memberInjectParameters = memberInjectParameters, + ) + } } - private fun generateFactoryClass( - codeGenDir: File, - clazz: ClassReference.Psi, - constructor: MemberFunctionReference.Psi, - ): GeneratedFile { - val classId = clazz.generateClassName(suffix = "_Factory") + @AutoService(CodeGenerator::class) + internal class Embedded : PrivateCodeGenerator() { - val packageName = classId.packageFqName.safePackageString() - val className = classId.relativeClassName.asString() + override fun isApplicable(context: AnvilContext) = InjectConstructorFactoryCodeGen.isApplicable( + context, + ) - val constructorParameters = constructor.parameters.mapToConstructorParameters() - val memberInjectParameters = clazz.memberInjectParameters() + override fun generateCodePrivate( + codeGenDir: File, + module: ModuleDescriptor, + projectFiles: Collection, + ) { + projectFiles + .classAndInnerClassReferences(module) + .forEach { clazz -> + clazz.constructors + .injectConstructor() + ?.takeIf { it.isAnnotatedWith(injectFqName) } + ?.let { + generateFactoryClass(codeGenDir, clazz, it) + } + } + } - val allParameters = constructorParameters + memberInjectParameters + private fun generateFactoryClass( + codeGenDir: File, + clazz: ClassReference.Psi, + constructor: MemberFunctionReference.Psi, + ): GeneratedFile { + val constructorParameters = constructor.parameters.mapToConstructorParameters() + val memberInjectParameters = clazz.memberInjectParameters() + val typeParameters = clazz.typeParameters.map { it.typeVariableName } + + val spec = generateFactoryClass( + injectedClassName = clazz.asClassName(), + typeParameters = typeParameters, + constructorParameters = constructorParameters, + memberInjectParameters = memberInjectParameters, + ) + + return createGeneratedFile(codeGenDir, spec.packageName, spec.name, spec.toString()) + } + } + + private fun generateFactoryClass( + injectedClassName: ClassName, + typeParameters: List, + constructorParameters: List, + memberInjectParameters: List, + ): FileSpec { + val generatedClassName = injectedClassName.generateClassName(suffix = "_Factory") + + val packageName = injectedClassName.packageName - val typeParameters = clazz.typeParameters + val allParameters = constructorParameters + memberInjectParameters - val factoryClass = classId.asClassName() - val factoryClassParameterized = factoryClass.optionallyParameterizedBy(typeParameters) - val classType = clazz.asClassName().optionallyParameterizedBy(typeParameters) + val factoryClassParameterized = generatedClassName.optionallyParameterizedByNames( + typeParameters, + ) + val classType = injectedClassName.optionallyParameterizedByNames(typeParameters) - val content = FileSpec.buildFile(packageName, className) { + val spec = FileSpec.createAnvilSpec(packageName, generatedClassName.simpleName) { val canGenerateAnObject = allParameters.isEmpty() && typeParameters.isEmpty() val classBuilder = if (canGenerateAnObject) { - TypeSpec.objectBuilder(factoryClass) + TypeSpec.objectBuilder(generatedClassName) } else { - TypeSpec.classBuilder(factoryClass) + TypeSpec.classBuilder(generatedClassName) } - typeParameters.forEach { classBuilder.addTypeVariable(it.typeVariableName) } + typeParameters.forEach { classBuilder.addTypeVariable(it) } classBuilder .addSuperinterface(Factory::class.asClassName().parameterizedBy(classType)) @@ -135,7 +216,7 @@ internal class InjectConstructorFactoryGenerator : PrivateCodeGenerator() { .jvmStatic() .apply { if (typeParameters.isNotEmpty()) { - addTypeVariables(typeParameters.map { it.typeVariableName }) + addTypeVariables(typeParameters) } if (canGenerateAnObject) { addStatement("return this") @@ -163,7 +244,7 @@ internal class InjectConstructorFactoryGenerator : PrivateCodeGenerator() { .jvmStatic() .apply { if (typeParameters.isNotEmpty()) { - addTypeVariables(typeParameters.map { it.typeVariableName }) + addTypeVariables(typeParameters) } constructorParameters.forEach { parameter -> addParameter( @@ -189,6 +270,6 @@ internal class InjectConstructorFactoryGenerator : PrivateCodeGenerator() { .let { addType(it) } } - return createGeneratedFile(codeGenDir, packageName, className, content) + return spec } } diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorCodeGen.kt similarity index 97% rename from compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorGenerator.kt rename to compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorCodeGen.kt index 13b9d3aee..87dafe0f5 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MapKeyCreatorCodeGen.kt @@ -1,6 +1,8 @@ package com.squareup.anvil.compiler.codegen.dagger import com.google.auto.service.AutoService +import com.google.devtools.ksp.KspExperimental +import com.google.devtools.ksp.getAnnotationsByType import com.google.devtools.ksp.getDeclaredProperties import com.google.devtools.ksp.processing.Resolver import com.google.devtools.ksp.processing.SymbolProcessorEnvironment @@ -80,14 +82,15 @@ object MapKeyCreatorCodeGen : AnvilApplicabilityChecker { @AutoService(SymbolProcessorProvider::class) class Provider : AnvilSymbolProcessorProvider(MapKeyCreatorCodeGen, ::KspGenerator) + @OptIn(KspExperimental::class) override fun processChecked(resolver: Resolver): List { resolver.getSymbolsWithAnnotation(mapKeyFqName.asString()) .filterIsInstance() .filter { clazz -> - val mapKey = clazz.annotations.find { it.shortName.asString() == "MapKey" } + val mapKey = clazz.getAnnotationsByType(MapKey::class) + .singleOrNull() ?: return@filter false - val unwrapValue = mapKey.argumentAt("unwrapValue")?.value as? Boolean ?: true - return@filter !unwrapValue + return@filter !mapKey.unwrapValue } .forEach { clazz -> generateCreatorClass(clazz) diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MembersInjectorGenerator.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MembersInjectorCodeGen.kt similarity index 55% rename from compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MembersInjectorGenerator.kt rename to compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MembersInjectorCodeGen.kt index 58603651d..10775269c 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MembersInjectorGenerator.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/dagger/MembersInjectorCodeGen.kt @@ -1,15 +1,24 @@ package com.squareup.anvil.compiler.codegen.dagger import com.google.auto.service.AutoService +import com.google.devtools.ksp.isPrivate +import com.google.devtools.ksp.processing.Resolver +import com.google.devtools.ksp.processing.SymbolProcessorEnvironment +import com.google.devtools.ksp.processing.SymbolProcessorProvider +import com.google.devtools.ksp.symbol.KSAnnotated +import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSPropertyDeclaration +import com.squareup.anvil.compiler.api.AnvilApplicabilityChecker import com.squareup.anvil.compiler.api.AnvilContext import com.squareup.anvil.compiler.api.CodeGenerator import com.squareup.anvil.compiler.api.GeneratedFile import com.squareup.anvil.compiler.api.createGeneratedFile import com.squareup.anvil.compiler.codegen.PrivateCodeGenerator +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessor +import com.squareup.anvil.compiler.codegen.ksp.AnvilSymbolProcessorProvider import com.squareup.anvil.compiler.injectFqName -import com.squareup.anvil.compiler.internal.asClassName -import com.squareup.anvil.compiler.internal.buildFile import com.squareup.anvil.compiler.internal.capitalize +import com.squareup.anvil.compiler.internal.createAnvilSpec import com.squareup.anvil.compiler.internal.reference.ClassReference import com.squareup.anvil.compiler.internal.reference.Visibility import com.squareup.anvil.compiler.internal.reference.asClassName @@ -17,6 +26,7 @@ import com.squareup.anvil.compiler.internal.reference.classAndInnerClassReferenc import com.squareup.anvil.compiler.internal.reference.generateClassName import com.squareup.anvil.compiler.internal.safePackageString import com.squareup.kotlinpoet.AnnotationSpec +import com.squareup.kotlinpoet.ClassName import com.squareup.kotlinpoet.FileSpec import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.KModifier.OVERRIDE @@ -24,58 +34,115 @@ import com.squareup.kotlinpoet.KModifier.PRIVATE import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy import com.squareup.kotlinpoet.PropertySpec import com.squareup.kotlinpoet.TypeSpec +import com.squareup.kotlinpoet.TypeVariableName import com.squareup.kotlinpoet.asClassName import com.squareup.kotlinpoet.jvm.jvmStatic +import com.squareup.kotlinpoet.ksp.toClassName +import com.squareup.kotlinpoet.ksp.toTypeVariableName +import com.squareup.kotlinpoet.ksp.writeTo import dagger.MembersInjector import dagger.internal.InjectedFieldSignature import org.jetbrains.kotlin.descriptors.ModuleDescriptor import org.jetbrains.kotlin.psi.KtFile import java.io.File -@AutoService(CodeGenerator::class) -internal class MembersInjectorGenerator : PrivateCodeGenerator() { - +internal object MembersInjectorCodeGen : AnvilApplicabilityChecker { override fun isApplicable(context: AnvilContext) = context.generateFactories - override fun generateCodePrivate( - codeGenDir: File, - module: ModuleDescriptor, - projectFiles: Collection, - ) { - projectFiles - .classAndInnerClassReferences(module) - .filterNot { it.isInterface() } - .forEach { clazz -> - // Only generate a MembersInjector if the target class declares its own member-injected - // properties. If it does, then any properties from superclasses must be added as well - // (clazz.memberInjectParameters() will do this). - clazz.properties - .filter { it.visibility() != Visibility.PRIVATE } - .filter { it.isAnnotatedWith(injectFqName) } - .ifEmpty { return@forEach } - - generateMembersInjectorClass( - codeGenDir = codeGenDir, - clazz = clazz, - parameters = clazz.memberInjectParameters(), - ) - } + internal class KspGenerator( + override val env: SymbolProcessorEnvironment, + ) : AnvilSymbolProcessor() { + @AutoService(SymbolProcessorProvider::class) + class Provider : AnvilSymbolProcessorProvider(MembersInjectorCodeGen, ::KspGenerator) + + override fun processChecked(resolver: Resolver): List { + resolver.getSymbolsWithAnnotation(injectFqName.asString()) + .filterIsInstance() + .filterNot { it.isPrivate() } + .filter { it.parentDeclaration is KSClassDeclaration } + .groupBy { it.parentDeclaration as KSClassDeclaration } + .forEach { (clazz, _) -> + val typeParameters = clazz.typeParameters + .map { it.toTypeVariableName() } + val isGeneric = typeParameters.isNotEmpty() + + generateMembersInjectorClass( + origin = clazz.toClassName(), + isGeneric = isGeneric, + typeParameters = typeParameters, + parameters = clazz.memberInjectParameters(), + ) + .writeTo(env.codeGenerator, aggregating = false, listOf(clazz.containingFile!!)) + } + + return emptyList() + } + } + + @AutoService(CodeGenerator::class) + internal class Embedded : PrivateCodeGenerator() { + + override fun isApplicable(context: AnvilContext) = MembersInjectorCodeGen.isApplicable(context) + + override fun generateCodePrivate( + codeGenDir: File, + module: ModuleDescriptor, + projectFiles: Collection, + ) { + projectFiles + .classAndInnerClassReferences(module) + .filterNot { it.isInterface() } + .forEach { clazz -> + // Only generate a MembersInjector if the target class declares its own member-injected + // properties. If it does, then any properties from superclasses must be added as well + // (clazz.memberInjectParameters() will do this). + clazz.properties + .filter { it.visibility() != Visibility.PRIVATE } + .filter { it.isAnnotatedWith(injectFqName) } + .ifEmpty { return@forEach } + + generateMembersInjectorClass( + codeGenDir = codeGenDir, + clazz = clazz, + parameters = clazz.memberInjectParameters(), + ) + } + } + + private fun generateMembersInjectorClass( + codeGenDir: File, + clazz: ClassReference.Psi, + parameters: List, + ): GeneratedFile { + val isGeneric = clazz.isGenericClass() + val typeParameters = clazz.typeParameters + .map { it.typeVariableName } + + val spec = generateMembersInjectorClass( + origin = clazz.asClassName(), + isGeneric = isGeneric, + typeParameters = typeParameters, + parameters = parameters, + ) + + return createGeneratedFile(codeGenDir, spec.packageName, spec.name, spec.toString()) + } } private fun generateMembersInjectorClass( - codeGenDir: File, - clazz: ClassReference.Psi, + origin: ClassName, + isGeneric: Boolean, + typeParameters: List, parameters: List, - ): GeneratedFile { - val classId = clazz.generateClassName(suffix = "_MembersInjector") - val packageName = classId.packageFqName.safePackageString() - val className = classId.relativeClassName.asString() - val typeParameters = clazz.typeParameters + ): FileSpec { + val memberInjectorClass = origin.generateClassName(suffix = "_MembersInjector") + val packageName = memberInjectorClass.packageName.safePackageString() + val fileName = memberInjectorClass.simpleName - val classType = clazz.asClassName() + val classType = origin .let { - if (clazz.isGenericClass()) { - it.parameterizedBy(typeParameters.map { typeParameter -> typeParameter.typeVariableName }) + if (isGeneric) { + it.parameterizedBy(typeParameters) } else { it } @@ -92,16 +159,14 @@ internal class MembersInjectorGenerator : PrivateCodeGenerator() { .joinToString() } - val memberInjectorClass = classId.asClassName() - - val content = FileSpec.buildFile(packageName, className) { + val spec = FileSpec.createAnvilSpec(packageName, fileName) { addType( TypeSpec .classBuilder(memberInjectorClass) .addSuperinterface(membersInjectorType) .apply { typeParameters.forEach { typeParameter -> - addTypeVariable(typeParameter.typeVariableName) + addTypeVariable(typeParameter) } primaryConstructor( FunSpec.constructorBuilder() @@ -137,7 +202,7 @@ internal class MembersInjectorGenerator : PrivateCodeGenerator() { .jvmStatic() .apply { typeParameters.forEach { typeParameter -> - addTypeVariable(typeParameter.typeVariableName) + addTypeVariable(typeParameter) } parameters.forEach { parameter -> @@ -165,7 +230,7 @@ internal class MembersInjectorGenerator : PrivateCodeGenerator() { .jvmStatic() .apply { typeParameters.forEach { typeParameter -> - addTypeVariable(typeParameter.typeVariableName) + addTypeVariable(typeParameter) } // Don't add @InjectedFieldSignature when it's calling a setter method @@ -191,6 +256,6 @@ internal class MembersInjectorGenerator : PrivateCodeGenerator() { ) } - return createGeneratedFile(codeGenDir, packageName, className, content) + return spec } } diff --git a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt index 589a4ab08..e73a9d0d0 100644 --- a/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt +++ b/compiler/src/main/java/com/squareup/anvil/compiler/codegen/ksp/KspUtil.kt @@ -1,9 +1,22 @@ package com.squareup.anvil.compiler.codegen.ksp +import com.google.devtools.ksp.KspExperimental +import com.google.devtools.ksp.isAnnotationPresent +import com.google.devtools.ksp.isConstructor +import com.google.devtools.ksp.processing.Resolver import com.google.devtools.ksp.symbol.ClassKind.ANNOTATION_CLASS import com.google.devtools.ksp.symbol.KSAnnotated import com.google.devtools.ksp.symbol.KSAnnotation import com.google.devtools.ksp.symbol.KSClassDeclaration +import com.google.devtools.ksp.symbol.KSFunctionDeclaration +import com.google.devtools.ksp.symbol.KSModifierListOwner +import com.google.devtools.ksp.symbol.KSType +import com.google.devtools.ksp.symbol.KSTypeAlias +import com.google.devtools.ksp.symbol.Modifier +import com.squareup.anvil.compiler.assistedInjectFqName +import com.squareup.anvil.compiler.injectFqName +import com.squareup.kotlinpoet.TypeName +import com.squareup.kotlinpoet.jvm.jvmSuppressWildcards import kotlin.reflect.KClass /** @@ -36,4 +49,85 @@ internal fun KSAnnotated.getKSAnnotationsByQualifiedName( internal fun KSAnnotated.isAnnotationPresent(qualifiedName: String): Boolean = getKSAnnotationsByQualifiedName(qualifiedName).firstOrNull() != null +internal inline fun KSAnnotated.isAnnotationPresent(): Boolean { + return isAnnotationPresent(T::class) +} + +internal fun KSAnnotated.isAnnotationPresent(klass: KClass<*>): Boolean { + val fqcn = klass.qualifiedName ?: return false + return getKSAnnotationsByQualifiedName(fqcn).firstOrNull() != null +} + internal fun KSClassDeclaration.isAnnotationClass(): Boolean = classKind == ANNOTATION_CLASS +internal fun KSModifierListOwner.isLateInit(): Boolean = Modifier.LATEINIT in modifiers + +@OptIn(KspExperimental::class) +internal fun TypeName.withJvmSuppressWildcardsIfNeeded( + annotatedReference: KSAnnotated, + type: KSType, +): TypeName { + // If the parameter is annotated with @JvmSuppressWildcards, then add the annotation + // to our type so that this information is forwarded when our Factory is compiled. + val hasJvmSuppressWildcards = annotatedReference.isAnnotationPresent(JvmSuppressWildcards::class) + + // Add the @JvmSuppressWildcards annotation even for simple generic return types like + // Set. This avoids some edge cases where Dagger chokes. + val isGenericType = (type.declaration as? KSClassDeclaration)?.typeParameters?.isNotEmpty() == true + + // Same for functions. + val isFunctionType = type.isFunctionType + + return when { + hasJvmSuppressWildcards || isGenericType -> this.jvmSuppressWildcards() + isFunctionType -> this.jvmSuppressWildcards() + else -> this + } +} + +/** + * Resolves the [KSClassDeclaration] for this type, including following typealiases as needed. + */ +internal tailrec fun KSType.resolveKSClassDeclaration(): KSClassDeclaration? { + return when (val declaration = declaration) { + is KSClassDeclaration -> declaration + is KSTypeAlias -> declaration.type.resolve().resolveKSClassDeclaration() + else -> error("Unrecognized declaration type: $declaration") + } +} + +/** + * Returns a sequence of all `@Inject` and `@AssistedInject` constructors visible to this resolver + */ +internal fun Resolver.injectConstructors(): List> { + return getSymbolsWithAnnotation(injectFqName.asString()) + .plus(getSymbolsWithAnnotation(assistedInjectFqName.asString())) + .filterIsInstance() + .filter { it.isConstructor() } + .groupBy { + it.parentDeclaration as KSClassDeclaration + } + .mapNotNull { (clazz, constructors) -> + if (constructors.size != 1) { + val constructorsErrorMessage = constructors.joinToString { constructor -> + constructor.annotations.joinToString(" ", postfix = " ") { annotation -> + "@${annotation.annotationType.resolve().declaration.qualifiedName!!.asString()}" + } + .replace("@javax.inject.Inject", "@Inject") + + clazz.qualifiedName!!.asString() + constructor.parameters.joinToString( + ", ", + prefix = "(", + postfix = ")", + ) { param -> + param.type.resolve().resolveKSClassDeclaration()!!.simpleName.getShortName() + } + } + throw KspAnvilException( + node = clazz, + message = "Type ${clazz.qualifiedName!!.asString()} may only contain one injected " + + "constructor. Found: [$constructorsErrorMessage]", + ) + } + + clazz to constructors[0] + } +} diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/TestUtils.kt b/compiler/src/test/java/com/squareup/anvil/compiler/TestUtils.kt index 1e341cd66..891ce8524 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/TestUtils.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/TestUtils.kt @@ -1,15 +1,19 @@ package com.squareup.anvil.compiler +import com.google.common.collect.Lists.cartesianProduct import com.google.common.truth.ComparableSubject import com.google.common.truth.Truth.assertThat import com.squareup.anvil.annotations.MergeComponent import com.squareup.anvil.compiler.api.CodeGenerator import com.squareup.anvil.compiler.internal.capitalize import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode.Embedded +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode.Ksp import com.squareup.anvil.compiler.internal.testing.compileAnvil import com.squareup.anvil.compiler.internal.testing.generatedClassesString import com.squareup.anvil.compiler.internal.testing.packageName import com.squareup.anvil.compiler.internal.testing.use +import com.tschuchort.compiletesting.CompilationResult import com.tschuchort.compiletesting.JvmCompilationResult import com.tschuchort.compiletesting.KotlinCompilation.ExitCode import com.tschuchort.compiletesting.KotlinCompilation.ExitCode.COMPILATION_ERROR @@ -188,3 +192,38 @@ internal fun JvmCompilationResult.walkGeneratedFiles(mode: AnvilCompilationMode) return dirToSearch.walkTopDown() .filter { it.isFile && it.extension == "kt" } } + +/** + * Parameters for configuring [AnvilCompilationMode] and whether to run a full test run or not. + */ +internal fun useDaggerAndKspParams( + embeddedCreator: () -> Embedded? = { Embedded() }, + kspCreator: () -> Ksp? = { Ksp() }, +): Collection { + return cartesianProduct( + listOf( + isFullTestRun(), + false, + ), + listOfNotNull( + embeddedCreator(), + kspCreator(), + ), + ).mapNotNull { (useDagger, mode) -> + if (useDagger == true && mode is Ksp) { + // TODO Dagger is not supported with KSP in Anvil's tests yet + null + } else { + arrayOf(useDagger, mode) + } + }.distinct() +} + +/** In any failing compilation in KSP, it always prints this error line first. */ +private const val KSP_ERROR_HEADER = "e: Error occurred in KSP, check log for detail" + +internal fun CompilationResult.compilationErrorLine(): String { + return messages + .lineSequence() + .first { it.startsWith("e:") && KSP_ERROR_HEADER !in it } +} diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt index 7b65d7649..2e51eb719 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedFactoryGeneratorTest.kt @@ -4,8 +4,11 @@ import com.google.common.truth.Truth.assertThat import com.squareup.anvil.compiler.WARNINGS_AS_ERRORS import com.squareup.anvil.compiler.assistedService import com.squareup.anvil.compiler.assistedServiceFactory +import com.squareup.anvil.compiler.compilationErrorLine import com.squareup.anvil.compiler.daggerModule1 import com.squareup.anvil.compiler.internal.testing.AnvilCompilation +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode.Embedded import com.squareup.anvil.compiler.internal.testing.createInstance import com.squareup.anvil.compiler.internal.testing.factoryClass import com.squareup.anvil.compiler.internal.testing.getPropertyValue @@ -14,10 +17,11 @@ import com.squareup.anvil.compiler.internal.testing.isStatic import com.squareup.anvil.compiler.internal.testing.moduleFactoryClass import com.squareup.anvil.compiler.internal.testing.use import com.squareup.anvil.compiler.isError -import com.squareup.anvil.compiler.isFullTestRun +import com.squareup.anvil.compiler.useDaggerAndKspParams import com.tschuchort.compiletesting.JvmCompilationResult import com.tschuchort.compiletesting.KotlinCompilation.ExitCode.OK import org.intellij.lang.annotations.Language +import org.junit.Assume.assumeTrue import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Parameterized @@ -27,14 +31,13 @@ import javax.inject.Provider @RunWith(Parameterized::class) class AssistedFactoryGeneratorTest( private val useDagger: Boolean, + private val mode: AnvilCompilationMode, ) { companion object { - @Parameters(name = "Use Dagger: {0}") + @Parameters(name = "Use Dagger: {0}, mode: {1}") @JvmStatic - fun useDagger(): Collection { - return listOf(isFullTestRun(), false).distinct() - } + fun params() = useDaggerAndKspParams() } @Test fun `an implementation for a factory class is generated`() { @@ -1639,8 +1642,7 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory ) { assertThat(exitCode).isError() assertThat( - messages.lines() - .first { it.startsWith("e:") } + compilationErrorLine() .removeParametersAndSort() .removeNullabilityAnnotations(), ).contains( @@ -1677,8 +1679,7 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory ) { assertThat(exitCode).isError() assertThat( - messages.lines() - .first { it.startsWith("e:") } + compilationErrorLine() .removeParametersAndSort() .removeNullabilityAnnotations(), ).contains( @@ -1712,8 +1713,7 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory ) { assertThat(exitCode).isError() assertThat( - messages.lines() - .first { it.startsWith("e:") } + compilationErrorLine() .removeParametersAndSort() .removeNullabilityAnnotations(), ).contains( @@ -1813,6 +1813,8 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory } @Test fun `assisted injections can be provided with a qualifier`() { + // TODO enable on KSP after BindingModuleGenerator supports KSP + assumeTrue(mode is Embedded) compile( """ package com.squareup.test @@ -2064,6 +2066,7 @@ public final class AssistedServiceFactory_Impl implements AssistedServiceFactory .configureAnvil( enableDaggerAnnotationProcessor = useDagger, generateDaggerFactories = !useDagger, + mode = mode, ) } diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt index 1639ccb86..00b15b07b 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/AssistedInjectGeneratorTest.kt @@ -3,12 +3,14 @@ package com.squareup.anvil.compiler.dagger import com.google.common.truth.Truth.assertThat import com.squareup.anvil.compiler.WARNINGS_AS_ERRORS import com.squareup.anvil.compiler.assistedService +import com.squareup.anvil.compiler.compilationErrorLine +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode import com.squareup.anvil.compiler.internal.testing.compileAnvil import com.squareup.anvil.compiler.internal.testing.factoryClass import com.squareup.anvil.compiler.internal.testing.invokeGet import com.squareup.anvil.compiler.internal.testing.isStatic import com.squareup.anvil.compiler.isError -import com.squareup.anvil.compiler.isFullTestRun +import com.squareup.anvil.compiler.useDaggerAndKspParams import com.tschuchort.compiletesting.JvmCompilationResult import org.intellij.lang.annotations.Language import org.junit.Test @@ -20,14 +22,13 @@ import javax.inject.Provider @RunWith(Parameterized::class) class AssistedInjectGeneratorTest( private val useDagger: Boolean, + private val mode: AnvilCompilationMode, ) { companion object { - @Parameters(name = "Use Dagger: {0}") + @Parameters(name = "Use Dagger: {0}, mode: {1}") @JvmStatic - fun useDagger(): Collection { - return listOf(isFullTestRun(), false).distinct() - } + fun params() = useDaggerAndKspParams() } @Test fun `a factory class is generated with one assisted parameter`() { @@ -632,8 +633,7 @@ public final class AssistedService_Factory { ) { assertThat(exitCode).isError() assertThat( - messages.lines() - .first { it.startsWith("e:") } + compilationErrorLine() .removeParametersAndSort(), ).contains( "Type com.squareup.test.AssistedService may only contain one injected constructor. " + @@ -662,8 +662,7 @@ public final class AssistedService_Factory { ) { assertThat(exitCode).isError() assertThat( - messages.lines() - .first { it.startsWith("e:") } + compilationErrorLine() .removeParametersAndSort(), ).contains( "Type com.squareup.test.AssistedService may only contain one injected constructor. " + @@ -681,6 +680,7 @@ public final class AssistedService_Factory { enableDaggerAnnotationProcessor = useDagger, generateDaggerFactories = !useDagger, allWarningsAsErrors = WARNINGS_AS_ERRORS, + mode = mode, block = block, ) } diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/DaggerTestUtils.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/DaggerTestUtils.kt index 37c0bdf63..288733b0e 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/DaggerTestUtils.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/DaggerTestUtils.kt @@ -1,5 +1,7 @@ package com.squareup.anvil.compiler.dagger +private const val KSP_PREFIX = "e: [ksp]" + /** * Removes parameters of the functions in a String like * ``` @@ -12,6 +14,9 @@ package com.squareup.anvil.compiler.dagger * Dagger also doesn't guarantee any order of functions. */ internal fun String.removeParametersAndSort(): String { + if (startsWith(KSP_PREFIX)) { + return removePrefix(KSP_PREFIX).removeParametersAndSort() + } val start = 1 + (indexOf('[').takeIf { it >= 0 } ?: return this) val end = indexOfLast { it == ']' }.takeIf { it >= 0 } ?: return this diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/InjectConstructorFactoryGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/InjectConstructorFactoryGeneratorTest.kt index 4dff9c490..cd5c8fc04 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/InjectConstructorFactoryGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/InjectConstructorFactoryGeneratorTest.kt @@ -1,14 +1,16 @@ package com.squareup.anvil.compiler.dagger import com.google.common.truth.Truth.assertThat +import com.squareup.anvil.compiler.compilationErrorLine import com.squareup.anvil.compiler.injectClass +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode import com.squareup.anvil.compiler.internal.testing.compileAnvil import com.squareup.anvil.compiler.internal.testing.createInstance import com.squareup.anvil.compiler.internal.testing.factoryClass import com.squareup.anvil.compiler.internal.testing.getPropertyValue import com.squareup.anvil.compiler.internal.testing.isStatic import com.squareup.anvil.compiler.isError -import com.squareup.anvil.compiler.isFullTestRun +import com.squareup.anvil.compiler.useDaggerAndKspParams import com.tschuchort.compiletesting.JvmCompilationResult import com.tschuchort.compiletesting.KotlinCompilation.ExitCode.OK import dagger.Lazy @@ -25,14 +27,13 @@ import javax.inject.Provider @RunWith(Parameterized::class) class InjectConstructorFactoryGeneratorTest( private val useDagger: Boolean, + private val mode: AnvilCompilationMode, ) { companion object { - @Parameters(name = "Use Dagger: {0}") + @Parameters(name = "Use Dagger: {0}, mode: {1}") @JvmStatic - fun useDagger(): Collection { - return listOf(isFullTestRun(), false).distinct() - } + fun params() = useDaggerAndKspParams() } @Test fun `a factory class is generated for an inject constructor without arguments`() { @@ -2552,8 +2553,7 @@ public class InjectClass_Factory>( ) { assertThat(exitCode).isError() assertThat( - messages.lines() - .first { it.startsWith("e:") } + compilationErrorLine() .removeParametersAndSort(), ).contains( "Type com.squareup.test.InjectClass may only contain one injected constructor. " + @@ -2745,6 +2745,7 @@ public final class InjectClass_Factory implements Factory { // Many constructor parameters are unused. allWarningsAsErrors = false, previousCompilationResult = previousCompilationResult, + mode = mode, block = block, ) } diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt index d3912faf4..997905e2c 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MapKeyCreatorGeneratorTest.kt @@ -1,14 +1,11 @@ package com.squareup.anvil.compiler.dagger -import com.google.common.collect.Lists.cartesianProduct import com.google.common.truth.Truth.assertThat import com.squareup.anvil.compiler.WARNINGS_AS_ERRORS import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode -import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode.Embedded -import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode.Ksp import com.squareup.anvil.compiler.internal.testing.compileAnvil import com.squareup.anvil.compiler.internal.testing.isStatic -import com.squareup.anvil.compiler.isFullTestRun +import com.squareup.anvil.compiler.useDaggerAndKspParams import com.tschuchort.compiletesting.JvmCompilationResult import org.intellij.lang.annotations.Language import org.jetbrains.kotlin.descriptors.runtime.components.tryLoadClass @@ -27,19 +24,7 @@ class MapKeyCreatorGeneratorTest( companion object { @Parameters(name = "Use Dagger: {0}, mode: {1}") @JvmStatic - fun useDagger(): Collection { - return cartesianProduct( - listOf(isFullTestRun(), false), - listOf(Embedded(), Ksp()), - ).mapNotNull { (useDagger, mode) -> - if (useDagger == true && mode is Ksp) { - // TODO Dagger is not supported with KSP in Anvil's tests yet - null - } else { - arrayOf(useDagger, mode) - } - }.distinct() - } + fun params() = useDaggerAndKspParams() } @Test fun `a creator class is generated`() { diff --git a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MembersInjectorGeneratorTest.kt b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MembersInjectorGeneratorTest.kt index 27d8a2c75..7e7784942 100644 --- a/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MembersInjectorGeneratorTest.kt +++ b/compiler/src/test/java/com/squareup/anvil/compiler/dagger/MembersInjectorGeneratorTest.kt @@ -4,6 +4,7 @@ import com.google.common.truth.Truth.assertThat import com.squareup.anvil.compiler.WARNINGS_AS_ERRORS import com.squareup.anvil.compiler.injectClass import com.squareup.anvil.compiler.internal.capitalize +import com.squareup.anvil.compiler.internal.testing.AnvilCompilationMode import com.squareup.anvil.compiler.internal.testing.compileAnvil import com.squareup.anvil.compiler.internal.testing.createInstance import com.squareup.anvil.compiler.internal.testing.getPropertyValue @@ -11,8 +12,8 @@ import com.squareup.anvil.compiler.internal.testing.getValue import com.squareup.anvil.compiler.internal.testing.isStatic import com.squareup.anvil.compiler.internal.testing.membersInjector import com.squareup.anvil.compiler.isError -import com.squareup.anvil.compiler.isFullTestRun import com.squareup.anvil.compiler.nestedInjectClass +import com.squareup.anvil.compiler.useDaggerAndKspParams import com.tschuchort.compiletesting.JvmCompilationResult import com.tschuchort.compiletesting.KotlinCompilation.ExitCode.OK import dagger.Lazy @@ -33,14 +34,13 @@ import kotlin.test.assertFailsWith @RunWith(Parameterized::class) class MembersInjectorGeneratorTest( private val useDagger: Boolean, + private val mode: AnvilCompilationMode, ) { companion object { - @Parameters(name = "Use Dagger: {0}") + @Parameters(name = "Use Dagger: {0}, mode: {1}") @JvmStatic - fun useDagger(): Collection { - return listOf(isFullTestRun(), false).distinct() - } + fun params() = useDaggerAndKspParams() } @Test fun `a factory class is generated for a field injection`() { @@ -2513,6 +2513,7 @@ public final class InjectClass_MembersInjector implements MembersInject generateDaggerFactories = !useDagger, allWarningsAsErrors = WARNINGS_AS_ERRORS, previousCompilationResult = previousCompilationResult, + mode = mode, block = block, ) }