diff --git a/common/lambda/src/main/scala/weco/lambda/Downstream.scala b/common/lambda/src/main/scala/weco/lambda/Downstream.scala index 2c17918bea..fed30504d5 100644 --- a/common/lambda/src/main/scala/weco/lambda/Downstream.scala +++ b/common/lambda/src/main/scala/weco/lambda/Downstream.scala @@ -13,19 +13,22 @@ trait Downstream { } class SNSDownstream(snsConfig: SNSConfig) extends Downstream { - private val msgSender = new SNSMessageSender( + protected val msgSender = new SNSMessageSender( snsClient = SnsClient.builder().build(), snsConfig = snsConfig, subject = "Sent from relation_embedder" ) override def notify(workId: String): Try[Unit] = Try(msgSender.send(workId)) - override def notify[T](batch: T)(implicit encoder: Encoder[T]): Try[Unit] = msgSender.sendT(batch) + override def notify[T](batch: T)(implicit encoder: Encoder[T]): Try[Unit] = + msgSender.sendT(batch) } object STDIODownstream extends Downstream { override def notify(workId: String): Try[Unit] = Try(println(workId)) - override def notify[T](t: T)(implicit encoder: Encoder[T]): Try[Unit] = Try(println(toJson(t))) + override def notify[T](t: T)(implicit encoder: Encoder[T]): Try[Unit] = Try( + println(toJson(t)) + ) } sealed trait DownstreamTarget diff --git a/common/lambda/src/main/scala/weco/lambda/SQSEventOps.scala b/common/lambda/src/main/scala/weco/lambda/SQSEventOps.scala index 052a1203c4..aa2bbe2d70 100644 --- a/common/lambda/src/main/scala/weco/lambda/SQSEventOps.scala +++ b/common/lambda/src/main/scala/weco/lambda/SQSEventOps.scala @@ -7,6 +7,7 @@ import ujson.Value import weco.json.JsonUtil.fromJson import scala.collection.JavaConverters._ +import scala.reflect.ClassTag object SQSEventOps { @@ -21,14 +22,20 @@ object SQSEventOps { * - a `Message`, which is the actual content we want */ implicit class ExtractTFromSqsEvent(event: SQSEvent) { - def extract[T]()(implicit decoder: Decoder[T]) = + def extract[T]()(implicit decoder: Decoder[T], ct: ClassTag[T]) = event.getRecords.asScala.toList.flatMap(extractFromMessage[T](_)) private def extractFromMessage[T]( message: SQSMessage - )(implicit decoder: Decoder[T]): Option[T] = + )(implicit decoder: Decoder[T], ct: ClassTag[T]): Option[T] = ujson.read(message.getBody).obj.get("Message").flatMap { - value: Value => fromJson[T](value.str).toOption + value: Value => + { + ct.runtimeClass match { + case c if c == classOf[String] => Some(value.str.asInstanceOf[T]) + case _ => fromJson[T](value.str).toOption + } + } } } } diff --git a/common/lambda/src/test/scala/weco/lambda/SQSEventOpsTest.scala b/common/lambda/src/test/scala/weco/lambda/SQSEventOpsTest.scala new file mode 100644 index 0000000000..6805242e4f --- /dev/null +++ b/common/lambda/src/test/scala/weco/lambda/SQSEventOpsTest.scala @@ -0,0 +1,68 @@ +package weco.lambda + +import com.amazonaws.services.lambda.runtime.events.SQSEvent +import com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers + +import scala.collection.JavaConverters._ +import weco.json.JsonUtil._ + +class SQSEventOpsTest extends AnyFunSpec with Matchers { + + import SQSEventOps._ + + describe("Using the implicit class SQSEventOps") { + it("extracts values from an SQSEvent where the message is a String") { + val fakeMessage = new SQSMessage() + fakeMessage.setBody("{\"Message\":\"A/C\"}") + val fakeSQSEvent = new SQSEvent() + fakeSQSEvent.setRecords(List(fakeMessage).asJava) + + val paths = fakeSQSEvent.extract[String]() + + paths shouldBe List("A/C") + } + + case class TestMessage(value: String) + + it("extracts values from an SQSEvent where the message is a JSON object") { + val fakeMessage = new SQSMessage() + fakeMessage.setBody("{\"Message\":\"{\\\"value\\\": \\\"A/C\\\"}\"}") + val fakeSQSEvent = new SQSEvent() + fakeSQSEvent.setRecords(List(fakeMessage).asJava) + + val paths = fakeSQSEvent.extract[TestMessage]() + + paths shouldBe List(TestMessage("A/C")) + } + + it("extracts multiple values from an SQSEvent") { + val fakeMessage1 = new SQSMessage() + fakeMessage1.setBody("{\"Message\":\"A/C\"}") + val fakeMessage2 = new SQSMessage() + fakeMessage2.setBody("{\"Message\":\"A/E\"}") + val fakeSQSEvent = new SQSEvent() + fakeSQSEvent.setRecords(List(fakeMessage1, fakeMessage2).asJava) + + val paths = fakeSQSEvent.extract[String]() + + paths shouldBe List("A/C", "A/E") + } + + it( + "extracts values from an SQSEvent where the message is a JSON object with multiple fields, only taking the ones we want" + ) { + val fakeMessage = new SQSMessage() + fakeMessage.setBody( + "{\"Message\":\"{\\\"value\\\": \\\"A/C\\\", \\\"other\\\": \\\"D/E\\\"}\"}" + ) + val fakeSQSEvent = new SQSEvent() + fakeSQSEvent.setRecords(List(fakeMessage).asJava) + + val paths = fakeSQSEvent.extract[TestMessage]() + + paths shouldBe List(TestMessage("A/C")) + } + } +} diff --git a/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/BatcherWorkerService.scala b/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/BatcherWorkerService.scala index 64dbcca792..35b243aa82 100644 --- a/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/BatcherWorkerService.scala +++ b/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/BatcherWorkerService.scala @@ -2,13 +2,10 @@ package weco.pipeline.batcher import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import scala.util.Try import org.apache.pekko.{Done, NotUsed} import org.apache.pekko.stream.scaladsl._ -import org.apache.pekko.stream.Materializer import software.amazon.awssdk.services.sqs.model.{Message => SQSMessage} import grizzled.slf4j.Logging -import weco.messaging.MessageSender import weco.messaging.sns.NotificationMessage import weco.messaging.sqs.SQSStream import weco.typesafe.Runnable @@ -17,11 +14,10 @@ case class Batch(rootPath: String, selectors: List[Selector]) class BatcherWorkerService[MsgDestination]( msgStream: SQSStream[NotificationMessage], - msgSender: MessageSender[MsgDestination], pathsProcessor: PathsProcessor, flushInterval: FiniteDuration, maxProcessedPaths: Int -)(implicit ec: ExecutionContext, materializer: Materializer) +)(implicit ec: ExecutionContext) extends Runnable with Logging { diff --git a/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/Main.scala b/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/Main.scala index e50d47804b..6b3d90c475 100644 --- a/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/Main.scala +++ b/pipeline/relation_embedder/batcher/src/main/scala/weco/pipeline/batcher/Main.scala @@ -20,8 +20,6 @@ object Main extends WellcomeTypesafeApp { new BatcherWorkerService( msgStream = SQSBuilder.buildSQSStream[NotificationMessage](config), - msgSender = SNSBuilder - .buildSNSMessageSender(config, subject = "Sent from batcher"), flushInterval = config.requireInt("batcher.flush_interval_minutes").minutes, pathsProcessor = new PathsProcessor( diff --git a/pipeline/relation_embedder/batcher/src/test/scala/weco/pipeline/batcher/BatcherWorkerServiceTest.scala b/pipeline/relation_embedder/batcher/src/test/scala/weco/pipeline/batcher/BatcherWorkerServiceTest.scala index 228f2ad98a..86a38218d2 100644 --- a/pipeline/relation_embedder/batcher/src/test/scala/weco/pipeline/batcher/BatcherWorkerServiceTest.scala +++ b/pipeline/relation_embedder/batcher/src/test/scala/weco/pipeline/batcher/BatcherWorkerServiceTest.scala @@ -1,21 +1,23 @@ package weco.pipeline.batcher -import scala.concurrent.duration._ -import scala.concurrent.ExecutionContext.Implicits.global -import scala.util.{Failure, Try} +import io.circe.Encoder +import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.concurrent.{Eventually, IntegrationPatience} import org.scalatest.funspec.AnyFunSpec import org.scalatest.matchers.should.Matchers -import org.scalatest.concurrent.{Eventually, IntegrationPatience} -import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.time.{Seconds, Span} -import io.circe.Encoder import weco.fixtures.TestWith -import weco.pekko.fixtures.Pekko +import weco.json.JsonUtil._ +import weco.lambda.Downstream import weco.messaging.fixtures.SQS +import weco.messaging.fixtures.SQS.QueuePair import weco.messaging.memory.MemoryMessageSender import weco.messaging.sns.NotificationMessage -import weco.json.JsonUtil._ -import SQS.QueuePair +import weco.pekko.fixtures.Pekko + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ +import scala.util.{Failure, Try} class BatcherWorkerServiceTest extends AnyFunSpec @@ -147,12 +149,16 @@ class BatcherWorkerServiceTest withSQSStream[NotificationMessage, R](queuePair.queue) { msgStream => val msgSender = new MessageSender(brokenPaths) + val memoryDownstream = new MemoryDownstream(msgSender) + val pathsProcessor = new PathsProcessor( + downstream = memoryDownstream, + maxBatchSize = maxBatchSize + ) val workerService = new BatcherWorkerService[String]( msgStream = msgStream, - msgSender = msgSender, flushInterval = flushInterval, maxProcessedPaths = 1000, - maxBatchSize = maxBatchSize + pathsProcessor = pathsProcessor ) workerService.run() testWith((queuePair, msgSender)) @@ -160,6 +166,7 @@ class BatcherWorkerServiceTest } } + class MessageSender(brokenPaths: Set[String] = Set.empty) extends MemoryMessageSender { override def sendT[T](t: T)(implicit encoder: Encoder[T]): Try[Unit] = { @@ -170,4 +177,9 @@ class BatcherWorkerServiceTest super.sendT(t) } } + + class MemoryDownstream(messageSender: MessageSender) extends Downstream { + override def notify(workId: Path): Try[Unit] = ??? + override def notify[T](batch: T)(implicit encoder: Encoder[T]): Try[Unit] = messageSender.sendT(batch) + } } diff --git a/pipeline/relation_embedder/batcher/src/test/scala/weco/pipeline/batcher/SQSEventOpsTest.scala b/pipeline/relation_embedder/batcher/src/test/scala/weco/pipeline/batcher/SQSEventOpsTest.scala deleted file mode 100644 index 6fa114349c..0000000000 --- a/pipeline/relation_embedder/batcher/src/test/scala/weco/pipeline/batcher/SQSEventOpsTest.scala +++ /dev/null @@ -1,24 +0,0 @@ -package weco.pipeline.batcher - -import com.amazonaws.services.lambda.runtime.events.SQSEvent -import com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage -import org.scalatest.funspec.AnyFunSpec -import org.scalatest.matchers.should.Matchers -import scala.collection.JavaConverters._ - -class SQSEventOpsTest extends AnyFunSpec with Matchers { - import lib.SQSEventOps._ - - describe("Using the implicit class SQSEventOps") { - it("extracts paths from an SQSEvent") { - val fakeMessage = new SQSMessage() - fakeMessage.setBody("{\"Message\":\"A/C\"}") - val fakeSQSEvent = new SQSEvent() - fakeSQSEvent.setRecords(List(fakeMessage).asJava) - - val paths = fakeSQSEvent.extractPaths - - paths shouldBe List("A/C") - } - } -}