Skip to content

Commit

Permalink
fix tests, move sqseventops test
Browse files Browse the repository at this point in the history
  • Loading branch information
kenoir committed Dec 20, 2024
1 parent 8b5f393 commit c9699ee
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 48 deletions.
9 changes: 6 additions & 3 deletions common/lambda/src/main/scala/weco/lambda/Downstream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@ trait Downstream {
}

class SNSDownstream(snsConfig: SNSConfig) extends Downstream {
private val msgSender = new SNSMessageSender(
protected val msgSender = new SNSMessageSender(
snsClient = SnsClient.builder().build(),
snsConfig = snsConfig,
subject = "Sent from relation_embedder"
)

override def notify(workId: String): Try[Unit] = Try(msgSender.send(workId))
override def notify[T](batch: T)(implicit encoder: Encoder[T]): Try[Unit] = msgSender.sendT(batch)
override def notify[T](batch: T)(implicit encoder: Encoder[T]): Try[Unit] =
msgSender.sendT(batch)
}

object STDIODownstream extends Downstream {
override def notify(workId: String): Try[Unit] = Try(println(workId))
override def notify[T](t: T)(implicit encoder: Encoder[T]): Try[Unit] = Try(println(toJson(t)))
override def notify[T](t: T)(implicit encoder: Encoder[T]): Try[Unit] = Try(
println(toJson(t))
)
}

sealed trait DownstreamTarget
Expand Down
13 changes: 10 additions & 3 deletions common/lambda/src/main/scala/weco/lambda/SQSEventOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import ujson.Value
import weco.json.JsonUtil.fromJson

import scala.collection.JavaConverters._
import scala.reflect.ClassTag

object SQSEventOps {

Expand All @@ -21,14 +22,20 @@ object SQSEventOps {
* - a `Message`, which is the actual content we want
*/
implicit class ExtractTFromSqsEvent(event: SQSEvent) {
def extract[T]()(implicit decoder: Decoder[T]) =
def extract[T]()(implicit decoder: Decoder[T], ct: ClassTag[T]) =
event.getRecords.asScala.toList.flatMap(extractFromMessage[T](_))

private def extractFromMessage[T](
message: SQSMessage
)(implicit decoder: Decoder[T]): Option[T] =
)(implicit decoder: Decoder[T], ct: ClassTag[T]): Option[T] =
ujson.read(message.getBody).obj.get("Message").flatMap {
value: Value => fromJson[T](value.str).toOption
value: Value =>
{
ct.runtimeClass match {
case c if c == classOf[String] => Some(value.str.asInstanceOf[T])
case _ => fromJson[T](value.str).toOption
}
}
}
}
}
68 changes: 68 additions & 0 deletions common/lambda/src/test/scala/weco/lambda/SQSEventOpsTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package weco.lambda

import com.amazonaws.services.lambda.runtime.events.SQSEvent
import com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers

import scala.collection.JavaConverters._
import weco.json.JsonUtil._

class SQSEventOpsTest extends AnyFunSpec with Matchers {

import SQSEventOps._

describe("Using the implicit class SQSEventOps") {
it("extracts values from an SQSEvent where the message is a String") {
val fakeMessage = new SQSMessage()
fakeMessage.setBody("{\"Message\":\"A/C\"}")
val fakeSQSEvent = new SQSEvent()
fakeSQSEvent.setRecords(List(fakeMessage).asJava)

val paths = fakeSQSEvent.extract[String]()

paths shouldBe List("A/C")
}

case class TestMessage(value: String)

it("extracts values from an SQSEvent where the message is a JSON object") {
val fakeMessage = new SQSMessage()
fakeMessage.setBody("{\"Message\":\"{\\\"value\\\": \\\"A/C\\\"}\"}")
val fakeSQSEvent = new SQSEvent()
fakeSQSEvent.setRecords(List(fakeMessage).asJava)

val paths = fakeSQSEvent.extract[TestMessage]()

paths shouldBe List(TestMessage("A/C"))
}

it("extracts multiple values from an SQSEvent") {
val fakeMessage1 = new SQSMessage()
fakeMessage1.setBody("{\"Message\":\"A/C\"}")
val fakeMessage2 = new SQSMessage()
fakeMessage2.setBody("{\"Message\":\"A/E\"}")
val fakeSQSEvent = new SQSEvent()
fakeSQSEvent.setRecords(List(fakeMessage1, fakeMessage2).asJava)

val paths = fakeSQSEvent.extract[String]()

paths shouldBe List("A/C", "A/E")
}

it(
"extracts values from an SQSEvent where the message is a JSON object with multiple fields, only taking the ones we want"
) {
val fakeMessage = new SQSMessage()
fakeMessage.setBody(
"{\"Message\":\"{\\\"value\\\": \\\"A/C\\\", \\\"other\\\": \\\"D/E\\\"}\"}"
)
val fakeSQSEvent = new SQSEvent()
fakeSQSEvent.setRecords(List(fakeMessage).asJava)

val paths = fakeSQSEvent.extract[TestMessage]()

paths shouldBe List(TestMessage("A/C"))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package weco.pipeline.batcher

import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.util.Try
import org.apache.pekko.{Done, NotUsed}
import org.apache.pekko.stream.scaladsl._
import org.apache.pekko.stream.Materializer
import software.amazon.awssdk.services.sqs.model.{Message => SQSMessage}
import grizzled.slf4j.Logging
import weco.messaging.MessageSender
import weco.messaging.sns.NotificationMessage
import weco.messaging.sqs.SQSStream
import weco.typesafe.Runnable
Expand All @@ -17,11 +14,10 @@ case class Batch(rootPath: String, selectors: List[Selector])

class BatcherWorkerService[MsgDestination](
msgStream: SQSStream[NotificationMessage],
msgSender: MessageSender[MsgDestination],
pathsProcessor: PathsProcessor,
flushInterval: FiniteDuration,
maxProcessedPaths: Int
)(implicit ec: ExecutionContext, materializer: Materializer)
)(implicit ec: ExecutionContext)
extends Runnable
with Logging {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ object Main extends WellcomeTypesafeApp {

new BatcherWorkerService(
msgStream = SQSBuilder.buildSQSStream[NotificationMessage](config),
msgSender = SNSBuilder
.buildSNSMessageSender(config, subject = "Sent from batcher"),
flushInterval =
config.requireInt("batcher.flush_interval_minutes").minutes,
pathsProcessor = new PathsProcessor(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
package weco.pipeline.batcher

import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.{Failure, Try}
import io.circe.Encoder
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.concurrent.{Eventually, IntegrationPatience}
import org.scalatest.funspec.AnyFunSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.concurrent.{Eventually, IntegrationPatience}
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.{Seconds, Span}
import io.circe.Encoder
import weco.fixtures.TestWith
import weco.pekko.fixtures.Pekko
import weco.json.JsonUtil._
import weco.lambda.Downstream
import weco.messaging.fixtures.SQS
import weco.messaging.fixtures.SQS.QueuePair
import weco.messaging.memory.MemoryMessageSender
import weco.messaging.sns.NotificationMessage
import weco.json.JsonUtil._
import SQS.QueuePair
import weco.pekko.fixtures.Pekko

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.{Failure, Try}

class BatcherWorkerServiceTest
extends AnyFunSpec
Expand Down Expand Up @@ -147,19 +149,24 @@ class BatcherWorkerServiceTest
withSQSStream[NotificationMessage, R](queuePair.queue) {
msgStream =>
val msgSender = new MessageSender(brokenPaths)
val memoryDownstream = new MemoryDownstream(msgSender)
val pathsProcessor = new PathsProcessor(
downstream = memoryDownstream,
maxBatchSize = maxBatchSize
)
val workerService = new BatcherWorkerService[String](
msgStream = msgStream,
msgSender = msgSender,
flushInterval = flushInterval,
maxProcessedPaths = 1000,
maxBatchSize = maxBatchSize
pathsProcessor = pathsProcessor
)
workerService.run()
testWith((queuePair, msgSender))
}
}
}


class MessageSender(brokenPaths: Set[String] = Set.empty)
extends MemoryMessageSender {
override def sendT[T](t: T)(implicit encoder: Encoder[T]): Try[Unit] = {
Expand All @@ -170,4 +177,9 @@ class BatcherWorkerServiceTest
super.sendT(t)
}
}

class MemoryDownstream(messageSender: MessageSender) extends Downstream {
override def notify(workId: Path): Try[Unit] = ???
override def notify[T](batch: T)(implicit encoder: Encoder[T]): Try[Unit] = messageSender.sendT(batch)
}
}

This file was deleted.

0 comments on commit c9699ee

Please sign in to comment.