Skip to content

Commit

Permalink
Improve error when service is missing protocol
Browse files Browse the repository at this point in the history
resolves smithy-lang#2452
----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
djedward committed Apr 3, 2024
1 parent 09ba40e commit 66445a9
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,27 @@ import software.amazon.smithy.model.traits.Trait
import software.amazon.smithy.rust.codegen.core.smithy.CodegenContext

open class ProtocolLoader<T, C : CodegenContext>(private val supportedProtocols: ProtocolMap<T, C>) {
private fun formatProtocols(): String {
return supportedProtocols.keys.joinToString(
prefix = "\t",
separator = "\n\t"
)
}

fun protocolFor(
model: Model,
serviceShape: ServiceShape,
): Pair<ShapeId, ProtocolGeneratorFactory<T, C>> {
val protocols: MutableMap<ShapeId, Trait> = ServiceIndex.of(model).getProtocols(serviceShape)
if (protocols.isEmpty()) {
throw CodegenException("Service must have a protocol trait. Available protocols:\n${formatProtocols()}")
}

val matchingProtocols =
protocols.keys.mapNotNull { protocolId -> supportedProtocols[protocolId]?.let { protocolId to it } }
if (matchingProtocols.isEmpty()) {
throw CodegenException("No matching protocol — service offers: ${protocols.keys}. We offer: ${supportedProtocols.keys}")
val specified = protocols.keys.joinToString(", ")
throw CodegenException("Unable to find a matching protocol. Model specifies ${specified}, but must match an available protocol:\n${formatProtocols()}")
}
return matchingProtocols.first()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package software.amazon.smithy.rust.codegen.server.smithy.protocols

import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldContain
import org.junit.jupiter.api.Test
import software.amazon.smithy.aws.traits.protocols.AwsJson1_1Trait
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.rust.codegen.core.smithy.protocols.AwsJsonVersion
import software.amazon.smithy.rust.codegen.core.testutil.asSmithyModel


class ServerProtocolLoaderTest {
private val testModel =
"""
${"$"}version: "2"
namespace test
use aws.api#service
use aws.protocols#awsJson1_0
@awsJson1_0
@service(
sdkId: "Test",
arnNamespace: "test"
)
service TestService {
version: "2024-04-01"
}
""".asSmithyModel(smithyVersion = "2.0")

private val testModelNoProtocol =
"""
${"$"}version: "2"
namespace test
use aws.api#service
@service(
sdkId: "Test",
arnNamespace: "test"
)
service TestService {
version: "2024-04-01"
}
""".asSmithyModel(smithyVersion = "2.0")

@Test
fun `ensures protocols are matched`() {
val loader = ServerProtocolLoader(ServerProtocolLoader.DefaultProtocols)

val (shape, _) = loader.protocolFor(testModel, testModel.serviceShapes.first())

shape.name shouldBe "awsJson1_0"
}

@Test
fun `ensures unmatched service protocol fails`() {
val loader = ServerProtocolLoader(
mapOf(
RestJson1Trait.ID to
ServerRestJsonFactory(
additionalServerHttpBoundProtocolCustomizations =
listOf(
StreamPayloadSerializerCustomization(),
),
),
RestXmlTrait.ID to
ServerRestXmlFactory(
additionalServerHttpBoundProtocolCustomizations =
listOf(
StreamPayloadSerializerCustomization(),
),
),
AwsJson1_1Trait.ID to
ServerAwsJsonFactory(
AwsJsonVersion.Json11,
additionalServerHttpBoundProtocolCustomizations = listOf(StreamPayloadSerializerCustomization()),
),
)
)
val exception = shouldThrow<CodegenException> {
loader.protocolFor(testModel, testModel.serviceShapes.first())
}
exception.message shouldContain("Unable to find a matching protocol")
}

@Test
fun `ensures service without protocol fails`() {
val loader = ServerProtocolLoader(ServerProtocolLoader.DefaultProtocols)
val exception = shouldThrow<CodegenException> {
loader.protocolFor(testModelNoProtocol, testModelNoProtocol.serviceShapes.first())
}
exception.message shouldContain("Service must have a protocol trait")
}
}

0 comments on commit 66445a9

Please sign in to comment.