diff --git a/benchmark/BDN.benchmark/Resp/RespIntegerReadBenchmarks.cs b/benchmark/BDN.benchmark/Resp/RespIntegerReadBenchmarks.cs index 5d76ab983c..6c4ff0d10e 100644 --- a/benchmark/BDN.benchmark/Resp/RespIntegerReadBenchmarks.cs +++ b/benchmark/BDN.benchmark/Resp/RespIntegerReadBenchmarks.cs @@ -10,13 +10,13 @@ namespace BDN.benchmark.Resp public unsafe class RespIntegerReadBenchmarks { [Benchmark] - [ArgumentsSource(nameof(SignedInt32EncodedValues))] - public int ReadInt32(AsciiTestCase testCase) + [ArgumentsSource(nameof(LengthHeaderValues))] + public int ReadLengthHeader(AsciiTestCase testCase) { fixed (byte* inputPtr = testCase.Bytes) { var start = inputPtr; - RespReadUtils.ReadInt(out var value, ref start, start + testCase.Bytes.Length); + RespReadUtils.ReadLengthHeader(out var value, ref start, start + testCase.Bytes.Length, allowNull: true); return value; } } @@ -72,6 +72,9 @@ public ulong ReadULongWithLengthHeader(AsciiTestCase testCase) public static IEnumerable SignedInt32EncodedValues => ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt32Values); + public static IEnumerable LengthHeaderValues + => ToRespLengthHeaderTestCases(RespIntegerWriteBenchmarks.SignedInt32Values); + public static IEnumerable SignedInt64EncodedValues => ToRespIntegerTestCases(RespIntegerWriteBenchmarks.SignedInt64Values); @@ -90,6 +93,9 @@ public static IEnumerable UnsignedInt64EncodedValuesWithLengthHeader public static IEnumerable ToRespIntegerTestCases(T[] integerValues) where T : struct => integerValues.Select(testCase => new AsciiTestCase($":{testCase}\r\n")); + public static IEnumerable ToRespLengthHeaderTestCases(T[] integerValues) where T : struct + => integerValues.Select(testCase => new AsciiTestCase($"${testCase}\r\n")); + public static IEnumerable ToRespIntegerWithLengthHeader(T[] integerValues) where T : struct => integerValues.Select(testCase => new AsciiTestCase($"${testCase.ToString()?.Length ?? 0}\r\n{testCase}\r\n")); diff --git a/libs/client/GarnetClientProcessReplies.cs b/libs/client/GarnetClientProcessReplies.cs index 1b2d4dc783..a157272221 100644 --- a/libs/client/GarnetClientProcessReplies.cs +++ b/libs/client/GarnetClientProcessReplies.cs @@ -46,7 +46,7 @@ unsafe bool ProcessReplyAsString(ref byte* ptr, byte* end, out string result, ou break; case (byte)'$': - if (!RespReadUtils.ReadStringWithLengthHeader(out result, ref ptr, end)) + if (!RespReadUtils.ReadStringWithLengthHeader(out result, ref ptr, end, allowNull: true)) return false; break; diff --git a/libs/cluster/Session/ClusterSession.cs b/libs/cluster/Session/ClusterSession.cs index caf3f050d4..f66cfd7a79 100644 --- a/libs/cluster/Session/ClusterSession.cs +++ b/libs/cluster/Session/ClusterSession.cs @@ -4,7 +4,9 @@ using System; using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using Garnet.common; +using Garnet.common.Parsing; using Garnet.networking; using Garnet.server; using Garnet.server.ACL; @@ -225,38 +227,35 @@ bool CheckACLAdminPermissions() ReadOnlySpan GetCommand(ReadOnlySpan bufSpan, out bool success) { - if (bytesRead - readHead < 6) + success = false; + + var ptr = recvBufferPtr + readHead; + var end = recvBufferPtr + bytesRead; + + // Try to read the command length + if (!RespReadUtils.ReadLengthHeader(out int length, ref ptr, end)) { - success = false; return default; } - Debug.Assert(*(recvBufferPtr + readHead) == '$'); - int psize = *(recvBufferPtr + readHead + 1) - '0'; - readHead += 2; - while (*(recvBufferPtr + readHead) != '\r') - { - psize = psize * 10 + *(recvBufferPtr + readHead) - '0'; - if (bytesRead - readHead < 1) - { - success = false; - return default; - } - readHead++; - } - if (bytesRead - readHead < 2 + psize + 2) + readHead = (int)(ptr - recvBufferPtr); + + // Try to read the command value + ptr += length; + if (ptr + 2 > end) { - success = false; return default; } - Debug.Assert(*(recvBufferPtr + readHead + 1) == '\n'); - var result = bufSpan.Slice(readHead + 2, psize); - Debug.Assert(*(recvBufferPtr + readHead + 2 + psize) == '\r'); - Debug.Assert(*(recvBufferPtr + readHead + 2 + psize + 1) == '\n'); + if (*(ushort*)ptr != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } - readHead += 2 + psize + 2; success = true; + var result = bufSpan.Slice(readHead, length); + readHead += length + 2; + return result; } } diff --git a/libs/common/Parsing/RespParsingException.cs b/libs/common/Parsing/RespParsingException.cs new file mode 100644 index 0000000000..c7396c226a --- /dev/null +++ b/libs/common/Parsing/RespParsingException.cs @@ -0,0 +1,85 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Text; + +namespace Garnet.common.Parsing +{ + /// + /// Exception wrapper for RESP parsing errors. + /// + public class RespParsingException : GarnetException + { + /// + /// Construct a new RESP parsing exception with the given message. + /// + /// Message that described the exception that has occurred. + RespParsingException(string message) : base(message) + { + // Nothing... + } + + /// + /// Throw an "Unexcepted Token" exception. + /// + /// The character that was unexpected. + [DoesNotReturn] + public static void ThrowUnexpectedToken(byte token) + { + var c = (char)token; + var escaped = char.IsControl(c) ? $"\\x{token:x2}" : c.ToString(); + Throw($"Unexpected character '{escaped}'."); + } + + /// + /// Throw an invalid string length exception. + /// + /// The invalid string length. + [DoesNotReturn] + public static void ThrowInvalidStringLength(long len) + { + Throw($"Invalid string length '{len}'."); + } + + /// + /// Throw an invalid length exception. + /// + /// The invalid length. + [DoesNotReturn] + public static void ThrowInvalidLength(long len) + { + Throw($"Invalid length '{len}'."); + } + + /// + /// Throw NaN (not a number) exception. + /// + /// Pointer to an ASCII-encoded byte buffer containing the string that could not be converted. + /// Length of the buffer. + [DoesNotReturn] + public static unsafe void ThrowNotANumber(byte* buffer, int length) + { + Throw($"Unable to parse number: {Encoding.ASCII.GetString(buffer, length)}"); + } + + /// + /// Throw a exception indicating that an integer overflow has occurred. + /// + /// Pointer to an ASCII-encoded byte buffer containing the string that caused the overflow. + /// Length of the buffer. + [DoesNotReturn] + public static unsafe void ThrowIntegerOverflow(byte* buffer, int length) + { + Throw($"Unable to parse integer. The given number is larger than allowed: {Encoding.ASCII.GetString(buffer, length)}"); + } + + /// + /// Throw helper that throws a RespParsingException. + /// + /// Exception message. + [DoesNotReturn] + public static void Throw(string message) => + throw new RespParsingException(message); + } +} \ No newline at end of file diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs index c33a324e69..e66f339cc0 100644 --- a/libs/common/RespReadUtils.cs +++ b/libs/common/RespReadUtils.cs @@ -4,156 +4,375 @@ using System; using System.Buffers; using System.Buffers.Text; -using System.Diagnostics; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; +using Garnet.common.Parsing; namespace Garnet.common { /// - /// Utilities for reading RESP protocol + /// Utilities for reading RESP protocol messages. /// public static unsafe class RespReadUtils { /// - /// Get Header length + /// Tries to read the leading sign of the given ASCII-encoded number. /// - /// - /// - /// - /// - public static bool ReadHeaderLength(out int len, ref byte* ptr, byte* end) + /// String to try reading sign from. + /// Whether the sign is '-'. + /// True if either '+' or '-' was found, false otherwise. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryReadSign(byte* ptr, out bool negative) { - len = -1; - if (ptr + 3 >= end) - return false; + negative = (*ptr == '-'); + return negative || (*ptr == '+'); + } - Debug.Assert(*ptr == '$'); - ptr++; - bool neg = *ptr == '-'; - int ksize = *ptr++ - '0'; - while (*ptr != '\r') + /// + /// Tries to read an unsigned 64-bit integer from a given ASCII-encoded input stream. + /// The input may include leading zeros. + /// + /// Pointer to the beginning of the ASCII encoded input string. + /// The end of the string to parse. + /// If parsing was successful, contains the parsed ulong value. + /// If parsing was successful, contains the number of bytes that were parsed. + /// + /// True if a ulong was successfully parsed, false if the input string did not start with + /// a valid integer or the end of the string was reached before finishing parsing. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryReadUlong(ref byte* ptr, byte* end, out ulong value, out ulong bytesRead) + { + bytesRead = 0; + value = 0; + var readHead = ptr; + + // Fast path for the first 19 digits. + // NOTE: UINT64 overflows can only happen on digit 20 or later (if integer contains leading zeros). + var fastPathEnd = ptr + 19; + while (readHead < fastPathEnd) { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - ksize = ksize * 10 + *ptr++ - '0'; - if (ptr >= end) + if (readHead > end) + { return false; + } + + var nextDigit = (uint)(*readHead - '0'); + if (nextDigit > 9) + { + goto Done; + } + + value = (10 * value) + nextDigit; + + readHead++; } - ptr += 2; - if (ptr > end) - return false; - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); - len = neg ? -ksize : ksize; + // Parse remaining digits, while checking for overflows. + while (true) + { + if (readHead > end) + { + return false; + } + + var nextDigit = (uint)(*readHead - '0'); + if (nextDigit > 9) + { + goto Done; + } + + if ((value == 1844674407370955161UL && ((int)nextDigit > 5)) || (value > 1844674407370955161UL)) + { + RespParsingException.ThrowIntegerOverflow(ptr, (int)(readHead - ptr)); + } + + value = (10 * value) + nextDigit; + + readHead++; + } + + Done: + bytesRead = (ulong)(readHead - ptr); + ptr = readHead; + return true; } + /// - /// Read int + /// Tries to read a signed 64-bit integer from a given ASCII-encoded input stream. /// - public static bool ReadInt(out int number, ref byte* ptr, byte* end) + /// Pointer to the beginning of the ASCII encoded input string. + /// The end of the string to parse. + /// If parsing was successful, contains the parsed long value. + /// If parsing was successful, contains the number of bytes that were parsed. + /// + /// True if a long was successfully parsed, false if the input string did not start with + /// a valid integer or the end of the string was reached before finishing parsing. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryReadLong(ref byte* ptr, byte* end, out long value, out ulong bytesRead) { - number = 0; - if (ptr + 3 >= end) + bytesRead = 0; + value = 0; + + // Parse optional leading sign + if (TryReadSign(ptr, out var negative)) + { + ptr++; + bytesRead = 1; + } + + // Parse digits as ulong + if (!TryReadUlong(ref ptr, end, out var number, out var digitsRead)) + { return false; + } - Debug.Assert(*ptr == '$'); + // Check for overflows and convert digits to long, if possible + if (negative) + { + if (number > ((ulong)long.MaxValue) + 1) + { + RespParsingException.ThrowIntegerOverflow(ptr - digitsRead, (int)digitsRead); + } - ptr++; - number = *ptr++ - '0'; - while (*ptr != '\r') + value = -1 - (long)(number - 1); + } + else { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - number = number * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; + if (number > long.MaxValue) + { + RespParsingException.ThrowIntegerOverflow(ptr - digitsRead, (int)digitsRead); + } + value = (long)number; } - ptr += 2; - if (ptr > end) + + bytesRead += digitsRead; + + return true; + } + + /// + /// Tries to read a signed 32-bit integer from a given ASCII-encoded input stream. + /// + /// Pointer to the beginning of the ASCII encoded input string. + /// The end of the string to parse. + /// If parsing was successful, contains the parsed int value. + /// If parsing was successful, contains the number of bytes that were parsed. + /// + /// True if an int was successfully parsed, false if the input string did not start with + /// a valid integer or the end of the string was reached before finishing parsing. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool TryReadInt(ref byte* ptr, byte* end, out int value, out ulong bytesRead) + { + bytesRead = 0; + value = 0; + + // Parse optional leading sign + if (TryReadSign(ptr, out var negative)) + { + ptr++; + bytesRead = 1; + } + + // Parse digits as ulong + if (!TryReadUlong(ref ptr, end, out var number, out var digitsRead)) + { return false; + } + + // Check for overflows and convert digits to int, if possible + if (negative) + { + if (number > ((ulong)int.MaxValue) + 1) + { + RespParsingException.ThrowIntegerOverflow(ptr - digitsRead, (int)digitsRead); + } + + value = (int)(0 - (long)number); + } + else + { + if (number > int.MaxValue) + { + RespParsingException.ThrowIntegerOverflow(ptr - digitsRead, (int)digitsRead); + } + value = (int)number; + } + + bytesRead += digitsRead; - Debug.Assert(*(ptr - 1) == '\n'); return true; } /// - /// Read signed 64 bit number + /// Tries to read a RESP length header from the given ASCII-encoded RESP string + /// and, if successful, moves the given ptr to the end of the length header. /// - public static bool Read64Int(out long number, ref byte* ptr, byte* end) + /// If parsing was successful, contains the extracted length from the header. + /// The starting position in the RESP string. Will be advanced if parsing is successful. + /// The current end of the RESP string. + /// Whether to allow special null length header ($-1\r\n). + /// Whether to parse an array length header ('*...\r\n') or a string length header ('$...\r\n'). + /// True if a length header was successfully read. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool ReadLengthHeader(out int length, ref byte* ptr, byte* end, bool allowNull = false, bool isArray = false) { - number = 0; - if (ptr + 3 >= end) + length = -1; + if (ptr + 3 > end) return false; - Debug.Assert(*ptr == ':'); + var readHead = ptr + 1; + var negative = *readHead == '-'; - ptr++; - number = *ptr++ - '0'; - while (*ptr != '\r') + // String length headers must start with a '$', array headers with '*' + if (*ptr != (isArray ? '*' : '$')) { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - number = number * 10 + *ptr++ - '0'; - if (ptr >= end) + RespParsingException.ThrowUnexpectedToken(*ptr); + } + + // Special case: '$-1' (NULL value) + if (negative) + { + if (readHead + 4 > end) + { return false; + } + + if (allowNull && (*(uint*)readHead == MemoryMarshal.Read("-1\r\n"u8))) + { + ptr = readHead + 4; + return true; + } + readHead++; } - ptr += 2; + + // Parse length + if (!TryReadUlong(ref readHead, end, out var value, out var digitsRead)) + { + return false; + } + + if (digitsRead == 0) + { + RespParsingException.ThrowUnexpectedToken(*readHead); + } + + // Validate length + length = (int)value; + + if (negative) + { + RespParsingException.ThrowInvalidStringLength(-length); + } + + if (value > int.MaxValue) + { + RespParsingException.ThrowIntegerOverflow(readHead - digitsRead, (int)digitsRead); + } + + // Ensure terminator has been received + ptr = readHead + 2; if (ptr > end) + { return false; + } + + if (*(ushort*)readHead != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } - Debug.Assert(*(ptr - 1) == '\n'); return true; } - /// - /// Read the length of an array of bulk strings + /// Read signed 64 bit number /// - /// - /// - /// - /// - public static bool ReadArrayLength(out int number, ref byte* ptr, byte* end) + public static bool Read64Int(out long number, ref byte* ptr, byte* end) { number = 0; if (ptr + 3 >= end) return false; - Debug.Assert(*ptr == '*'); + // Integer header must start with ':' + if (*ptr++ != ':') + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } - ptr++; - number = *ptr++ - '0'; - while (*ptr != '\r') + // Parse length + if (!TryReadLong(ref ptr, end, out number, out var bytesRead)) { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - number = number * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; + return false; } + + // Ensure terminator has been received ptr += 2; if (ptr > end) + { return false; + } + + if (*(ushort*)(ptr - 2) != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } - Debug.Assert(*(ptr - 1) == '\n'); return true; } + /// + /// Tries to read a RESP array length header from the given ASCII-encoded RESP string + /// and, if successful, moves the given ptr to the end of the length header. + /// + /// If parsing was successful, contains the extracted length from the header. + /// The starting position in the RESP string. Will be advanced if parsing is successful. + /// The current end of the RESP string. + /// True if a length header was successfully read. + + public static bool ReadArrayLength(out int length, ref byte* ptr, byte* end) + => ReadLengthHeader(out length, ref ptr, end, isArray: true); + + /// /// Read int with length header /// public static bool ReadIntWithLengthHeader(out int number, ref byte* ptr, byte* end) { number = 0; - if (!ReadInt(out int numberLength, ref ptr, end)) + + // Parse RESP string header + if (!ReadLengthHeader(out var numberLength, ref ptr, end)) return false; if (ptr + numberLength + 2 > end) return false; - number = (int)NumUtils.BytesToLong(numberLength, ptr); - ptr += numberLength + 2; - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + // Parse associated integer value + var numberStart = ptr; + if (!TryReadInt(ref ptr, end, out number, out var bytesRead)) + { + return false; + } + + if ((int)bytesRead != numberLength) + { + RespParsingException.ThrowNotANumber(numberStart, numberLength); + } + + // Ensure terminator has been received + if (*(ushort*)ptr != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + + ptr += 2; + return true; } @@ -163,16 +382,34 @@ public static bool ReadIntWithLengthHeader(out int number, ref byte* ptr, byte* public static bool ReadLongWithLengthHeader(out long number, ref byte* ptr, byte* end) { number = 0; - if (!ReadInt(out int numberLength, ref ptr, end)) + + // Parse RESP string header + if (!ReadLengthHeader(out var numberLength, ref ptr, end)) return false; if (ptr + numberLength + 2 > end) return false; - number = NumUtils.BytesToLong(numberLength, ptr); - ptr += numberLength + 2; - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + // Parse associated integer value + var numberStart = ptr; + if (!TryReadLong(ref ptr, end, out number, out var bytesRead)) + { + return false; + } + + if ((int)bytesRead != numberLength) + { + RespParsingException.ThrowNotANumber(numberStart, numberLength); + } + + // Ensure terminator has been received + if (*(ushort*)ptr != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + + ptr += 2; + return true; } @@ -182,16 +419,34 @@ public static bool ReadLongWithLengthHeader(out long number, ref byte* ptr, byte public static bool ReadULongWithLengthHeader(out ulong number, ref byte* ptr, byte* end) { number = 0; - if (!ReadInt(out int numberLength, ref ptr, end)) + + // Parse RESP string header + if (!ReadLengthHeader(out var numberLength, ref ptr, end)) return false; if (ptr + numberLength + 2 > end) return false; - number = NumUtils.BytesToULong(numberLength, ptr); - ptr += numberLength + 2; - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + // Parse associated integer value + var numberStart = ptr; + if (!TryReadUlong(ref ptr, end, out number, out var bytesRead)) + { + return false; + } + + if ((int)bytesRead != numberLength) + { + RespParsingException.ThrowNotANumber(numberStart, numberLength); + } + + // Ensure terminator has been received + if (*(ushort*)ptr != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + + ptr += 2; + return true; } @@ -201,100 +456,114 @@ public static bool ReadULongWithLengthHeader(out ulong number, ref byte* ptr, by public static bool ReadByteArrayWithLengthHeader(out byte[] result, ref byte* ptr, byte* end) { result = null; - if (ptr + 3 >= end) + + // Parse RESP string header + if (!ReadLengthHeader(out var length, ref ptr, end)) return false; - Debug.Assert(*ptr == '$'); - ptr++; - bool neg = *ptr == '-'; - int ksize = *ptr++ - '0'; - while (*ptr != '\r') - { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - ksize = ksize * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; - } + // Advance read pointer to the end of the array (including terminator) + var keyPtr = ptr; - if (neg) - { - ptr += 2; - if (ptr > end) return false; - return true; - } + ptr += length + 2; - var keyPtr = ptr + 2; - ptr = ptr + 2 + ksize + 2; // for \r\n + key + \r\n if (ptr > end) return false; - Debug.Assert(*(ptr + 1 - (2 + ksize + 2)) == '\n'); - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + // Ensure terminator has been received + if (*(ushort*)(ptr - 2) != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*(ptr - 2)); + } + + result = new Span(keyPtr, length).ToArray(); - result = new Span(keyPtr, ksize).ToArray(); return true; } /// - /// Read string with length header + /// Read boolean value with length header /// public static bool ReadBoolWithLengthHeader(out bool result, ref byte* ptr, byte* end) { - //$1\r\n1\r\n - //$1\r\n0\r\n result = false; - if (ptr + 7 >= end) + + if (ptr + 7 > end) return false; - Debug.Assert(*ptr == '$'); - Debug.Assert(*(ptr + 1) == '1'); - Debug.Assert(*(ptr + 2) == '\r'); - Debug.Assert(*(ptr + 3) == '\n'); - ptr += 4; - result = *ptr == '1' ? true : false; - ptr += 3; + + // Fast path: RESP string header should have length 1 + if (*(uint*)ptr == MemoryMarshal.Read("$1\r\n"u8)) + { + ptr += 4; + } + else + { + // Parse malformed RESP string header + if (!ReadLengthHeader(out var length, ref ptr, end)) + return false; + + if (length != 1) + { + RespParsingException.ThrowInvalidLength(length); + } + } + + // Parse contents (needs to be 1 character) + result = (*ptr++ == '1'); + + // Ensure terminator has been received + if (*(ushort*)ptr != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + + ptr += 2; + return true; } /// - /// Read string with length header + /// Tries to read a RESP-formatted string including its length header from the given ASCII-encoded + /// RESP message and, if successful, moves the given ptr to the end of the string value. /// - public static bool ReadStringWithLengthHeader(out string result, ref byte* ptr, byte* end) + /// If parsing was successful, contains the extracted string value. + /// The starting position in the RESP message. Will be advanced if parsing is successful. + /// The current end of the RESP message. + /// Whether to allow the RESP null value ($-1\r\n) + /// True if a RESP string was successfully read. + public static bool ReadStringWithLengthHeader(out string result, ref byte* ptr, byte* end, bool allowNull = false) { result = null; - if (ptr + 3 >= end) + + if (ptr + 3 > end) return false; - Debug.Assert(*ptr == '$'); - ptr++; - bool neg = *ptr == '-'; - int ksize = *ptr++ - '0'; - while (*ptr != '\r') - { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - ksize = ksize * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; - } + // Parse RESP string header + if (!ReadLengthHeader(out var length, ref ptr, end, allowNull: allowNull)) + return false; - if (neg) + if (allowNull && length < 0) { - ptr += 2; - if (ptr > end) return false; + // NULL value ('$-1\r\n') return true; } - var keyPtr = ptr + 2; - ptr = ptr + 2 + ksize + 2; // for \r\n + key + \r\n + // Extract string content + '\r\n' terminator + var keyPtr = ptr; + + ptr += length + 2; + if (ptr > end) return false; - Debug.Assert(*(ptr + 1 - (2 + ksize + 2)) == '\n'); - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + // Ensure terminator has been received + if (*(ushort*)(ptr - 2) != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*(ptr - 2)); + } + + result = Encoding.UTF8.GetString(new Span(keyPtr, length)); - result = Encoding.UTF8.GetString(new Span(keyPtr, ksize)); return true; } @@ -304,39 +573,35 @@ public static bool ReadStringWithLengthHeader(out string result, ref byte* ptr, public static bool ReadStringWithLengthHeader(MemoryPool pool, out MemoryResult result, ref byte* ptr, byte* end) { result = default; - if (ptr + 3 >= end) + if (ptr + 3 > end) return false; - Debug.Assert(*ptr == '$'); - ptr++; - bool neg = *ptr == '-'; - int ksize = *ptr++ - '0'; - while (*ptr != '\r') - { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - ksize = ksize * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; - } + // Parse RESP string header + if (!ReadLengthHeader(out var length, ref ptr, end)) + return false; - if (neg) + if (length < 0) { - ptr += 2; - if (ptr > end) return false; + // NULL value ('$-1\r\n') return true; } - var keyPtr = ptr + 2; - ptr = ptr + 2 + ksize + 2; // for \r\n + key + \r\n + // Extract string content + '\r\n' terminator + var keyPtr = ptr; + + ptr += length + 2; + if (ptr > end) return false; - Debug.Assert(*(ptr + 1 - (2 + ksize + 2)) == '\n'); - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + // Ensure terminator has been received + if (*(ushort*)(ptr - 2) != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*(ptr - 2)); + } - result = MemoryResult.Create(pool, ksize); - new ReadOnlySpan(keyPtr, ksize).CopyTo(result.Span); + result = MemoryResult.Create(pool, length); + new ReadOnlySpan(keyPtr, length).CopyTo(result.Span); return true; } @@ -346,10 +611,16 @@ public static bool ReadStringWithLengthHeader(MemoryPool pool, out MemoryR public static bool ReadSimpleString(out string result, ref byte* ptr, byte* end) { result = null; + if (ptr + 2 >= end) return false; - Debug.Assert(*ptr == '+'); + // Simple strings need to start with a '+' + if (*ptr != '+') + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + ptr++; return ReadString(out result, ref ptr, end); @@ -364,7 +635,12 @@ public static bool ReadErrorAsString(out string result, ref byte* ptr, byte* end if (ptr + 2 >= end) return false; - Debug.Assert(*ptr == '-'); + // Error strings need to start with a '-' + if (*ptr != '-') + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + ptr++; return ReadString(out result, ref ptr, end); @@ -379,7 +655,12 @@ public static bool ReadIntegerAsString(out string result, ref byte* ptr, byte* e if (ptr + 2 >= end) return false; - Debug.Assert(*ptr == ':'); + // Integer strings need to start with a ':' + if (*ptr != ':') + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + ptr++; return ReadString(out result, ref ptr, end); @@ -394,7 +675,12 @@ public static bool ReadSimpleString(MemoryPool pool, out MemoryResult= end) return false; - Debug.Assert(*ptr == '+'); + // Simple strings need to start with a '+' + if (*ptr != '+') + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + ptr++; return ReadString(pool, out result, ref ptr, end); @@ -409,7 +695,12 @@ public static bool ReadErrorAsString(MemoryPool pool, out MemoryResult= end) return false; - Debug.Assert(*ptr == '-'); + // Error strings need to start with a '-' + if (*ptr != '-') + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + ptr++; return ReadString(pool, out result, ref ptr, end); @@ -424,7 +715,12 @@ public static bool ReadIntegerAsString(MemoryPool pool, out MemoryResult= end) return false; - Debug.Assert(*ptr == ':'); + // Integer strings need to start with a ':' + if (*ptr != ':') + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } + ptr++; return ReadString(pool, out result, ref ptr, end); @@ -436,38 +732,31 @@ public static bool ReadIntegerAsString(MemoryPool pool, out MemoryResult= end) - return false; - Debug.Assert(*ptr == '*'); - ptr++; - bool neg = *ptr == '-'; - int asize = *ptr++ - '0'; - while (*ptr != '\r') + // Parse RESP array header + if (!ReadArrayLength(out var length, ref ptr, end)) { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - asize = asize * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; - } - ptr += 2; // for \r\n - if (ptr > end) return false; + } - if (neg) + if (length < 0) + { + // NULL value ('*-1\r\n') return true; + } - result = new string[asize]; - for (int z = 0; z < asize; z++) + // Parse individual strings in the array + result = new string[length]; + for (var i = 0; i < length; i++) { if (*ptr == '$') { - if (!ReadStringWithLengthHeader(out result[z], ref ptr, end)) + if (!ReadStringWithLengthHeader(out result[i], ref ptr, end)) return false; } else { - if (!ReadIntegerAsString(out result[z], ref ptr, end)) + if (!ReadIntegerAsString(out result[i], ref ptr, end)) return false; } } @@ -481,38 +770,30 @@ public static bool ReadStringArrayWithLengthHeader(out string[] result, ref byte public static bool ReadStringArrayWithLengthHeader(MemoryPool pool, out MemoryResult[] result, ref byte* ptr, byte* end) { result = null; - if (ptr + 3 >= end) - return false; - - Debug.Assert(*ptr == '*'); - ptr++; - bool neg = *ptr == '-'; - int asize = *ptr++ - '0'; - while (*ptr != '\r') + // Parse RESP array header + if (!ReadArrayLength(out var length, ref ptr, end)) { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - asize = asize * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; - } - ptr += 2; // for \r\n - if (ptr > end) return false; + } - if (neg) + if (length < 0) + { + // NULL value ('*-1\r\n') return true; + } - result = new MemoryResult[asize]; - for (int z = 0; z < asize; z++) + // Parse individual strings in the array + result = new MemoryResult[length]; + for (var i = 0; i < length; i++) { if (*ptr == '$') { - if (!ReadStringWithLengthHeader(pool, out result[z], ref ptr, end)) + if (!ReadStringWithLengthHeader(pool, out result[i], ref ptr, end)) return false; } else { - if (!ReadIntegerAsString(pool, out result[z], ref ptr, end)) + if (!ReadIntegerAsString(pool, out result[i], ref ptr, end)) return false; } } @@ -525,12 +806,13 @@ public static bool ReadStringArrayWithLengthHeader(MemoryPool pool, out Me /// public static bool ReadDoubleWithLengthHeader(out double result, out bool parsed, ref byte* ptr, byte* end) { - parsed = false; if (!ReadByteArrayWithLengthHeader(out var resultBytes, ref ptr, end)) { result = 0; + parsed = false; return false; } + parsed = Utf8Parser.TryParse(resultBytes, out result, out var bytesConsumed, default) && bytesConsumed == resultBytes.Length; return true; @@ -541,76 +823,87 @@ public static bool ReadDoubleWithLengthHeader(out double result, out bool parsed /// public static bool ReadSpanByteWithLengthHeader(ref Span result, ref byte* ptr, byte* end) { - if (ptr + 3 >= end) + // Parse RESP string header + if (!ReadLengthHeader(out var len, ref ptr, end)) + { return false; + } - Debug.Assert(*ptr == '$'); - ptr++; - int ksize = *ptr++ - '0'; - while (*ptr != '\r') + if (len < 0) { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - ksize = ksize * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; + // NULL value ('$-1\r\n') + result = null; + return true; } - var keyPtr = ptr + 2; - ptr = ptr + 2 + ksize + 2; // for \r\n + key + \r\n + + var keyPtr = ptr; + + // Parse content: ensure that input contains key + '\r\n' + ptr += len + 2; if (ptr > end) + { return false; + } - Debug.Assert(*(ptr + 1 - (2 + ksize + 2)) == '\n'); - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + if (*(ushort*)(ptr - 2) != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*(ptr - 2)); + } - result = new Span(keyPtr, ksize); + result = new Span(keyPtr, len); return true; } /// - /// Read pointer to byte array, with length header + /// Read pointer to byte array, with length header. /// + /// Pointer to the beginning of the read byte array (including empty). + /// Length of byte array. + /// Current read head of the input RESP stream. + /// Current end of the input RESP stream. + /// True if input was complete, otherwise false. + /// Thrown if array length was invalid. [MethodImpl(MethodImplOptions.AggressiveInlining)] public static bool ReadPtrWithLengthHeader(ref byte* result, ref int len, ref byte* ptr, byte* end) { - if (ptr + 3 >= end) // we need at least 3 characters: [$0\r] - return false; - - Debug.Assert(*ptr == '$'); - ptr++; - len = *ptr++ - '0'; - while (*ptr != '\r') + // Parse RESP string header + if (!ReadLengthHeader(out len, ref ptr, end)) { - Debug.Assert(*ptr >= '0' && *ptr <= '9'); - len = len * 10 + *ptr++ - '0'; - if (ptr >= end) - return false; + return false; } - result = ptr + 2; - ptr = ptr + 2 + len + 2; // for \r\n + key + \r\n + + result = ptr; + + // Parse content: ensure that input contains key + '\r\n' + ptr += len + 2; if (ptr > end) + { return false; + } - Debug.Assert(*(ptr + 1 - (2 + len + 2)) == '\n'); - Debug.Assert(*(ptr - 2) == '\r'); - Debug.Assert(*(ptr - 1) == '\n'); + if (*(ushort*)(ptr - 2) != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*(ptr - 2)); + } return true; } /// - /// Read string + /// Read ASCII string without header until string terminator ('\r\n'). /// private static bool ReadString(out string result, ref byte* ptr, byte* end) { result = null; + if (ptr + 1 >= end) return false; - byte* start = ptr; + var start = ptr; + while (ptr < end - 1) { - if (*ptr == (byte)'\r' && *(ptr + 1) == (byte)'\n') + if (*(ushort*)ptr == MemoryMarshal.Read("\r\n"u8)) { result = Encoding.UTF8.GetString(new ReadOnlySpan(start, (int)(ptr - start))); ptr += 2; @@ -623,7 +916,7 @@ private static bool ReadString(out string result, ref byte* ptr, byte* end) } /// - /// Read string + /// Read ASCII string without header until string terminator ('\r\n'). /// private static bool ReadString(MemoryPool pool, out MemoryResult result, ref byte* ptr, byte* end) { @@ -631,10 +924,10 @@ private static bool ReadString(MemoryPool pool, out MemoryResult res if (ptr + 1 >= end) return false; - byte* start = ptr; + var start = ptr; while (ptr < end - 1) { - if (*ptr == (byte)'\r' && *(ptr + 1) == (byte)'\n') + if (*(ushort*)ptr == MemoryMarshal.Read("\r\n"u8)) { result = MemoryResult.Create(pool, (int)(ptr - start)); new ReadOnlySpan(start, result.Length).CopyTo(result.Span); @@ -655,7 +948,7 @@ public static bool ReadSerializedSpanByte(ref byte* keyPtr, ref byte keyMetaData //1. safe read ksize if (ptr + sizeof(int) > end) return false; - int ksize = *(int*)(ptr); + var ksize = *(int*)ptr; ptr += sizeof(int); //2. safe read key bytes @@ -668,7 +961,7 @@ public static bool ReadSerializedSpanByte(ref byte* keyPtr, ref byte keyMetaData //3. safe read vsize if (ptr + 4 > end) return false; - int vsize = *(int*)(ptr); + var vsize = *(int*)ptr; ptr += sizeof(int); //4. safe read value bytes @@ -717,7 +1010,7 @@ public static bool ReadSerializedData(out byte[] key, out byte[] value, out long //5. safe read expiration info if (ptr + 8 > end) return false; - expiration = *(long*)(ptr); + expiration = *(long*)ptr; ptr += 8; key = new byte[keyLen]; diff --git a/libs/server/Metrics/Latency/GarnetLatencyMetrics.cs b/libs/server/Metrics/Latency/GarnetLatencyMetrics.cs index b4e1480a5b..cb532c2757 100644 --- a/libs/server/Metrics/Latency/GarnetLatencyMetrics.cs +++ b/libs/server/Metrics/Latency/GarnetLatencyMetrics.cs @@ -5,7 +5,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Globalization; -using System.Linq; using Garnet.common; using HdrHistogram; diff --git a/libs/server/Metrics/Latency/GarnetLatencyMetricsSession.cs b/libs/server/Metrics/Latency/GarnetLatencyMetricsSession.cs index d7f8c31a90..93559f296c 100644 --- a/libs/server/Metrics/Latency/GarnetLatencyMetricsSession.cs +++ b/libs/server/Metrics/Latency/GarnetLatencyMetricsSession.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using System; -using System.Linq; using System.Runtime.CompilerServices; using Garnet.common; diff --git a/libs/server/Objects/Set/SetObjectImpl.cs b/libs/server/Objects/Set/SetObjectImpl.cs index 68ca7226ac..89479ac0a9 100644 --- a/libs/server/Objects/Set/SetObjectImpl.cs +++ b/libs/server/Objects/Set/SetObjectImpl.cs @@ -3,7 +3,6 @@ using System; using System.Buffers; -using System.Collections.Generic; using System.Linq; using System.Security.Cryptography; using Garnet.common; diff --git a/libs/server/Resp/BasicCommands.cs b/libs/server/Resp/BasicCommands.cs index 90893478f5..2643d3f660 100644 --- a/libs/server/Resp/BasicCommands.cs +++ b/libs/server/Resp/BasicCommands.cs @@ -724,6 +724,7 @@ private bool NetworkIncrement(byte* ptr, RespCommand cmd, ref TGarne { Debug.Assert(cmd == RespCommand.INCRBY || cmd == RespCommand.DECRBY || cmd == RespCommand.INCR || cmd == RespCommand.DECR); + // Parse key argument byte* keyPtr = null; int ksize = 0; @@ -733,6 +734,8 @@ private bool NetworkIncrement(byte* ptr, RespCommand cmd, ref TGarne ArgSlice input = default; if (cmd == RespCommand.INCRBY || cmd == RespCommand.DECRBY) { + // Parse value argument + // NOTE: Parse empty strings for better error messages through storageApi.Increment byte* valPtr = null; int vsize = 0; if (!RespReadUtils.ReadPtrWithLengthHeader(ref valPtr, ref vsize, ref ptr, recvBufferPtr + bytesRead)) diff --git a/libs/server/Resp/Objects/ListCommands.cs b/libs/server/Resp/Objects/ListCommands.cs index 375004bb8b..1279189723 100644 --- a/libs/server/Resp/Objects/ListCommands.cs +++ b/libs/server/Resp/Objects/ListCommands.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using System; -using System.Linq; using Garnet.common; using Tsavorite.core; diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index 4a8fafc8c7..17bbf0b6b1 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -8,6 +8,7 @@ using System.Runtime.InteropServices; using System.Text; using Garnet.common; +using Garnet.common.Parsing; using Garnet.networking; using Garnet.server.ACL; using Garnet.server.Auth; @@ -210,13 +211,32 @@ public override int TryConsumeMessages(byte* reqBuffer, int bytesReceived) clusterSession?.AcquireCurrentEpoch(); recvBufferPtr = reqBuffer; networkSender.GetResponseObject(); - ProcessMessages(); + + try + { + ProcessMessages(); + } + catch (RespParsingException ex) + { + logger?.LogCritical($"Aborting open session due to RESP parsing error: {ex.Message}"); + logger?.LogDebug(ex, "RespParsingException in ProcessMessages:"); + + // Forward parsing error as RESP error + while (!RespWriteUtils.WriteError($"ERR Protocol Error: {ex.Message}", ref dcurr, dend)) + SendAndReset(); + + // Send message and dispose the network sender to end the session + Send(networkSender.GetResponseObjectHead()); + networkSender.Dispose(); + } recvBufferPtr = null; } + + catch (Exception ex) { sessionMetrics?.incr_total_number_resp_server_session_exceptions(1); - logger?.LogCritical(ex, "ProcessMessages threw exception"); + logger?.LogCritical(ex, "ProcessMessages threw exception:"); // The session is no longer usable, dispose it networkSender.Dispose(); } @@ -639,38 +659,35 @@ private bool ProcessOtherCommands(RespCommand command, byte subcmd, ReadOnlySpan GetCommand(ReadOnlySpan bufSpan, out bool success) { - if (bytesRead - readHead < 6) + var ptr = recvBufferPtr + readHead; + var end = recvBufferPtr + bytesRead; + + // Try the command length + if (!RespReadUtils.ReadLengthHeader(out int length, ref ptr, end)) { success = false; return default; } - Debug.Assert(*(recvBufferPtr + readHead) == '$'); - int psize = *(recvBufferPtr + readHead + 1) - '0'; - readHead += 2; - while (*(recvBufferPtr + readHead) != '\r') - { - psize = psize * 10 + *(recvBufferPtr + readHead) - '0'; - if (bytesRead - readHead < 1) - { - success = false; - return default; - } - readHead++; - } - if (bytesRead - readHead < 2 + psize + 2) + readHead = (int)(ptr - recvBufferPtr); + + // Try to read the command value + ptr += length; + if (ptr + 2 > end) { success = false; return default; } - Debug.Assert(*(recvBufferPtr + readHead + 1) == '\n'); - var result = bufSpan.Slice(readHead + 2, psize); - Debug.Assert(*(recvBufferPtr + readHead + 2 + psize) == '\r'); - Debug.Assert(*(recvBufferPtr + readHead + 2 + psize + 1) == '\n'); + if (*(ushort*)ptr != MemoryMarshal.Read("\r\n"u8)) + { + RespParsingException.ThrowUnexpectedToken(*ptr); + } - readHead += 2 + psize + 2; + var result = bufSpan.Slice(readHead, length); + readHead += length + 2; success = true; + return result; } diff --git a/libs/server/Storage/Session/MainStore/MainStoreOps.cs b/libs/server/Storage/Session/MainStore/MainStoreOps.cs index a3a2e7eb83..de79ae37d2 100644 --- a/libs/server/Storage/Session/MainStore/MainStoreOps.cs +++ b/libs/server/Storage/Session/MainStore/MainStoreOps.cs @@ -544,7 +544,7 @@ public unsafe GarnetStatus RENAME(ArgSlice oldKeySlice, ArgSlice newKeySlice, St var memoryHandle = o.Memory.Memory.Pin(); var ptrVal = (byte*)memoryHandle.Pointer; - RespReadUtils.ReadHeaderLength(out var headerLength, ref ptrVal, ptrVal + o.Length); + RespReadUtils.ReadLengthHeader(out var headerLength, ref ptrVal, ptrVal + o.Length); var value = SpanByte.FromPinnedPointer(ptrVal, headerLength); SET(ref newKey, ref value, ref context); diff --git a/test/Garnet.test/GarnetBitmapTests.cs b/test/Garnet.test/GarnetBitmapTests.cs index 082e78bf79..ed610ea51a 100644 --- a/test/Garnet.test/GarnetBitmapTests.cs +++ b/test/Garnet.test/GarnetBitmapTests.cs @@ -762,15 +762,15 @@ public unsafe void BitmapSimpleBITOP_PCT(int bytesPerSend) { case Bitwise.And: dst = srcA & srcB & srcC; - response = lightClientRequest.SendCommandChunks("BITOP AND " + d + " " + a + " " + b + " " + " " + c, bytesPerSend); + response = lightClientRequest.SendCommandChunks("BITOP AND " + d + " " + a + " " + b + " " + c, bytesPerSend); break; case Bitwise.Or: dst = srcA | srcB | srcC; - response = lightClientRequest.SendCommandChunks("BITOP OR " + d + " " + a + " " + b + " " + " " + c, bytesPerSend); + response = lightClientRequest.SendCommandChunks("BITOP OR " + d + " " + a + " " + b + " " + c, bytesPerSend); break; case Bitwise.Xor: dst = srcA ^ srcB ^ srcC; - response = lightClientRequest.SendCommandChunks("BITOP XOR " + d + " " + a + " " + b + " " + " " + c, bytesPerSend); + response = lightClientRequest.SendCommandChunks("BITOP XOR " + d + " " + a + " " + b + " " + c, bytesPerSend); break; } diff --git a/test/Garnet.test/GarnetClientTests.cs b/test/Garnet.test/GarnetClientTests.cs index c723ba2806..2bd0f53ca4 100644 --- a/test/Garnet.test/GarnetClientTests.cs +++ b/test/Garnet.test/GarnetClientTests.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using System; -using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading; diff --git a/test/Garnet.test/Resp/RespReadUtilsTests.cs b/test/Garnet.test/Resp/RespReadUtilsTests.cs new file mode 100644 index 0000000000..cbb8cb8be6 --- /dev/null +++ b/test/Garnet.test/Resp/RespReadUtilsTests.cs @@ -0,0 +1,291 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System.Text; +using Garnet.common; +using Garnet.common.Parsing; +using NUnit.Framework; + +namespace Garnet.test.Resp +{ + /// + /// Tests for RespReadUtils parsing functions. + /// + unsafe class RespReadUtilsTests + { + /// + /// Tests that ReadLengthHeader successfully parses valid numbers. + /// + /// Header length encoded as an ASCII string. + /// Expected parsed header length as int. + [TestCase("0", 0)] + [TestCase("-1", -1)] + [TestCase("2147483647", 2147483647)] + public static unsafe void ReadLengthHeaderTest(string text, int expected) + { + var bytes = Encoding.ASCII.GetBytes($"${text}\r\n"); + fixed (byte* ptr = bytes) + { + var start = ptr; + var end = ptr + bytes.Length; + var success = RespReadUtils.ReadLengthHeader(out var length, ref start, end, allowNull: true); + + Assert.IsTrue(success); + Assert.AreEqual(expected, length); + Assert.IsTrue(start == end); + } + } + + /// + /// Tests that ReadLengthHeader throws exceptions for invalid inputs. + /// + /// Invalid ASCII-encoded string length header (including '$'). + [TestCase("$\r\n\r\n")] // Empty input length + [TestCase("$-1\r\n")] // NULL should be disallowed + [TestCase("123\r\n")] // Missing $ + [TestCase("$-2147483648\r\n")] // Valid Int32 value but negative (not allowed) + [TestCase("$-2\r\n")] // -1 should be legal, but -2 should not be + [TestCase("$2147483648\r\n")] // Should cause an overflow + [TestCase("$123ab\r\n")] // Not a number + [TestCase("$123ab")] // Missing "\r\n" + public static unsafe void ReadLengthHeaderExceptionsTest(string text) + { + var bytes = Encoding.ASCII.GetBytes(text); + _ = Assert.Throws(() => + { + fixed (byte* ptr = bytes) + { + var start = ptr; + _ = RespReadUtils.ReadLengthHeader(out var length, ref start, ptr + bytes.Length, allowNull: false); + } + }); + } + + /// + /// Tests that ReadArrayLength successfully parses valid numbers. + /// + /// Header length encoded as an ASCII string. + /// Expected parsed header length as int. + [TestCase("0", 0)] + [TestCase("2147483647", 2147483647)] + public static unsafe void ReadArrayLengthTest(string text, int expected) + { + var bytes = Encoding.ASCII.GetBytes($"*{text}\r\n"); + fixed (byte* ptr = bytes) + { + var start = ptr; + var end = ptr + bytes.Length; + var success = RespReadUtils.ReadArrayLength(out var length, ref start, end); + + Assert.IsTrue(success); + Assert.AreEqual(expected, length); + Assert.IsTrue(start == end); + } + } + + /// + /// Tests that ReadArrayLength throws exceptions for invalid inputs. + /// + /// Invalid ASCII-encoded array length header (including '*'). + [TestCase("*\r\n\r\n")] // Empty input length + [TestCase("123\r\n")] // Missing * + [TestCase("*-2147483648\r\n")] // Valid Int32 value but negative (not allowed) + [TestCase("*-2\r\n")] // -1 should be legal, but -2 should not be + [TestCase("*2147483648\r\n")] // Should cause an overflow + [TestCase("*123ab\r\n")] // Not a number + [TestCase("*123ab")] // Missing "\r\n" + public static unsafe void ReadArrayLengthExceptionsTest(string text) + { + var bytes = Encoding.ASCII.GetBytes(text); + _ = Assert.Throws(() => + { + fixed (byte* ptr = bytes) + { + var start = ptr; + _ = RespReadUtils.ReadArrayLength(out var length, ref start, ptr + bytes.Length); + } + }); + } + + /// + /// Tests that ReadIntWithLengthHeader successfully parses valid integers. + /// + /// Int encoded as an ASCII string. + /// Expected parsed value. + [TestCase("0", 0)] + [TestCase("-2147483648", -2147483648)] + [TestCase("2147483647", 2147483647)] + public static unsafe void ReadIntWithLengthHeaderTest(string text, int expected) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + fixed (byte* ptr = bytes) + { + var start = ptr; + var end = ptr + bytes.Length; + var success = RespReadUtils.ReadIntWithLengthHeader(out var length, ref start, end); + + Assert.IsTrue(success); + Assert.AreEqual(expected, length); + Assert.IsTrue(start == end); + } + } + + /// + /// Tests that ReadIntWithLengthHeader throws exceptions for invalid inputs. + /// + /// Invalid ASCII-encoded input number. + [TestCase("2147483648")] // Should cause overflow + [TestCase("-2147483649")] // Should cause overflow + [TestCase("123abc")] // Not a number + [TestCase("abc121cba")] // Not a number + public static unsafe void ReadIntWithLengthHeaderExceptionsTest(string text) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + + _ = Assert.Throws(() => + { + fixed (byte* ptr = bytes) + { + var start = ptr; + _ = RespReadUtils.ReadIntWithLengthHeader(out var length, ref start, ptr + bytes.Length); + } + }); + } + + /// + /// Tests that ReadLongWithLengthHeader successfully parses valid longs. + /// + /// Long int encoded as an ASCII string. + /// Expected parsed value. + [TestCase("0", 0L)] + [TestCase("-9223372036854775808", -9223372036854775808L)] + [TestCase("9223372036854775807", 9223372036854775807L)] + public static unsafe void ReadLongWithLengthHeaderTest(string text, long expected) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + fixed (byte* ptr = bytes) + { + var start = ptr; + var end = ptr + bytes.Length; + var success = RespReadUtils.ReadLongWithLengthHeader(out var length, ref start, end); + + Assert.IsTrue(success); + Assert.AreEqual(expected, length); + Assert.IsTrue(start == end); + } + } + + /// + /// Tests that ReadLongWithLengthHeader throws exceptions for invalid inputs. + /// + /// Invalid ASCII-encoded input number. + [TestCase("9223372036854775808")] // Should cause overflow + [TestCase("-9223372036854775809")] // Should cause overflow + [TestCase("10000000000000000000")] // Should cause overflow + [TestCase("123abc")] // Not a number + [TestCase("abc121cba")] // Not a number + public static unsafe void ReadLongWithLengthHeaderExceptionsTest(string text) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + + _ = Assert.Throws(() => + { + fixed (byte* ptr = bytes) + { + var start = ptr; + _ = RespReadUtils.ReadLongWithLengthHeader(out var length, ref start, ptr + bytes.Length); + } + }); + } + + /// + /// Tests that ReadULongWithLengthHeader successfully parses valid ulong integers. + /// + /// Unsigned long int encoded as an ASCII string. + /// Expected parsed value. + [TestCase("0", 0UL)] + [TestCase("18446744073709551615", 18446744073709551615UL)] + public static unsafe void ReadULongWithLengthHeaderTest(string text, ulong expected) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + fixed (byte* ptr = bytes) + { + var start = ptr; + var end = ptr + bytes.Length; + var success = RespReadUtils.ReadULongWithLengthHeader(out var length, ref start, end); + + Assert.IsTrue(success); + Assert.AreEqual(expected, length); + Assert.IsTrue(start == end); + } + } + + /// + /// Tests that ReadULongWithLengthHeader throws exceptions for invalid inputs. + /// + /// Invalid ASCII-encoded input number. + [TestCase("18446744073709551616")] // Should cause overflow + [TestCase("-1")] // Negative numbers are not allowed + [TestCase("123abc")] // Not a number + [TestCase("abc121cba")] // Not a number + public static unsafe void ReadULongWithLengthHeaderExceptionsTest(string text) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + + _ = Assert.Throws(() => + { + fixed (byte* ptr = bytes) + { + var start = ptr; + _ = RespReadUtils.ReadULongWithLengthHeader(out var length, ref start, ptr + bytes.Length); + } + }); + } + + /// + /// Tests that ReadPtrWithLengthHeader successfully parses simple strings. + /// + /// Input ASCII string. + [TestCase("test")] + [TestCase("")] + public static unsafe void ReadPtrWithLengthHeaderTest(string text) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + fixed (byte* ptr = bytes) + { + byte* result = null; + var length = -1; + var start = ptr; + var end = ptr + bytes.Length; + var success = RespReadUtils.ReadPtrWithLengthHeader(ref result, ref length, ref start, end); + + Assert.IsTrue(success); + Assert.IsTrue(result != null); + Assert.IsTrue(start == end); + Assert.IsTrue(length == text.Length); + } + } + + /// + /// Tests that ReadBoolWithLengthHeader successfully parses valid inputs. + /// + /// Int encoded as an ASCII string. + /// Expected parsed value. + [TestCase("1", true)] + [TestCase("0", false)] + public static unsafe void ReadBoolWithLengthHeaderTest(string text, bool expected) + { + var bytes = Encoding.ASCII.GetBytes($"${text.Length}\r\n{text}\r\n"); + fixed (byte* ptr = bytes) + { + var start = ptr; + var end = ptr + bytes.Length; + var success = RespReadUtils.ReadBoolWithLengthHeader(out var result, ref start, end); + + Assert.IsTrue(success); + Assert.AreEqual(expected, result); + Assert.IsTrue(start == end); + } + } + } +} \ No newline at end of file diff --git a/test/Garnet.test/RespTests.cs b/test/Garnet.test/RespTests.cs index 35ae5267a6..09d97e53bc 100644 --- a/test/Garnet.test/RespTests.cs +++ b/test/Garnet.test/RespTests.cs @@ -657,22 +657,23 @@ public void SimpleIncrementInvalidValue(RespCommand cmd, bool initialize) var db = redis.GetDatabase(0); string[] values = ["", "7 3", "02+(34", "笑い男", "01", "-01", "7ab"]; - foreach (var value in values) + for (var i = 0; i < values.Length; i++) { + var key = $"key{i}"; var exception = false; if (initialize) { - var resp = db.StringSet(value, value); + var resp = db.StringSet(key, values[i]); Assert.AreEqual(true, resp); } try { _ = cmd switch { - RespCommand.INCR => db.StringIncrement(value), - RespCommand.DECR => db.StringDecrement(value), - RespCommand.INCRBY => (initialize ? db.StringIncrement(value, 10L) : (long)db.Execute("INCRBY", [value, value])), - RespCommand.DECRBY => (initialize ? db.StringDecrement(value, 10L) : (long)db.Execute("DECRBY", [value, value])), + RespCommand.INCR => db.StringIncrement(key), + RespCommand.DECR => db.StringDecrement(key), + RespCommand.INCRBY => initialize ? db.StringIncrement(key, 10L) : (long)db.Execute("INCRBY", [key, values[i]]), + RespCommand.DECRBY => initialize ? db.StringDecrement(key, 10L) : (long)db.Execute("DECRBY", [key, values[i]]), _ => throw new Exception($"Command {cmd} not supported!"), }; }