diff --git a/mirai-core/src/commonMain/kotlin/message/protocol/encode/MessageLengthVerifier.kt b/mirai-core/src/commonMain/kotlin/message/protocol/encode/MessageLengthVerifier.kt new file mode 100644 index 00000000000..ed441e6373d --- /dev/null +++ b/mirai-core/src/commonMain/kotlin/message/protocol/encode/MessageLengthVerifier.kt @@ -0,0 +1,203 @@ +/* + * Copyright 2019-2022 Mamoe Technologies and contributors. + * + * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. + * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. + * + * https://github.com/mamoe/mirai/blob/dev/LICENSE + */ + +package net.mamoe.mirai.internal.message.protocol.encode + +import net.mamoe.mirai.contact.Contact +import net.mamoe.mirai.contact.Group +import net.mamoe.mirai.contact.getMember +import net.mamoe.mirai.contact.nameCardOrNick +import net.mamoe.mirai.message.data.* +import net.mamoe.mirai.message.data.visitor.MessageVisitor +import net.mamoe.mirai.message.data.visitor.RecursiveMessageVisitor +import net.mamoe.mirai.message.data.visitor.accept +import net.mamoe.mirai.utils.toLongUnsigned +import kotlin.jvm.JvmStatic +import kotlin.math.log10 + + +/** + * An object that stores these length properties. + * @see MessageLengthVerifier + */ +internal interface MessageLengthTokens { + val uiChars: Long + val uiImages: Long + val uiForwardNodes: Long +// val protocolTotal: Long + + companion object { + val comparator: Comparator = + compareBy { it.uiChars } + .then(compareBy { it.uiImages }) + .then(compareBy { it.uiForwardNodes }) +// .then(compareBy { it.protocolTotal }) + } +} + +/** + * A [MessageVisitor] that calculates and verifies length of a message. + * + * Can be applied to any [Message] by calling[Message.accept] passing this visitor. + * + * Applying to [ForwardMessage] verifies the [ForwardMessage] itself and its nodes recursively. + * + * Use properties from [MessageLengthTokens] to retrieve calculation results. + * @since 2.14 + */ +internal interface MessageLengthVerifier : MessageVisitor, MessageLengthTokens { + val nestedVerifiers: List + + fun isLengthValid(): Boolean +} + +/** + * Gets an [MessageLengthVerifier] with specified configuration [lengthTokens] and [context]. + */ +internal fun MessageLengthVerifier( + context: Contact?, + lengthTokens: MessageLengthLimits, + failfast: Boolean, +): MessageLengthVerifier { + return MessageLengthVerifierImpl(context, lengthTokens, failfast) +} + +/** + * Specifies length limits for [MessageLengthVerifier] + * @sample net.mamoe.mirai.internal.message.protocol.encode.MessageLengthVerifierTest + */ +internal class MessageLengthLimits( + override val uiChars: Long = 5000,// 5000 chars + override val uiImages: Long = 50, + override val uiForwardNodes: Long = 200, // 200 nodes for each forward message + +// override val protocolTotal: Long = 1 * 1000 * 1000, // 1 MB +) : MessageLengthTokens { + companion object { + @JvmStatic + val DEFAULT = MessageLengthLimits() + } +} + +/////////////////////////////////////////////////////////////////////////// +// IMPLEMENTATION +/////////////////////////////////////////////////////////////////////////// + +private inline operator fun MessageLengthTokens.compareTo(other: MessageLengthTokens): Int = + MessageLengthTokens.comparator.compare(this, other) + +internal class MessageLengthVerifierImpl constructor( + private val context: Contact?, + private val limits: MessageLengthLimits, + private val failfast: Boolean, +) : RecursiveMessageVisitor(), MessageLengthVerifier { + override val nestedVerifiers: MutableList = mutableListOf() + private var hasInvalidNested: Boolean = false + + /** + * 展示在 UI 的字符长度. + * @see MessageLengthLimits.uiChars + */ + override var uiChars: Long = 0 + private set + + override var uiImages: Long = 0 + private set + + override var uiForwardNodes: Long = 0 + private set + + override fun isFinished(): Boolean { + if (!failfast) return false + return !isLengthValid() + } + + override fun isLengthValid(): Boolean = this <= limits && !hasInvalidNested + + override fun visitPlainText(message: PlainText, data: Unit) { + uiChars += message.content.length + } + + override fun visitAt(message: At, data: Unit) { + val length = message.displayInGroup() + ?: message.target.numberOfDigitsInDecimal + uiChars += length + 1 // + `@` + } + + private fun At.displayInGroup(): Long? { + return if (context is Group) { + context.getMember(target)?.nameCardOrNick?.length?.toLongUnsigned() + } else { + null + } + } + + override fun visitAtAll(message: AtAll, data: Unit) { + uiChars += message.content.length + } + + override fun visitFace(message: Face, data: Unit) { + uiChars += 4 + } + + override fun visitImage(message: Image, data: Unit) { + uiImages++ +// protocolTotal = TypicalMessageSize.image + } + + override fun visitFlashImage(message: FlashImage, data: Unit) { + visitImage(message.image, data) + } + + override fun visitQuoteReply(message: QuoteReply, data: Unit) { + message.source.originalMessage.accept(this) + } + + override fun visitForwardMessage(message: ForwardMessage, data: Unit) { + val nested = MessageLengthVerifierImpl(context, limits, failfast) + nestedVerifiers.add(nested) + + for (node in message.nodeList) { + if (nested.isFinished()) break + nested.visitForwardMessageNode(node) + } + + if (!nested.isLengthValid()) { + hasInvalidNested = true + } + } + + fun visitForwardMessageNode(node: ForwardMessage.INode) { + uiForwardNodes++ + node.messageChain.accept(this) + } + + companion object { + val Long.numberOfDigitsInDecimal: Long + get() = if (this == 0L) 1 else 1 + log10(this.toDouble()).toLong() + } +} + + +//private object TypicalMessageSize { +// @Serializable +// private class Elements( +// @ProtoNumber(1) val elements: List +// ) +// +// val image: Long = kotlin.run { +// val elems = MessageProtocolFacade.encode( +// chain = Image("{01E9451B-70ED-EAE3-B37C-101F1EEBF5B5}.jpg").toMessageChain(), +// messageTarget = null, +// withGeneralFlags = false, +// isForward = false +// ) +// ProtoBuf.encodeToByteArray(Elements.serializer(), Elements(elems)).size.toLongUnsigned() +// } +//} \ No newline at end of file diff --git a/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt b/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt index ad4c1ec2e9e..9943c7ad6d6 100644 --- a/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt +++ b/mirai-core/src/commonMain/kotlin/message/protocol/impl/ForwardMessageProtocol.kt @@ -63,7 +63,7 @@ internal class ForwardMessageProtocol : MessageProtocol() { forward: ForwardMessage, contact: AbstractContact ) { - check(forward.nodeList.size <= 200) { + check(forward.nodeList.size <= MAX_NODES) { throw MessageTooLargeException( contact, forward, forward, "ForwardMessage allows up to 200 nodes, but found ${forward.nodeList.size}" @@ -71,4 +71,8 @@ internal class ForwardMessageProtocol : MessageProtocol() { } } } + + companion object { + const val MAX_NODES : Int = 200 + } } \ No newline at end of file diff --git a/mirai-core/src/commonTest/kotlin/message/protocol/encode/MessageLengthVerifierTest.kt b/mirai-core/src/commonTest/kotlin/message/protocol/encode/MessageLengthVerifierTest.kt new file mode 100644 index 00000000000..4846196e133 --- /dev/null +++ b/mirai-core/src/commonTest/kotlin/message/protocol/encode/MessageLengthVerifierTest.kt @@ -0,0 +1,258 @@ +/* + * Copyright 2019-2022 Mamoe Technologies and contributors. + * + * 此源代码的使用受 GNU AFFERO GENERAL PUBLIC LICENSE version 3 许可证的约束, 可以在以下链接找到该许可证. + * Use of this source code is governed by the GNU AGPLv3 license that can be found through the following link. + * + * https://github.com/mamoe/mirai/blob/dev/LICENSE + */ + +package net.mamoe.mirai.internal.message.protocol.encode + +import net.mamoe.mirai.contact.MemberPermission +import net.mamoe.mirai.internal.MockBot +import net.mamoe.mirai.internal.message.protocol.encode.MessageLengthVerifierImpl.Companion.numberOfDigitsInDecimal +import net.mamoe.mirai.internal.notice.processors.GroupExtensions +import net.mamoe.mirai.internal.test.AbstractTest +import net.mamoe.mirai.message.data.* +import net.mamoe.mirai.message.data.visitor.accept +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +/** + * @see MessageLengthVerifier + */ +internal class MessageLengthVerifierTest : AbstractTest(), GroupExtensions { + private val bot = MockBot { } + private val group = bot.addGroup(123L, 1111L).apply { + addMember(1111L, permission = MemberPermission.OWNER) + } + + private companion object { + private val fiveThousandChars = PlainText("a".repeat(5000)) + private val anImage = + // [mirai:image:{9D97AF44-0007-5F86-6567-C0BD3F6A5C5C}.gif, width=211, height=243, size=108292, type=GIF, isEmoji=true] + Image("{9D97AF44-0007-5F86-6567-C0BD3F6A5C5C}.gif") { // guess what it is? + width = 211 + height = 243 + size = 108292 + type = ImageType.GIF + isEmoji = true + } + } + + @Test + fun numberOfDigitsInDecimal() { + assertEquals(1, 0L.numberOfDigitsInDecimal) + assertEquals(2, 10L.numberOfDigitsInDecimal) + assertEquals(2, 11L.numberOfDigitsInDecimal) + assertEquals(4, 1000L.numberOfDigitsInDecimal) + assertEquals(4, 1001L.numberOfDigitsInDecimal) + assertEquals(4, 1999L.numberOfDigitsInDecimal) + } + + @Test + fun `initial values`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = false) + + assertEquals(0, verifier.uiChars) + assertEquals(0, verifier.uiImages) + assertEquals(0, verifier.uiForwardNodes) + assertTrue(verifier.isLengthValid()) + } + + @Test + fun `count PlainTexts`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = false) + + val chain = messageChainOf(fiveThousandChars) + chain.accept(verifier) + assertEquals(5000, verifier.uiChars) + assertEquals(0, verifier.uiImages) + assertEquals(0, verifier.uiForwardNodes) + } + + @Test + fun `count Images`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = false) + + val chain = messageChainOf(anImage, anImage) + chain.accept(verifier) + assertEquals(0, verifier.uiChars) + assertEquals(2, verifier.uiImages) + assertEquals(0, verifier.uiForwardNodes) + } + + @Test + fun `count Images and PlainTexts`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = false) + + val chain = messageChainOf(fiveThousandChars, anImage) + chain.accept(verifier) + assertEquals(5000, verifier.uiChars) + assertEquals(1, verifier.uiImages) + assertEquals(0, verifier.uiForwardNodes) + } + + @Test + fun failfast() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = true) + + val chain = messageChainOf(fiveThousandChars, anImage, fiveThousandChars, fiveThousandChars, anImage) + chain.accept(verifier) + assertEquals(fiveThousandChars.content.length * 2L, verifier.uiChars) + assertEquals(1, verifier.uiImages) + assertEquals(0, verifier.uiForwardNodes) + assertFalse(verifier.isLengthValid()) + } + + @Test + fun `disable failfast`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = false) + + val chain = messageChainOf(fiveThousandChars, anImage, fiveThousandChars, fiveThousandChars, anImage) + chain.accept(verifier) + assertEquals(fiveThousandChars.content.length * 3L, verifier.uiChars) + assertEquals(2, verifier.uiImages) + assertEquals(0, verifier.uiForwardNodes) + assertFalse(verifier.isLengthValid()) + } + + @Test + fun `limits are inclusive`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = true) + + val chain = messageChainOf(fiveThousandChars, anImage) + chain.accept(verifier) + assertEquals(5000, verifier.uiChars) + assertEquals(1, verifier.uiImages) + assertEquals(0, verifier.uiForwardNodes) + assertTrue(verifier.isLengthValid()) + } + + + @Test + fun `count recursively ForwardMessage nodes`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = false) + + val chain = messageChainOf(buildForwardMessage(group) { + 1111 says fiveThousandChars + 1111 says anImage + 1111 says fiveThousandChars + 1111 says fiveThousandChars + 1111 says anImage + }) + + chain.accept(verifier) + assertEquals(fiveThousandChars.content.length * 3L, verifier.uiChars) + assertEquals(2, verifier.uiImages) + assertEquals(5, verifier.uiForwardNodes) + assertFalse(verifier.isLengthValid()) + } + + @Test + fun `count deeply recursively ForwardMessage nodes`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = false) + + val chain = messageChainOf(buildForwardMessage(group) { + 1111 says fiveThousandChars + 1111 says anImage + 1111 says fiveThousandChars + 1111 says fiveThousandChars + 1111 says anImage + + 1111 says buildForwardMessage(group) { + 1111 says fiveThousandChars + 1111 says anImage + 1111 says fiveThousandChars + 1111 says fiveThousandChars + 1111 says anImage + } + }) + + chain.accept(verifier) + assertEquals(fiveThousandChars.content.length * 3L * 2, verifier.uiChars) + assertEquals(2 * 2, verifier.uiImages) + assertEquals(6 + 5, verifier.uiForwardNodes) + assertFalse(verifier.isLengthValid()) + } + + @Test + fun `count deeply recursively ForwardMessage nodes failfast`() { + val limits = MessageLengthLimits( + uiChars = 5000, + uiImages = 50, + uiForwardNodes = 200, + ) + val verifier = MessageLengthVerifier(null, limits, failfast = true) + + val chain = messageChainOf(buildForwardMessage(group) { + 1111 says fiveThousandChars + 1111 says anImage + 1111 says fiveThousandChars + 1111 says fiveThousandChars + 1111 says anImage + + 1111 says buildForwardMessage(group) { + 1111 says fiveThousandChars + 1111 says anImage + 1111 says fiveThousandChars + 1111 says fiveThousandChars + 1111 says anImage + } + }) + + chain.accept(verifier) + assertEquals(fiveThousandChars.content.length * 1L, verifier.uiChars) + assertEquals(1, verifier.uiImages) + assertEquals(2, verifier.uiForwardNodes) + assertFalse(verifier.isLengthValid()) + } +} \ No newline at end of file