diff --git a/Directory.Packages.props b/Directory.Packages.props index d674d273f0..0719a40ed8 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -8,7 +8,7 @@ - + diff --git a/benchmark/BDN.benchmark/BDN.benchmark.csproj b/benchmark/BDN.benchmark/BDN.benchmark.csproj index 60753023ab..525b4edc72 100644 --- a/benchmark/BDN.benchmark/BDN.benchmark.csproj +++ b/benchmark/BDN.benchmark/BDN.benchmark.csproj @@ -10,7 +10,7 @@ - + diff --git a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs new file mode 100644 index 0000000000..ecd5002bfb --- /dev/null +++ b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using BenchmarkDotNet.Attributes; +using Embedded.perftest; +using Garnet.server; + +namespace BDN.benchmark.Lua +{ + /// + /// Benchmark for non-script running operations in LuaRunner + /// + [MemoryDiagnoser] + public unsafe class LuaRunnerOperations + { + private const string SmallScript = "return nil"; + + private const string LargeScript = @" +-- based on a real UDF, with some extraneous ops removed + +local userKey = KEYS[1] +local identifier = ARGV[1] +local currentTime = ARGV[2] +local newSession = ARGV[3] + +local userKeyValue = redis.call(""GET"", userKey) + +local updatedValue = nil +local returnValue = -1 + +if userKeyValue then +-- needs to be updated + + local oldestEntry = nil + local oldestEntryUpdateTime = nil + + local match = nil + + local entryCount = 0 + +-- loop over each entry, looking for one to update + for entry in string.gmatch(userKeyValue, ""([^%|]+)"") do + entryCount = entryCount + 1 + + local entryIdentifier = nil + local entrySessionNumber = -1 + local entryRequestCount = -1 + local entryLastSessionUpdateTime = -1 + + local ix = 0 + for part in string.gmatch(entry, ""([^:]+)"") do + if ix == 0 then + entryIdentifier = part + elseif ix == 1 then + entrySessionNumber = tonumber(part) + elseif ix == 2 then + entryRequestCount = tonumber(part) + elseif ix == 3 then + entryLastSessionUpdateTime = tonumber(part) + else +-- malformed, too many parts + return -1 + end + + ix = ix + 1 + end + + if ix ~= 4 then +-- malformed, too few parts + return -2 + end + + if entryIdentifier == identifier then +-- found the one to update + local updatedEntry = nil + + if tonumber(newSession) == 1 then + local updatedSessionNumber = entrySessionNumber + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(updatedSessionNumber) .. "":1:"" .. tostring(currentTime) + returnValue = 3 + else + local updatedRequestCount = entryRequestCount + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(entrySessionNumber) .. "":"" .. tostring(updatedRequestCount) .. "":"" .. tostring(currentTime) + returnValue = 2 + end + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedEntry = string.gsub(entry, ""%p"", ""%%%1"") + updatedValue = string.gsub(userKeyValue, escapedEntry, updatedEntry) + + break + end + + if oldestEntryUpdateTime == nil or oldestEntryUpdateTime > entryLastSessionUpdateTime then +-- remember the oldest entry, so we can replace it if needed + oldestEntry = entry + oldestEntryUpdateTime = entryLastSessionUpdateTime + end + end + + if updatedValue == nil then +-- we didn't update an existing value, so we need to add it + + local newEntry = identifier .. "":1:1:"" .. tostring(currentTime) + + if entryCount < 20 then +-- there's room, just append it + updatedValue = userKeyValue .. ""|"" .. newEntry + returnValue = 4 + else +-- there isn't room, replace the LRU entry + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedOldestEntry = string.gsub(oldestEntry, ""%p"", ""%%%1"") + + updatedValue = string.gsub(userKeyValue, escapedOldestEntry, newEntry) + returnValue = 5 + end + end +else +-- needs to be created + updatedValue = identifier .. "":1:1:"" .. tostring(currentTime) + + returnValue = 1 +end + +redis.call(""SET"", userKey, updatedValue) + +return returnValue +"; + + /// + /// Lua parameters + /// + [ParamsSource(nameof(LuaParamsProvider))] + public LuaParams Params { get; set; } + + /// + /// Lua parameters provider + /// + public IEnumerable LuaParamsProvider() + { + yield return new(); + } + + private EmbeddedRespServer server; + private RespServerSession session; + + private LuaRunner paramsRunner; + + private LuaRunner smallCompileRunner; + private LuaRunner largeCompileRunner; + + [GlobalSetup] + public void GlobalSetup() + { + server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true }); + + session = server.GetRespSession(); + paramsRunner = new LuaRunner("return nil"); + + smallCompileRunner = new LuaRunner(SmallScript); + largeCompileRunner = new LuaRunner(LargeScript); + } + + [GlobalCleanup] + public void GlobalCleanup() + { + session.Dispose(); + server.Dispose(); + paramsRunner.Dispose(); + } + + [Benchmark] + public void ResetParametersSmall() + { + // First force up + paramsRunner.ResetParameters(1, 1); + + // Then require a small amount of clearing (1 key, 1 arg) + paramsRunner.ResetParameters(0, 0); + } + + [Benchmark] + public void ResetParametersLarge() + { + // First force up + paramsRunner.ResetParameters(10, 10); + + // Then require a large amount of clearing (10 keys, 10 args) + paramsRunner.ResetParameters(0, 0); + } + + [Benchmark] + public void ConstructSmall() + { + using var runner = new LuaRunner(SmallScript); + } + + [Benchmark] + public void ConstructLarge() + { + using var runner = new LuaRunner(LargeScript); + } + + [Benchmark] + public void CompileForSessionSmall() + { + smallCompileRunner.ResetCompilation(); + smallCompileRunner.CompileForSession(session); + } + + [Benchmark] + public void CompileForSessionLarge() + { + largeCompileRunner.ResetCompilation(); + largeCompileRunner.CompileForSession(session); + } + } +} \ No newline at end of file diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs new file mode 100644 index 0000000000..c804e0d9ac --- /dev/null +++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using BenchmarkDotNet.Attributes; +using Embedded.perftest; +using Garnet.common; +using Garnet.server; +using Garnet.server.Auth; + +namespace BDN.benchmark.Lua +{ + [MemoryDiagnoser] + public class LuaScriptCacheOperations + { + /// + /// Lua parameters + /// + [ParamsSource(nameof(LuaParamsProvider))] + public LuaParams Params { get; set; } + + /// + /// Lua parameters provider + /// + public IEnumerable LuaParamsProvider() + { + yield return new(); + } + + private EmbeddedRespServer server; + private StoreWrapper storeWrapper; + private SessionScriptCache sessionScriptCache; + private RespServerSession session; + + private byte[] outerHitDigest; + private byte[] innerHitDigest; + private byte[] missDigest; + + [GlobalSetup] + public void GlobalSetup() + { + server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true }); + storeWrapper = server.StoreWrapper; + sessionScriptCache = new SessionScriptCache(storeWrapper, new GarnetNoAuthAuthenticator()); + session = server.GetRespSession(); + + outerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + sessionScriptCache.GetScriptDigest("return 1"u8, outerHitDigest); + if (!storeWrapper.storeScriptCache.TryAdd(new(outerHitDigest), "return 1"u8.ToArray())) + { + throw new InvalidOperationException("Should have been able to load into global cache"); + } + + innerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + sessionScriptCache.GetScriptDigest("return 1 + 1"u8, innerHitDigest); + if (!storeWrapper.storeScriptCache.TryAdd(new(innerHitDigest), "return 1 + 1"u8.ToArray())) + { + throw new InvalidOperationException("Should have been able to load into global cache"); + } + + missDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + sessionScriptCache.GetScriptDigest("foobar"u8, missDigest); + } + + [GlobalCleanup] + public void GlobalCleanup() + { + session?.Dispose(); + server?.Dispose(); + } + + [IterationSetup] + public void IterationSetup() + { + // Force lookup to do work + sessionScriptCache.Clear(); + + // Make outer hit available for every iteration + if (!sessionScriptCache.TryLoad(session, "return 1"u8, new(outerHitDigest), out _, out _, out var error)) + { + throw new InvalidOperationException($"Should have been able to load: {error}"); + } + } + + [Benchmark] + public void LookupHit() + { + _ = sessionScriptCache.TryGetFromDigest(new(outerHitDigest), out _); + } + + [Benchmark] + public void LookupMiss() + { + _ = sessionScriptCache.TryGetFromDigest(new(missDigest), out _); + } + + [Benchmark] + public void LoadOuterHit() + { + // First if returns true + // + // This is the common case + LoadScript(outerHitDigest); + } + + [Benchmark] + public void LoadInnerHit() + { + // First if returns false, second if returns true + // + // This is expected, but rare + LoadScript(innerHitDigest); + } + + [Benchmark] + public void LoadMiss() + { + // First if returns false, second if returns false + // + // This is extremely unlikely, basically implies an error on the client + LoadScript(missDigest); + } + + [Benchmark] + public void Digest() + { + Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; + sessionScriptCache.GetScriptDigest("return 1 + redis.call('GET', KEYS[1])"u8, digest); + } + + /// + /// The moral equivalent to our cache load operation. + /// + private void LoadScript(Span digest) + { + AsciiUtils.ToLowerInPlace(digest); + + var digestKey = new ScriptHashKey(digest); + + if (!sessionScriptCache.TryGetFromDigest(digestKey, out var runner)) + { + if (storeWrapper.storeScriptCache.TryGetValue(digestKey, out var source)) + { + if (!sessionScriptCache.TryLoad(session, source, digestKey, out runner, out _, out var error)) + { + // TryLoad will have written an error out, it any + + _ = storeWrapper.storeScriptCache.TryRemove(digestKey, out _); + } + } + } + } + } +} \ No newline at end of file diff --git a/benchmark/BDN.benchmark/Lua/LuaScripts.cs b/benchmark/BDN.benchmark/Lua/LuaScripts.cs index 5c6060d2e6..3cff4b30e3 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScripts.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScripts.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Columns; using Garnet.server; namespace BDN.benchmark.Lua @@ -11,7 +10,6 @@ namespace BDN.benchmark.Lua /// Benchmark for Lua /// [MemoryDiagnoser] - [HideColumns(Column.Gen0)] public unsafe class LuaScripts { /// @@ -35,13 +33,13 @@ public IEnumerable LuaParamsProvider() public void GlobalSetup() { r1 = new LuaRunner("return"); - r1.Compile(); + r1.CompileForRunner(); r2 = new LuaRunner("return 1 + 1"); - r2.Compile(); + r2.CompileForRunner(); r3 = new LuaRunner("return KEYS[1]"); - r3.Compile(); + r3.CompileForRunner(); r4 = new LuaRunner("return redis.call(KEYS[1])"); - r4.Compile(); + r4.CompileForRunner(); } [GlobalCleanup] @@ -55,18 +53,18 @@ public void GlobalCleanup() [Benchmark] public void Script1() - => r1.Run(); + => r1.RunForRunner(); [Benchmark] public void Script2() - => r2.Run(); + => r2.RunForRunner(); [Benchmark] public void Script3() - => r3.Run(keys, null); + => r3.RunForRunner(keys, null); [Benchmark] public void Script4() - => r4.Run(keys, null); + => r4.RunForRunner(keys, null); } } \ No newline at end of file diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs index 8d58631fe9..be4fa09608 100644 --- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs +++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs @@ -39,7 +39,7 @@ public IEnumerable OperationParamsProvider() /// 25 us = 4 Mops/sec /// 100 us = 1 Mops/sec /// - const int batchSize = 100; + internal const int batchSize = 100; internal EmbeddedRespServer server; internal RespServerSession session; diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs index f068a7c0ac..26436bfea9 100644 --- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs +++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System.Runtime.CompilerServices; +using System.Security.Cryptography; +using System.Text; using BenchmarkDotNet.Attributes; namespace BDN.benchmark.Operations @@ -11,6 +14,124 @@ namespace BDN.benchmark.Operations [MemoryDiagnoser] public unsafe class ScriptOperations : OperationsBase { + // Small script that does 1 operation and no logic + const string SmallScriptText = @"return redis.call('GET', KEYS[1]);"; + + // Large script that does 2 operations and lots of logic + const string LargeScriptText = @" +-- based on a real UDF, with some extraneous ops removed + +local userKey = KEYS[1] +local identifier = ARGV[1] +local currentTime = ARGV[2] +local newSession = ARGV[3] + +local userKeyValue = redis.call(""GET"", userKey) + +local updatedValue = nil +local returnValue = -1 + +if userKeyValue then +-- needs to be updated + + local oldestEntry = nil + local oldestEntryUpdateTime = nil + + local match = nil + + local entryCount = 0 + +-- loop over each entry, looking for one to update + for entry in string.gmatch(userKeyValue, ""([^%|]+)"") do + entryCount = entryCount + 1 + + local entryIdentifier = nil + local entrySessionNumber = -1 + local entryRequestCount = -1 + local entryLastSessionUpdateTime = -1 + + local ix = 0 + for part in string.gmatch(entry, ""([^:]+)"") do + if ix == 0 then + entryIdentifier = part + elseif ix == 1 then + entrySessionNumber = tonumber(part) + elseif ix == 2 then + entryRequestCount = tonumber(part) + elseif ix == 3 then + entryLastSessionUpdateTime = tonumber(part) + else +-- malformed, too many parts + return -1 + end + + ix = ix + 1 + end + + if ix ~= 4 then +-- malformed, too few parts + return -2 + end + + if entryIdentifier == identifier then +-- found the one to update + local updatedEntry = nil + + if tonumber(newSession) == 1 then + local updatedSessionNumber = entrySessionNumber + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(updatedSessionNumber) .. "":1:"" .. tostring(currentTime) + returnValue = 3 + else + local updatedRequestCount = entryRequestCount + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(entrySessionNumber) .. "":"" .. tostring(updatedRequestCount) .. "":"" .. tostring(currentTime) + returnValue = 2 + end + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedEntry = string.gsub(entry, ""%p"", ""%%%1"") + updatedValue = string.gsub(userKeyValue, escapedEntry, updatedEntry) + + break + end + + if oldestEntryUpdateTime == nil or oldestEntryUpdateTime > entryLastSessionUpdateTime then +-- remember the oldest entry, so we can replace it if needed + oldestEntry = entry + oldestEntryUpdateTime = entryLastSessionUpdateTime + end + end + + if updatedValue == nil then +-- we didn't update an existing value, so we need to add it + + local newEntry = identifier .. "":1:1:"" .. tostring(currentTime) + + if entryCount < 20 then +-- there's room, just append it + updatedValue = userKeyValue .. ""|"" .. newEntry + returnValue = 4 + else +-- there isn't room, replace the LRU entry + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedOldestEntry = string.gsub(oldestEntry, ""%p"", ""%%%1"") + + updatedValue = string.gsub(userKeyValue, escapedOldestEntry, newEntry) + returnValue = 5 + end + end +else +-- needs to be created + updatedValue = identifier .. "":1:1:"" .. tostring(currentTime) + + returnValue = 1 +end + +redis.call(""SET"", userKey, updatedValue) + +return returnValue +"; + static ReadOnlySpan SCRIPT_LOAD => "*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n$8\r\nreturn 1\r\n"u8; byte[] scriptLoadRequestBuffer; byte* scriptLoadRequestBufferPointer; @@ -33,6 +154,16 @@ public unsafe class ScriptOperations : OperationsBase byte[] evalShaRequestBuffer; byte* evalShaRequestBufferPointer; + byte[] evalShaSmallScriptBuffer; + byte* evalShaSmallScriptBufferPointer; + + byte[] evalShaLargeScriptBuffer; + byte* evalShaLargeScriptBufferPointer; + + static ReadOnlySpan ARRAY_RETURN => "*3\r\n$4\r\nEVAL\r\n$22\r\nreturn {1, 2, 3, 4, 5}\r\n$1\r\n0\r\n"u8; + byte[] arrayReturnRequestBuffer; + byte* arrayReturnRequestBufferPointer; + public override void GlobalSetup() { base.GlobalSetup(); @@ -50,6 +181,59 @@ public override void GlobalSetup() SetupOperation(ref evalRequestBuffer, ref evalRequestBufferPointer, EVAL); SetupOperation(ref evalShaRequestBuffer, ref evalShaRequestBufferPointer, EVALSHA); + + SetupOperation(ref arrayReturnRequestBuffer, ref arrayReturnRequestBufferPointer, ARRAY_RETURN); + + // Setup small script + var loadSmallScript = $"*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n${SmallScriptText.Length}\r\n{SmallScriptText}\r\n"; + var loadSmallScriptBytes = Encoding.UTF8.GetBytes(loadSmallScript); + fixed (byte* loadPtr = loadSmallScriptBytes) + { + _ = session.TryConsumeMessages(loadPtr, loadSmallScriptBytes.Length); + } + + var smallScriptHash = string.Join("", SHA1.HashData(Encoding.UTF8.GetBytes(SmallScriptText)).Select(static x => x.ToString("x2"))); + var evalShaSmallScript = $"*4\r\n$7\r\nEVALSHA\r\n$40\r\n{smallScriptHash}\r\n$1\r\n1\r\n$3\r\nfoo\r\n"; + evalShaSmallScriptBuffer = GC.AllocateUninitializedArray(evalShaSmallScript.Length * batchSize, pinned: true); + for (var i = 0; i < batchSize; i++) + { + var start = i * evalShaSmallScript.Length; + Encoding.UTF8.GetBytes(evalShaSmallScript, evalShaSmallScriptBuffer.AsSpan().Slice(start, evalShaSmallScript.Length)); + } + evalShaSmallScriptBufferPointer = (byte*)Unsafe.AsPointer(ref evalShaSmallScriptBuffer[0]); + + // Setup large script + var loadLargeScript = $"*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n${LargeScriptText.Length}\r\n{LargeScriptText}\r\n"; + var loadLargeScriptBytes = Encoding.UTF8.GetBytes(loadLargeScript); + fixed (byte* loadPtr = loadLargeScriptBytes) + { + _ = session.TryConsumeMessages(loadPtr, loadLargeScriptBytes.Length); + } + + var largeScriptHash = string.Join("", SHA1.HashData(Encoding.UTF8.GetBytes(LargeScriptText)).Select(static x => x.ToString("x2"))); + var largeScriptEvals = new List(); + for (var i = 0; i < batchSize; i++) + { + var evalShaLargeScript = $"*7\r\n$7\r\nEVALSHA\r\n$40\r\n{largeScriptHash}\r\n$1\r\n1\r\n$5\r\nhello\r\n"; + + var id = Guid.NewGuid().ToString(); + evalShaLargeScript += $"${id.Length}\r\n"; + evalShaLargeScript += $"{id}\r\n"; + + var time = (i * 10).ToString(); + evalShaLargeScript += $"${time.Length}\r\n"; + evalShaLargeScript += $"{time}\r\n"; + + var newSession = i % 2; + evalShaLargeScript += "$1\r\n"; + evalShaLargeScript += $"{newSession}\r\n"; + + var asBytes = Encoding.UTF8.GetBytes(evalShaLargeScript); + largeScriptEvals.AddRange(asBytes); + } + evalShaLargeScriptBuffer = GC.AllocateUninitializedArray(largeScriptEvals.Count, pinned: true); + largeScriptEvals.CopyTo(evalShaLargeScriptBuffer); + evalShaLargeScriptBufferPointer = (byte*)Unsafe.AsPointer(ref evalShaLargeScriptBuffer[0]); } [Benchmark] @@ -81,5 +265,23 @@ public void EvalSha() { _ = session.TryConsumeMessages(evalShaRequestBufferPointer, evalShaRequestBuffer.Length); } + + [Benchmark] + public void SmallScript() + { + _ = session.TryConsumeMessages(evalShaSmallScriptBufferPointer, evalShaSmallScriptBuffer.Length); + } + + [Benchmark] + public void LargeScript() + { + _ = session.TryConsumeMessages(evalShaLargeScriptBufferPointer, evalShaLargeScriptBuffer.Length); + } + + [Benchmark] + public void ArrayReturn() + { + _ = session.TryConsumeMessages(arrayReturnRequestBufferPointer, arrayReturnRequestBuffer.Length); + } } } \ No newline at end of file diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs index 8b93be12e0..a82ed7ee87 100644 --- a/libs/common/RespReadUtils.cs +++ b/libs/common/RespReadUtils.cs @@ -957,50 +957,6 @@ public static bool ReadStringArrayWithLengthHeader(out string[] result, ref byte return true; } - /// - /// Read string array with length header - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool ReadStringArrayResponseWithLengthHeader(out string[] result, ref byte* ptr, byte* end) - { - result = null; - - // Parse RESP array header - if (!ReadSignedArrayLength(out var length, ref ptr, end)) - { - return false; - } - - if (length < 0) - { - // NULL value ('*-1\r\n') - return true; - } - - // Parse individual strings in the array - result = new string[length]; - for (var i = 0; i < length; i++) - { - if (*ptr == '$') - { - if (!ReadStringResponseWithLengthHeader(out result[i], ref ptr, end)) - return false; - } - else if (*ptr == '+') - { - if (!ReadSimpleString(out result[i], ref ptr, end)) - return false; - } - else - { - if (!ReadIntegerAsString(out result[i], ref ptr, end)) - return false; - } - } - - return true; - } - /// /// Read double with length header /// diff --git a/libs/host/Garnet.host.csproj b/libs/host/Garnet.host.csproj index a4d034dd85..1c1259fec3 100644 --- a/libs/host/Garnet.host.csproj +++ b/libs/host/Garnet.host.csproj @@ -29,7 +29,7 @@ - + diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 68861d3985..e2106bfaa7 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -7,7 +7,6 @@ using System.Runtime.CompilerServices; using System.Text; using Garnet.common; -using NLua; namespace Garnet.server { @@ -43,6 +42,12 @@ public ScratchBufferManager() /// public void Reset() => scratchBufferOffset = 0; + /// + /// Return the full buffer managed by this . + /// + public Span FullBuffer() + => scratchBuffer; + /// /// Rewind (pop) the last entry of scratch buffer (rewinding the current scratch buffer offset), /// if it contains the given ArgSlice @@ -218,95 +223,62 @@ public ArgSlice FormatScratch(int headerSize, ReadOnlySpan arg) } /// - /// Format specified command with arguments, as a RESP command. Lua state - /// can be specified to handle Lua tables as arguments. + /// Start a RESP array to hold a command and arguments. + /// + /// Fill it with calls to and/or . /// - public ArgSlice FormatCommandAsResp(string cmd, object[] args, Lua state) + public void StartCommand(ReadOnlySpan cmd, int argCount) { if (scratchBuffer == null) ExpandScratchBuffer(64); - scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected - int commandStartOffset = scratchBufferOffset; - byte* ptr = scratchBufferHead + scratchBufferOffset; + var ptr = scratchBufferHead + scratchBufferOffset; - while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) + while (!RespWriteUtils.WriteArrayLength(argCount + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) { ExpandScratchBuffer(scratchBuffer.Length + 1); ptr = scratchBufferHead + scratchBufferOffset; } scratchBufferOffset = (int)(ptr - scratchBufferHead); - while (!RespWriteUtils.WriteAsciiBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) + while (!RespWriteUtils.WriteBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) { ExpandScratchBuffer(scratchBuffer.Length + 1); ptr = scratchBufferHead + scratchBufferOffset; } scratchBufferOffset = (int)(ptr - scratchBufferHead); + } - int count = 1; - foreach (var item in args) + /// + /// Use to fill a RESP array with arguments after a call to . + /// + public void WriteNullArgument() + { + var ptr = scratchBufferHead + scratchBufferOffset; + + while (!RespWriteUtils.WriteNull(ref ptr, scratchBufferHead + scratchBuffer.Length)) { - if (item is string str) - { - count++; - while (!RespWriteUtils.WriteAsciiBulkString(str, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - else if (item is LuaTable t) - { - var d = state.GetTableDict(t); - foreach (var value in d.Values) - { - count++; - while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(value), ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - t.Dispose(); - } - else if (item is long i) - { - count++; - while (!RespWriteUtils.WriteIntegerAsBulkString((int)i, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - else - { - count++; - while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(item), ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; } - if (count != args.Length + 1) + + scratchBufferOffset = (int)(ptr - scratchBufferHead); + } + + /// + /// Use to fill a RESP array with arguments after a call to . + /// + public void WriteArgument(ReadOnlySpan arg) + { + var ptr = scratchBufferHead + scratchBufferOffset; + + while (!RespWriteUtils.WriteBulkString(arg, ref ptr, scratchBufferHead + scratchBuffer.Length)) { - int extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); - if (commandStartOffset < extraSpace) - throw new InvalidOperationException("Invalid number of arguments"); - commandStartOffset -= extraSpace; - var head = scratchBufferHead + commandStartOffset; - // There should be space as we have reserved it - _ = RespWriteUtils.WriteArrayLength(count, ref head, scratchBufferHead + scratchBuffer.Length); + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; } - var retVal = new ArgSlice(scratchBufferHead + commandStartOffset, scratchBufferOffset - commandStartOffset); - Debug.Assert(scratchBufferOffset <= scratchBuffer.Length); - return retVal; + scratchBufferOffset = (int)(ptr - scratchBufferHead); } /// @@ -351,5 +323,20 @@ public ArgSlice GetSliceFromTail(int length) { return new ArgSlice(scratchBufferHead + scratchBufferOffset - length, length); } + + /// + /// Force backing buffer to grow. + /// + public void GrowBuffer() + { + if (scratchBuffer == null) + { + ExpandScratchBuffer(64); + } + else + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + } + } } } \ No newline at end of file diff --git a/libs/server/Garnet.server.csproj b/libs/server/Garnet.server.csproj index e3abc5b5c2..04f91c2909 100644 --- a/libs/server/Garnet.server.csproj +++ b/libs/server/Garnet.server.csproj @@ -21,7 +21,7 @@ - + \ No newline at end of file diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index f306262719..f9a4df47f8 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -2,11 +2,9 @@ // Licensed under the MIT license. using System; -using System.Buffers; +using System.Collections.Generic; using Garnet.common; using Microsoft.Extensions.Logging; -using NLua; -using NLua.Exceptions; using Tsavorite.core; namespace Garnet.server @@ -31,22 +29,27 @@ private unsafe bool TryEVALSHA() } ref var digest = ref parseState.GetArgSliceByRef(0); - AsciiUtils.ToLowerInPlace(digest.Span); - var digestAsSpanByteMem = new SpanByteAndMemory(digest.SpanByte); + LuaRunner runner = null; - var result = false; - if (!sessionScriptCache.TryGetFromDigest(digestAsSpanByteMem, out var runner)) + // Length check is mandatory, as ScriptHashKey assumes correct length + if (digest.length == SessionScriptCache.SHA1Len) { - if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source)) + AsciiUtils.ToLowerInPlace(digest.Span); + + var scriptKey = new ScriptHashKey(digest.Span); + + if (!sessionScriptCache.TryGetFromDigest(scriptKey, out runner)) { - if (!sessionScriptCache.TryLoad(source, digestAsSpanByteMem, out runner, out var error)) + if (storeWrapper.storeScriptCache.TryGetValue(scriptKey, out var source)) { - while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) - SendAndReset(); + if (!sessionScriptCache.TryLoad(this, source, scriptKey, out runner, out _, out var error)) + { + // TryLoad will have written an error out, it any - _ = storeWrapper.storeScriptCache.TryRemove(digestAsSpanByteMem, out _); - return result; + _ = storeWrapper.storeScriptCache.TryRemove(scriptKey, out _); + return true; + } } } } @@ -58,10 +61,10 @@ private unsafe bool TryEVALSHA() } else { - result = ExecuteScript(count - 1, runner); + ExecuteScript(count - 1, runner); } - return result; + return true; } @@ -82,19 +85,16 @@ private unsafe bool TryEVAL() return AbortWithWrongNumberOfArguments("EVAL"); } - var script = parseState.GetArgSliceByRef(0).ToArray(); + ref var script = ref parseState.GetArgSliceByRef(0); // that this is stack allocated is load bearing - if it moves, things will break Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; - sessionScriptCache.GetScriptDigest(script, digest); + sessionScriptCache.GetScriptDigest(script.ReadOnlySpan, digest); - var result = false; - if (!sessionScriptCache.TryLoad(script, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out var error)) + if (!sessionScriptCache.TryLoad(this, script.ReadOnlySpan, new ScriptHashKey(digest), out var runner, out _, out var error)) { - while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) - SendAndReset(); - - return result; + // TryLoad will have written any errors out + return true; } if (runner == null) @@ -104,10 +104,10 @@ private unsafe bool TryEVAL() } else { - result = ExecuteScript(count - 1, runner); + ExecuteScript(count - 1, runner); } - return result; + return true; } /// @@ -125,7 +125,7 @@ private bool NetworkScriptExists() return AbortWithWrongNumberOfArguments("script|exists"); } - // returns an array where each element is a 0 if the script does not exist, and a 1 if it does + // Returns an array where each element is a 0 if the script does not exist, and a 1 if it does while (!RespWriteUtils.WriteArrayLength(parseState.Count, ref dcurr, dend)) SendAndReset(); @@ -133,11 +133,17 @@ private bool NetworkScriptExists() for (var shaIx = 0; shaIx < parseState.Count; shaIx++) { ref var sha1 = ref parseState.GetArgSliceByRef(shaIx); - AsciiUtils.ToLowerInPlace(sha1.Span); + var exists = 0; - var sha1Arg = new SpanByteAndMemory(sha1.SpanByte); + // Length check is required, as ScriptHashKey makes a hard assumption + if (sha1.length == SessionScriptCache.SHA1Len) + { + AsciiUtils.ToLowerInPlace(sha1.Span); + + var sha1Arg = new ScriptHashKey(sha1.Span); - var exists = storeWrapper.storeScriptCache.ContainsKey(sha1Arg) ? 1 : 0; + exists = storeWrapper.storeScriptCache.ContainsKey(sha1Arg) ? 1 : 0; + } while (!RespWriteUtils.WriteArrayItem(exists, ref dcurr, dend)) SendAndReset(); @@ -202,18 +208,26 @@ private bool NetworkScriptLoad() return AbortWithWrongNumberOfArguments("script|load"); } - var source = parseState.GetArgSliceByRef(0).ToArray(); - if (!sessionScriptCache.TryLoad(source, out var digest, out _, out var error)) - { - while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) - SendAndReset(); - } - else + ref var source = ref parseState.GetArgSliceByRef(0); + + Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; + sessionScriptCache.GetScriptDigest(source.Span, digest); + + if (sessionScriptCache.TryLoad(this, source.ReadOnlySpan, new(digest), out _, out var digestOnHeap, out var error)) { + // TryLoad will write any errors out // Add script to the store dictionary - var scriptKey = new SpanByteAndMemory(new ScriptHashOwner(digest.AsMemory()), digest.Length); - _ = storeWrapper.storeScriptCache.TryAdd(scriptKey, source); + if (digestOnHeap == null) + { + var newAlloc = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + digest.CopyTo(newAlloc); + _ = storeWrapper.storeScriptCache.TryAdd(new(newAlloc), source.ToArray()); + } + else + { + _ = storeWrapper.storeScriptCache.TryAdd(digestOnHeap.Value, source.ToArray()); + } while (!RespWriteUtils.WriteBulkString(digest, ref dcurr, dend)) SendAndReset(); @@ -243,120 +257,17 @@ private bool CheckLuaEnabled() /// /// Invoke the execution of a server-side Lua script. /// - /// - /// - /// - private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner) + private void ExecuteScript(int count, LuaRunner scriptRunner) { try { - var scriptResult = scriptRunner.Run(count, parseState); - WriteObject(scriptResult); - } - catch (LuaScriptException ex) - { - logger?.LogError(ex.InnerException ?? ex, "Error executing Lua script callback"); - while (!RespWriteUtils.WriteError("ERR " + (ex.InnerException ?? ex).Message, ref dcurr, dend)) - SendAndReset(); - return true; + scriptRunner.RunForSession(count, this); } catch (Exception ex) { logger?.LogError(ex, "Error executing Lua script"); while (!RespWriteUtils.WriteError("ERR " + ex.Message, ref dcurr, dend)) SendAndReset(); - return true; - } - return true; - } - - void WriteObject(object scriptResult) - { - if (scriptResult != null) - { - if (scriptResult is string s) - { - while (!RespWriteUtils.WriteAsciiBulkString(s, ref dcurr, dend)) - SendAndReset(); - } - else if ((scriptResult as byte?) != null && (byte)scriptResult == 36) //equals to $ - { - while (!RespWriteUtils.WriteDirect((byte[])scriptResult, ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is bool b) - { - if (b) - { - while (!RespWriteUtils.WriteInteger(1, ref dcurr, dend)) - SendAndReset(); - } - else - { - while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) - SendAndReset(); - } - } - else if (scriptResult is long l) - { - while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is ArgSlice a) - { - while (!RespWriteUtils.WriteBulkString(a.ReadOnlySpan, ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is object[] o) - { - // Two objects one boolean value and the result from the Lua Call - while (!RespWriteUtils.WriteAsciiBulkString(o[1].ToString().AsSpan(), ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is LuaTable luaTable) - { - try - { - var retVal = luaTable["err"]; - if (retVal != null) - { - while (!RespWriteUtils.WriteError((string)retVal, ref dcurr, dend)) - SendAndReset(); - } - else - { - retVal = luaTable["ok"]; - if (retVal != null) - { - while (!RespWriteUtils.WriteAsciiBulkString((string)retVal, ref dcurr, dend)) - SendAndReset(); - } - else - { - int count = luaTable.Values.Count; - while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) - SendAndReset(); - foreach (var value in luaTable.Values) - { - WriteObject(value); - } - } - } - } - finally - { - luaTable.Dispose(); - } - } - else - { - throw new LuaScriptException("Unknown return type", ""); - } - } - else - { - while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) - SendAndReset(); } } } diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index ea2ef3f0dd..38851209da 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -2,11 +2,14 @@ // Licensed under the MIT license. using System; -using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using Garnet.common; +using KeraLua; using Microsoft.Extensions.Logging; -using NLua; namespace Garnet.server { @@ -15,121 +18,439 @@ namespace Garnet.server /// internal sealed class LuaRunner : IDisposable { - readonly string source; + /// + /// Adapter to allow us to write directly to the network + /// when in Garnet and still keep script runner work. + /// + private unsafe interface IResponseAdapter + { + /// + /// Equivalent to the ref curr we pass into methods. + /// + ref byte* BufferCur { get; } + + /// + /// Equivalent to the end we pass into methods. + /// + byte* BufferEnd { get; } + + /// + /// Equivalent to . + /// + void SendAndReset(); + } + + /// + /// Adapter so script results go directly + /// to the network. + /// + private readonly struct RespResponseAdapter : IResponseAdapter + { + private readonly RespServerSession session; + + internal RespResponseAdapter(RespServerSession session) + { + this.session = session; + } + + /// + public unsafe ref byte* BufferCur + => ref session.dcurr; + + /// + public unsafe byte* BufferEnd + => session.dend; + + /// + public void SendAndReset() + => session.SendAndReset(); + } + + /// + /// For the runner, put output into an array. + /// + private unsafe struct RunnerAdapter : IResponseAdapter + { + private readonly ScratchBufferManager bufferManager; + private byte* origin; + private byte* curHead; + private byte* curEnd; + + internal RunnerAdapter(ScratchBufferManager bufferManager) + { + this.bufferManager = bufferManager; + this.bufferManager.Reset(); + + var scratchSpace = bufferManager.FullBuffer(); + + origin = curHead = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace)); + curEnd = curHead + scratchSpace.Length; + } + +#pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference + /// + public unsafe ref byte* BufferCur + => ref curHead; +#pragma warning restore CS9084 + + /// + public unsafe byte* BufferEnd + => curEnd; + + /// + /// Gets a span that covers the responses as written so far. + /// + public readonly ReadOnlySpan Response + { + get + { + var len = (int)(curHead - origin); + + var full = bufferManager.FullBuffer(); + + return full[..len]; + } + } + + /// + public void SendAndReset() + { + var len = (int)(curHead - origin); + + // We don't actually send anywhere, we grow the backing array + bufferManager.GrowBuffer(); + + var scratchSpace = bufferManager.FullBuffer(); + + origin = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace)); + curEnd = origin + scratchSpace.Length; + curHead = origin + len; + } + } + + const string LoaderBlock = @" +import = function () end +redis = {} +function redis.call(...) + return garnet_call(...) +end +function redis.status_reply(text) + return text +end +function redis.error_reply(text) + return { err = 'ERR ' .. text } +end +KEYS = {} +ARGV = {} +sandbox_env = { + _G = _G; + _VERSION = _VERSION; + + assert = assert; + collectgarbage = collectgarbage; + coroutine = coroutine; + error = error; + gcinfo = gcinfo; + -- explicitly not allowing getfenv + getmetatable = getmetatable; + ipairs = ipairs; + load = load; + loadstring = loadstring; + math = math; + next = next; + pairs = pairs; + pcall = pcall; + rawequal = rawequal; + rawget = rawget; + rawset = rawset; + redis = redis; + select = select; + -- explicitly not allowing setfenv + string = string; + setmetatable = setmetatable; + table = table; + tonumber = tonumber; + tostring = tostring; + type = type; + unpack = table.unpack; + xpcall = xpcall; + + KEYS = KEYS; + ARGV = ARGV; +} +-- do resets in the Lua side to minimize pinvokes +function reset_keys_and_argv(fromKey, fromArgv) + local keyCount = #KEYS + for i = fromKey, keyCount do + KEYS[i] = nil + end + + local argvCount = #ARGV + for i = fromArgv, argvCount do + ARGV[i] = nil + end +end +-- responsible for sandboxing user provided code +function load_sandboxed(source) + if (not source) then return nil end + local rawFunc, err = load(source, nil, nil, sandbox_env) + + -- compilation error is returned directly + if err then + return rawFunc, err + end + + -- otherwise we wrap the compiled function in a helper + return function() + local rawRet = rawFunc() + + -- handle ok response wrappers without crossing the pinvoke boundary + -- err response wrapper requires a bit more work, but is also rarer + if rawRet and type(rawRet) == ""table"" and rawRet.ok then + return rawRet.ok + end + + return rawRet + end +end +"; + + private static readonly ReadOnlyMemory LoaderBlockBytes = Encoding.UTF8.GetBytes(LoaderBlock); + + // Rooted to keep function pointer alive + readonly LuaFunction garnetCall; + + // References into Registry on the Lua side + // + // These are mix of objects we regularly update, + // constants we want to avoid copying from .NET to Lua, + // and the compiled function definition. + readonly int sandboxEnvRegistryIndex; + readonly int keysTableRegistryIndex; + readonly int argvTableRegistryIndex; + readonly int loadSandboxedRegistryIndex; + readonly int resetKeysAndArgvRegistryIndex; + readonly int okConstStringRegistryIndex; + readonly int errConstStringRegistryIndex; + readonly int noSessionAvailableConstStringRegistryIndex; + readonly int pleaseSpecifyRedisCallConstStringRegistryIndex; + readonly int errNoAuthConstStringRegistryIndex; + readonly int errUnknownConstStringRegistryIndex; + readonly int errBadArgConstStringRegistryIndex; + int functionRegistryIndex; + + readonly ReadOnlyMemory source; readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly RespServerSession respServerSession; + readonly ScratchBufferManager scratchBufferManager; readonly ILogger logger; - readonly Lua state; - readonly LuaTable sandbox_env; - LuaFunction function; readonly TxnKeyEntries txnKeyEntries; readonly bool txnMode; - readonly LuaFunction garnetCall; - readonly LuaTable keyTable, argvTable; + + // This cannot be readonly, as it is a mutable struct + LuaStateWrapper state; + int keyLength, argvLength; - Queue disposeQueue; /// /// Creates a new runner with the source of the script /// - public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) + public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) { this.source = source; this.txnMode = txnMode; this.respServerSession = respServerSession; this.scratchBufferNetworkSender = scratchBufferNetworkSender; - this.scratchBufferManager = respServerSession?.scratchBufferManager; + this.scratchBufferManager = respServerSession?.scratchBufferManager ?? new(); this.logger = logger; - state = new Lua(); - state.State.Encoding = Encoding.UTF8; + sandboxEnvRegistryIndex = -1; + keysTableRegistryIndex = -1; + argvTableRegistryIndex = -1; + loadSandboxedRegistryIndex = -1; + functionRegistryIndex = -1; + + // TODO: custom allocator? + state = new LuaStateWrapper(new Lua()); + if (txnMode) { - this.txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); - garnetCall = state.RegisterFunction("garnet_call", this, this.GetType().GetMethod(nameof(garnet_call_txn))); + txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); + + garnetCall = garnet_call_txn; } else { - garnetCall = state.RegisterFunction("garnet_call", this, this.GetType().GetMethod("garnet_call")); - } - _ = state.DoString(@" - import = function () end - redis = {} - function redis.call(cmd, ...) - return garnet_call(cmd, ...) - end - function redis.status_reply(text) - return text - end - function redis.error_reply(text) - return { err = text } - end - KEYS = {} - ARGV = {} - sandbox_env = { - tostring = tostring; - next = next; - assert = assert; - tonumber = tonumber; - rawequal = rawequal; - collectgarbage = collectgarbage; - coroutine = coroutine; - type = type; - select = select; - unpack = table.unpack; - gcinfo = gcinfo; - pairs = pairs; - loadstring = loadstring; - ipairs = ipairs; - error = error; - redis = redis; - math = math; - table = table; - string = string; - KEYS = KEYS; - ARGV = ARGV; - } - function load_sandboxed(source) - if (not source) then return nil end - return load(source, nil, nil, sandbox_env) - end - "); - sandbox_env = (LuaTable)state["sandbox_env"]; - keyTable = (LuaTable)state["KEYS"]; - argvTable = (LuaTable)state["ARGV"]; + garnetCall = garnet_call; + } + + var loadRes = state.LoadBuffer(LoaderBlockBytes.Span); + if (loadRes != LuaStatus.OK) + { + throw new GarnetException("Could load loader into Lua"); + } + + var sandboxRes = state.PCall(0, -1); + if (sandboxRes != LuaStatus.OK) + { + throw new GarnetException("Could not initialize Lua sandbox state"); + } + + // Register garnet_call in global namespace + state.Register("garnet_call", garnetCall); + + state.GetGlobal(LuaType.Table, "sandbox_env"); + sandboxEnvRegistryIndex = state.Ref(); + + state.GetGlobal(LuaType.Table, "KEYS"); + keysTableRegistryIndex = state.Ref(); + + state.GetGlobal(LuaType.Table, "ARGV"); + argvTableRegistryIndex = state.Ref(); + + state.GetGlobal(LuaType.Function, "load_sandboxed"); + loadSandboxedRegistryIndex = state.Ref(); + + state.GetGlobal(LuaType.Function, "reset_keys_and_argv"); + resetKeysAndArgvRegistryIndex = state.Ref(); + + // Commonly used strings, register them once so we don't have to copy them over each time we need them + okConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_OK); + errConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_err); + noSessionAvailableConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_No_session_available); + pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call); + errNoAuthConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.RESP_ERR_NOAUTH); + errUnknownConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Unknown_Redis_command_called_from_script); + errBadArgConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers); + + state.ExpectLuaStackEmpty(); } /// /// Creates a new runner with the source of the script /// - public LuaRunner(ReadOnlySpan source, bool txnMode, RespServerSession respServerSession, ScratchBufferNetworkSender scratchBufferNetworkSender, ILogger logger = null) - : this(Encoding.UTF8.GetString(source), txnMode, respServerSession, scratchBufferNetworkSender, logger) + public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) + : this(Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, logger) + { + } + + /// + /// Some strings we use a bunch, and copying them to Lua each time is wasteful + /// + /// So instead we stash them in the Registry and load them by index + /// + int ConstantStringToRegistry(ReadOnlySpan str) + { + state.PushBuffer(str); + return state.Ref(); + } + + /// + /// Compile script for running in a .NET host. + /// + /// Errors are raised as exceptions. + /// + public unsafe void CompileForRunner() + { + var adapter = new RunnerAdapter(scratchBufferManager); + CompileCommon(ref adapter); + + var resp = adapter.Response; + var respStart = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); + var respEnd = respStart + resp.Length; + if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respStart, respEnd)) + { + var errStr = Encoding.UTF8.GetString(errSpan); + throw new GarnetException(errStr); + } + } + + /// + /// Compile script for a . + /// + /// Any errors encountered are written out as Resp errors. + /// + public void CompileForSession(RespServerSession session) { + var adapter = new RespResponseAdapter(session); + CompileCommon(ref adapter); } /// - /// Compile script + /// Drops compiled function, just for benchmarking purposes. /// - public void Compile() + public void ResetCompilation() { + if (functionRegistryIndex != -1) + { + state.Unref(LuaRegistry.Index, functionRegistryIndex); + functionRegistryIndex = -1; + } + } + + /// + /// Compile script, writing errors out to given response. + /// + unsafe void CompileCommon(ref TResponse resp) + where TResponse : struct, IResponseAdapter + { + const int NeededStackSpace = 2; + + Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); + + state.ExpectLuaStackEmpty(); + try { - using var loader = (LuaFunction)state["load_sandboxed"]; - var result = loader.Call(source); - if (result?.Length == 1) + state.ForceMinimumStackCapacity(NeededStackSpace); + + state.PushInteger(loadSandboxedRegistryIndex); + _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index); + + state.PushBuffer(source.Span); + state.Call(1, -1); // Multiple returns allowed + + var numRets = state.StackTop; + + if (numRets == 0) { - function = result[0] as LuaFunction; + while (!RespWriteUtils.WriteError("Shouldn't happen, no returns from load_sandboxed"u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + return; } + else if (numRets == 1) + { + var returnType = state.Type(1); + if (returnType != LuaType.Function) + { + var errStr = $"Could not compile function, got back a {returnType}"; + while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); - if (result?.Length == 2) + return; + } + + functionRegistryIndex = state.Ref(); + } + else if (numRets == 2) { - throw new GarnetException($"Compilation error: {(string)result[1]}"); + state.CheckBuffer(2, out var errorBuf); + + var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}"; + while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + state.Pop(2); + + return; } else { - throw new GarnetException($"Unable to load script"); + state.Pop(numRets); + + throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}"); } } catch (Exception ex) @@ -137,6 +458,10 @@ public void Compile() logger?.LogError(ex, "CreateFunction threw an exception"); throw; } + finally + { + state.ExpectLuaStackEmpty(); + } } /// @@ -144,184 +469,403 @@ public void Compile() /// public void Dispose() { - garnetCall?.Dispose(); - keyTable?.Dispose(); - argvTable?.Dispose(); - sandbox_env?.Dispose(); - function?.Dispose(); - state?.Dispose(); + state.Dispose(); } /// /// Entry point for redis.call method from a Lua script (non-transactional mode) /// - /// - /// Parameters - /// - public object garnet_call(string cmd, params object[] args) - => respServerSession == null ? null : ProcessCommandFromScripting(respServerSession.basicGarnetApi, cmd, args); + public int garnet_call(IntPtr luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + if (respServerSession == null) + { + return NoSessionResponse(); + } + + return ProcessCommandFromScripting(respServerSession.basicGarnetApi); + } /// /// Entry point for redis.call method from a Lua script (transactional mode) /// - /// - /// Parameters - /// - public object garnet_call_txn(string cmd, params object[] args) - => respServerSession == null ? null : ProcessCommandFromScripting(respServerSession.lockableGarnetApi, cmd, args); + public int garnet_call_txn(IntPtr luaStatePtr) + { + state.CallFromLuaEntered(luaStatePtr); + + if (respServerSession == null) + { + return NoSessionResponse(); + } + + return ProcessCommandFromScripting(respServerSession.lockableGarnetApi); + } + + /// + /// Call somehow came in with no valid resp server session. + /// + /// This is used in benchmarking. + /// + int NoSessionResponse() + { + const int NeededStackSpace = 1; + + state.ForceMinimumStackCapacity(NeededStackSpace); + + state.PushNil(); + return 1; + } /// /// Entry point method for executing commands from a Lua Script /// - unsafe object ProcessCommandFromScripting(TGarnetApi api, string cmd, params object[] args) + unsafe int ProcessCommandFromScripting(TGarnetApi api) where TGarnetApi : IGarnetApi { - switch (cmd) + const int AdditionalStackSpace = 1; + + try { + var argCount = state.StackTop; + + if (argCount == 0) + { + return LuaStaticError(pleaseSpecifyRedisCallConstStringRegistryIndex); + } + + state.ForceMinimumStackCapacity(AdditionalStackSpace); + + if (!state.CheckBuffer(1, out var cmdSpan)) + { + return LuaStaticError(errBadArgConstStringRegistryIndex); + } + // We special-case a few performance-sensitive operations to directly invoke via the storage API - case "SET" when args.Length == 2: - case "set" when args.Length == 2: + if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "SET"u8) && argCount == 3) + { + if (!respServerSession.CheckACLPermissions(RespCommand.SET)) { - if (!respServerSession.CheckACLPermissions(RespCommand.SET)) - return Encoding.ASCII.GetString(CmdStrings.RESP_ERR_NOAUTH); - var key = scratchBufferManager.CreateArgSlice(Convert.ToString(args[0])); - var value = scratchBufferManager.CreateArgSlice(Convert.ToString(args[1])); - _ = api.SET(key, value); - return "OK"; + return LuaStaticError(errNoAuthConstStringRegistryIndex); } - case "GET": - case "get": + + if (!state.CheckBuffer(2, out var keySpan) || !state.CheckBuffer(3, out var valSpan)) { - if (!respServerSession.CheckACLPermissions(RespCommand.GET)) - throw new Exception(Encoding.ASCII.GetString(CmdStrings.RESP_ERR_NOAUTH)); - var key = scratchBufferManager.CreateArgSlice(Convert.ToString(args[0])); - var status = api.GET(key, out var value); - if (status == GarnetStatus.OK) - return value.ToString(); - return null; + return LuaStaticError(errBadArgConstStringRegistryIndex); } + + // Note these spans are implicitly pinned, as they're actually on the Lua stack + var key = ArgSlice.FromPinnedSpan(keySpan); + var value = ArgSlice.FromPinnedSpan(valSpan); + + _ = api.SET(key, value); + + state.PushConstantString(okConstStringRegistryIndex); + return 1; + } + else if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "GET"u8) && argCount == 2) + { + if (!respServerSession.CheckACLPermissions(RespCommand.GET)) + { + return LuaStaticError(errNoAuthConstStringRegistryIndex); + } + + if (!state.CheckBuffer(2, out var keySpan)) + { + return LuaStaticError(errBadArgConstStringRegistryIndex); + } + + // Span is (implicitly) pinned since it's actually on the Lua stack + var key = ArgSlice.FromPinnedSpan(keySpan); + var status = api.GET(key, out var value); + if (status == GarnetStatus.OK) + { + state.PushBuffer(value.ReadOnlySpan); + } + else + { + state.PushNil(); + } + + return 1; + } + // As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized // in future to provide parse state directly. - default: + + scratchBufferManager.Reset(); + scratchBufferManager.StartCommand(cmdSpan, argCount - 1); + + for (var i = 0; i < argCount - 1; i++) + { + var argIx = 2 + i; + + var argType = state.Type(argIx); + if (argType == LuaType.Nil) + { + scratchBufferManager.WriteNullArgument(); + } + else if (argType is LuaType.String or LuaType.Number) + { + // KnownStringToBuffer will coerce a number into a string + // + // Redis nominally converts numbers to integers, but in this case just ToStrings things + state.KnownStringToBuffer(argIx, out var span); + + // Span remains pinned so long as we don't pop the stack + scratchBufferManager.WriteArgument(span); + } + else { - var request = scratchBufferManager.FormatCommandAsResp(cmd, args, state); - _ = respServerSession.TryConsumeMessages(request.ptr, request.length); - var response = scratchBufferNetworkSender.GetResponse(); - var result = ProcessResponse(response.ptr, response.length); - scratchBufferNetworkSender.Reset(); - return result; + return LuaStaticError(errBadArgConstStringRegistryIndex); } + } + + var request = scratchBufferManager.ViewFullArgSlice(); + + // Once the request is formatted, we can release all the args on the Lua stack + // + // This keeps the stack size down for processing the response + state.Pop(argCount); + + _ = respServerSession.TryConsumeMessages(request.ptr, request.length); + + var response = scratchBufferNetworkSender.GetResponse(); + var result = ProcessResponse(response.ptr, response.length); + scratchBufferNetworkSender.Reset(); + return result; + } + catch (Exception e) + { + logger?.LogError(e, "During Lua script execution"); + + return state.RaiseError(e.Message); } } /// - /// Process a RESP-formatted response from the RespServerSession + /// Cause a Lua error to be raised with a message previously registered. /// - unsafe object ProcessResponse(byte* ptr, int length) + int LuaStaticError(int constStringRegistryIndex) { + const int NeededStackSize = 1; + + state.ForceMinimumStackCapacity(NeededStackSize); + + state.PushConstantString(constStringRegistryIndex); + return state.RaiseErrorFromStack(); + } + + /// + /// Process a RESP-formatted response from the RespServerSession. + /// + /// Pushes result onto state stack and returns 1, or raises an error and never returns. + /// + unsafe int ProcessResponse(byte* ptr, int length) + { + const int NeededStackSize = 3; + + state.ForceMinimumStackCapacity(NeededStackSize); + switch (*ptr) { case (byte)'+': - if (RespReadUtils.ReadSimpleString(out var resultStr, ref ptr, ptr + length)) - return resultStr; - break; + ptr++; + length--; + if (RespReadUtils.ReadAsSpan(out var resultSpan, ref ptr, ptr + length)) + { + state.PushBuffer(resultSpan); + return 1; + } + goto default; + case (byte)':': if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) - return number; - break; + { + state.PushInteger(number); + return 1; + } + goto default; + case (byte)'-': - if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) - return resultStr; - break; + ptr++; + length--; + if (RespReadUtils.ReadAsSpan(out var errSpan, ref ptr, ptr + length)) + { + if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD)) + { + // Gets a special response + return LuaStaticError(errUnknownConstStringRegistryIndex); + } + + state.PushBuffer(errSpan); + return state.RaiseErrorFromStack(); + + } + goto default; case (byte)'$': - if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) - return resultStr; - break; + if (length >= 5 && new ReadOnlySpan(ptr + 1, 4).SequenceEqual("-1\r\n"u8)) + { + // Bulk null strings are mapped to FALSE + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + state.PushBoolean(false); + + return 1; + } + else if (RespReadUtils.ReadSpanWithLengthHeader(out var bulkSpan, ref ptr, ptr + length)) + { + state.PushBuffer(bulkSpan); + + return 1; + } + goto default; case (byte)'*': - if (RespReadUtils.ReadStringArrayResponseWithLengthHeader(out var resultArray, ref ptr, ptr + length)) + if (RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref ptr, ptr + length)) { - // Create return table - var returnValue = (LuaTable)state.DoString("return { }")[0]; - - // Queue up for disposal at the end of the script call - disposeQueue ??= new(); - disposeQueue.Enqueue(returnValue); - - // Populate the table - var i = 1; - foreach (var item in resultArray) - returnValue[i++] = item == null ? false : item; - return returnValue; + // Create the new table + state.CreateTable(itemCount, 0); + + for (var itemIx = 0; itemIx < itemCount; itemIx++) + { + if (*ptr == '$') + { + // Bulk String + if (length >= 4 && new ReadOnlySpan(ptr + 1, 4).SequenceEqual("-1\r\n"u8)) + { + // Null strings are mapped to false + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + state.PushBoolean(false); + } + else if (RespReadUtils.ReadSpanWithLengthHeader(out var strSpan, ref ptr, ptr + length)) + { + state.PushBuffer(strSpan); + } + else + { + // Error, drop the table we allocated + state.Pop(1); + goto default; + } + } + else + { + // In practice, we ONLY ever return bulk strings + // So just... not implementing the rest for now + throw new NotImplementedException($"Unexpected sigil: {(char)*ptr}"); + } + + // Stack now has table and value at itemIx on it + state.RawSetInteger(1, itemIx + 1); + } + + return 1; } - break; + goto default; default: throw new Exception("Unexpected response: " + Encoding.UTF8.GetString(new Span(ptr, length)).Replace("\n", "|").Replace("\r", "") + "]"); } - return null; } /// - /// Runs the precompiled Lua function with specified parse state + /// Runs the precompiled Lua function with the given outer session. + /// + /// Response is written directly into the . /// - public object Run(int count, SessionParseState parseState) + public void RunForSession(int count, RespServerSession outerSession) { + const int NeededStackSize = 3; + + state.ForceMinimumStackCapacity(NeededStackSize); + scratchBufferManager.Reset(); - int offset = 1; - int nKeys = parseState.GetInt(offset++); + var parseState = outerSession.parseState; + + var offset = 1; + var nKeys = parseState.GetInt(offset++); count--; ResetParameters(nKeys, count - nKeys); if (nKeys > 0) { - for (int i = 0; i < nKeys; i++) + // Get KEYS on the stack + state.PushInteger(keysTableRegistryIndex); + state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + + for (var i = 0; i < nKeys; i++) { + ref var key = ref parseState.GetArgSliceByRef(offset); + if (txnMode) { - var key = parseState.GetArgSliceByRef(offset); txnKeyEntries.AddKey(key, false, Tsavorite.core.LockType.Exclusive); if (!respServerSession.storageSession.objectStoreLockableContext.IsNull) txnKeyEntries.AddKey(key, true, Tsavorite.core.LockType.Exclusive); } - keyTable[i + 1] = parseState.GetString(offset++); + + // Equivalent to KEYS[i+1] = key + state.PushInteger(i + 1); + state.PushBuffer(key.ReadOnlySpan); + state.RawSet(1); + + offset++; } - count -= nKeys; - //TODO: handle slot verification for Lua script keys - //if (NetworkKeyArraySlotVerify(keys, true)) - //{ - // return true; - //} + // Remove KEYS from the stack + state.Pop(1); + + count -= nKeys; } if (count > 0) { - for (int i = 0; i < count; i++) + // Get ARGV on the stack + state.PushInteger(argvTableRegistryIndex); + state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + + for (var i = 0; i < count; i++) { - argvTable[i + 1] = parseState.GetString(offset++); + ref var argv = ref parseState.GetArgSliceByRef(offset); + + // Equivalent to ARGV[i+1] = argv + state.PushInteger(i + 1); + state.PushBuffer(argv.ReadOnlySpan); + state.RawSet(1); + + offset++; } + + // Remove ARGV from the stack + state.Pop(1); } + var adapter = new RespResponseAdapter(outerSession); + if (txnMode && nKeys > 0) { - return RunTransaction(); + RunInTransaction(ref adapter); } else { - return Run(); + RunCommon(ref adapter); } } /// - /// Runs the precompiled Lua function with specified (keys, argv) state + /// Runs the precompiled Lua function with specified (keys, argv) state. + /// + /// Meant for use from a .NET host rather than in Garnet properly. /// - public object Run(string[] keys = null, string[] argv = null) + public unsafe object RunForRunner(string[] keys = null, string[] argv = null) { scratchBufferManager?.Reset(); - LoadParameters(keys, argv); + LoadParametersForRunner(keys, argv); + + var adapter = new RunnerAdapter(scratchBufferManager); + if (txnMode && keys?.Length > 0) { // Add keys to the transaction @@ -332,15 +876,88 @@ public object Run(string[] keys = null, string[] argv = null) if (!respServerSession.storageSession.objectStoreLockableContext.IsNull) txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive); } - return RunTransaction(); + + RunInTransaction(ref adapter); } else { - return Run(); + RunCommon(ref adapter); + } + + var resp = adapter.Response; + var respCur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); + var respEnd = respCur + resp.Length; + + if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respCur, respEnd)) + { + var errStr = Encoding.UTF8.GetString(errSpan); + throw new GarnetException(errStr); + } + + var ret = MapRespToObject(ref respCur, respEnd); + Debug.Assert(respCur == respEnd, "Should have fully consumed response"); + + return ret; + + static object MapRespToObject(ref byte* cur, byte* end) + { + switch (*cur) + { + case (byte)'+': + var simpleStrRes = RespReadUtils.ReadSimpleString(out var simpleStr, ref cur, end); + Debug.Assert(simpleStrRes, "Should never fail"); + + return simpleStr; + + case (byte)':': + var readIntRes = RespReadUtils.Read64Int(out var int64, ref cur, end); + Debug.Assert(readIntRes, "Should never fail"); + + return int64; + + // Error ('-') is handled before call to MapRespToObject + + case (byte)'$': + var length = end - cur; + + if (length >= 5 && new ReadOnlySpan(cur + 1, 4).SequenceEqual("-1\r\n"u8)) + { + cur += 5; + return null; + } + + var bulkStrRes = RespReadUtils.ReadStringResponseWithLengthHeader(out var bulkStr, ref cur, end); + Debug.Assert(bulkStrRes, "Should never fail"); + + return bulkStr; + + case (byte)'*': + var arrayLengthRes = RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref cur, end); + Debug.Assert(arrayLengthRes, "Should never fail"); + + if (itemCount == 0) + { + return Array.Empty(); + } + + var array = new object[itemCount]; + for (var i = 0; i < array.Length; i++) + { + array[i] = MapRespToObject(ref cur, end); + } + + return array; + + default: throw new NotImplementedException($"Unexpected sigil {(char)*cur}"); + } } } - object RunTransaction() + /// + /// Calls after setting up appropriate state for a transaction. + /// + void RunInTransaction(ref TResponse response) + where TResponse : struct, IResponseAdapter { try { @@ -349,7 +966,8 @@ object RunTransaction() respServerSession.storageSession.objectStoreLockableContext.BeginLockable(); respServerSession.SetTransactionMode(true); txnKeyEntries.LockAllKeys(); - return Run(); + + RunCommon(ref response); } finally { @@ -361,56 +979,341 @@ object RunTransaction() } } - void ResetParameters(int nKeys, int nArgs) + /// + /// Remove extra keys and args from KEYS and ARGV globals. + /// + internal void ResetParameters(int nKeys, int nArgs) { - if (keyLength > nKeys) + const int NeededStackSize = 3; + + state.ForceMinimumStackCapacity(NeededStackSize); + + if (keyLength > nKeys || argvLength > nArgs) { - _ = state.DoString($"count = #KEYS for i={nKeys + 1}, {keyLength} do KEYS[i]=nil end"); + state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex); + + state.PushInteger(nKeys + 1); + state.PushInteger(nArgs + 1); + + var resetRes = state.PCall(2, 0); + Debug.Assert(resetRes == LuaStatus.OK, "Resetting should never fail"); } + keyLength = nKeys; - if (argvLength > nArgs) - { - _ = state.DoString($"count = #ARGV for i={nArgs + 1}, {argvLength} do ARGV[i]=nil end"); - } argvLength = nArgs; } - void LoadParameters(string[] keys, string[] argv) + /// + /// Takes .NET strings for keys and args and pushes them into KEYS and ARGV globals. + /// + void LoadParametersForRunner(string[] keys, string[] argv) { + const int NeededStackSize = 2; + + state.ForceMinimumStackCapacity(NeededStackSize); + ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); + if (keys != null) { - for (int i = 0; i < keys.Length; i++) - keyTable[i + 1] = keys[i]; + // get KEYS on the stack + state.PushInteger(keysTableRegistryIndex); + _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + + for (var i = 0; i < keys.Length; i++) + { + // equivalent to KEYS[i+1] = keys[i] + var key = keys[i]; + PrepareString(key, scratchBufferManager, out var encoded); + state.PushBuffer(encoded); + state.RawSetInteger(1, i + 1); + } + + state.Pop(1); } + if (argv != null) { - for (int i = 0; i < argv.Length; i++) - argvTable[i + 1] = argv[i]; + // get ARGV on the stack + state.PushInteger(argvTableRegistryIndex); + _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index); + + for (var i = 0; i < argv.Length; i++) + { + // equivalent to ARGV[i+1] = keys[i] + var arg = argv[i]; + PrepareString(arg, scratchBufferManager, out var encoded); + state.PushBuffer(encoded); + state.RawSetInteger(1, i + 1); + } + + state.Pop(1); + } + + static void PrepareString(string raw, ScratchBufferManager buffer, out ReadOnlySpan strBytes) + { + var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length); + + buffer.Reset(); + var argSlice = buffer.CreateArgSlice(maxLen); + var span = argSlice.Span; + + var written = Encoding.UTF8.GetBytes(raw, span); + strBytes = span[..written]; } } /// - /// Runs the precompiled Lua function + /// Runs the precompiled Lua function. /// - /// - object Run() + unsafe void RunCommon(ref TResponse resp) + where TResponse : struct, IResponseAdapter { - var result = function.Call(); - Cleanup(); - return result?.Length > 0 ? result[0] : null; - } + const int NeededStackSize = 2; - void Cleanup() - { - if (disposeQueue != null) + // TODO: mapping is dependent on Resp2 vs Resp3 settings + // and that's not implemented at all + + try + { + state.ForceMinimumStackCapacity(NeededStackSize); + + state.PushInteger(functionRegistryIndex); + _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index); + + var callRes = state.PCall(0, 1); + if (callRes == LuaStatus.OK) + { + // The actual call worked, handle the response + + if (state.StackTop == 0) + { + WriteNull(this, ref resp); + return; + } + + var retType = state.Type(1); + var isNullish = retType is LuaType.Nil or LuaType.UserData or LuaType.Function or LuaType.Thread or LuaType.UserData; + + if (isNullish) + { + WriteNull(this, ref resp); + return; + } + else if (retType == LuaType.Number) + { + WriteNumber(this, ref resp); + return; + } + else if (retType == LuaType.String) + { + WriteString(this, ref resp); + return; + } + else if (retType == LuaType.Boolean) + { + WriteBoolean(this, ref resp); + return; + } + else if (retType == LuaType.Table) + { + // Redis does not respect metatables, so RAW access is ok here + + // If the key err is in there, we need to short circuit + state.PushConstantString(errConstStringRegistryIndex); + + var errType = state.RawGet(null, 1); + if (errType == LuaType.String) + { + WriteError(this, ref resp); + + // Remove table from stack + state.Pop(1); + + return; + } + + // Remove whatever we read from the table under the "err" key + state.Pop(1); + + // Map this table to an array + WriteArray(this, ref resp); + } + } + else + { + // An error was raised + + if (state.StackTop == 0) + { + while (!RespWriteUtils.WriteError("ERR An error occurred while invoking a Lua script"u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; + } + else if (state.StackTop == 1) + { + if (state.CheckBuffer(1, out var errBuf)) + { + while (!RespWriteUtils.WriteError(errBuf, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + } + + state.Pop(1); + + return; + } + else + { + logger?.LogError("Got an unexpected number of values back from a pcall error {callRes}", callRes); + + while (!RespWriteUtils.WriteError("ERR Unexpected error response"u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + state.ClearStack(); + + return; + } + } + } + finally + { + state.ExpectLuaStackEmpty(); + } + + // Write a null RESP value, remove the top value on the stack if there is one + static void WriteNull(LuaRunner runner, ref TResponse resp) { - while (disposeQueue.Count > 0) + while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + // The stack _could_ be empty if we're writing a null, so check before popping + if (runner.state.StackTop != 0) { - var table = disposeQueue.Dequeue(); - table.Dispose(); + runner.state.Pop(1); } } + + // Writes the number on the top of the stack, removes it from the stack + static void WriteNumber(LuaRunner runner, ref TResponse resp) + { + Debug.Assert(runner.state.Type(runner.state.StackTop) == LuaType.Number, "Number was not on top of stack"); + + // Redis unconditionally converts all "number" replies to integer replies so we match that + // + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + var num = (long)runner.state.CheckNumber(runner.state.StackTop); + + while (!RespWriteUtils.WriteInteger(num, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + runner.state.Pop(1); + } + + // Writes the string on the top of the stack, removes it from the stack + static void WriteString(LuaRunner runner, ref TResponse resp) + { + runner.state.KnownStringToBuffer(runner.state.StackTop, out var buf); + + while (!RespWriteUtils.WriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + runner.state.Pop(1); + } + + // Writes the boolean on the top of the stack, removes it from the stack + static void WriteBoolean(LuaRunner runner, ref TResponse resp) + { + Debug.Assert(runner.state.Type(runner.state.StackTop) == LuaType.Boolean, "Boolean was not on top of stack"); + + // Redis maps Lua false to null, and Lua true to 1 this is strange, but documented + // + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + if (runner.state.ToBoolean(runner.state.StackTop)) + { + while (!RespWriteUtils.WriteInteger(1, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + } + else + { + while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + } + + runner.state.Pop(1); + } + + // Writes the string on the top of the stack out as an error, removes the string from the stack + static void WriteError(LuaRunner runner, ref TResponse resp) + { + runner.state.KnownStringToBuffer(runner.state.StackTop, out var errBuff); + + while (!RespWriteUtils.WriteError(errBuff, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + runner.state.Pop(1); + } + + static void WriteArray(LuaRunner runner, ref TResponse resp) + { + // Redis does not respect metatables, so RAW access is ok here + + // 1 for the table, 1 for the pending value + const int AdditonalNeededStackSize = 2; + + Debug.Assert(runner.state.Type(runner.state.StackTop) == LuaType.Table, "Table was not on top of stack"); + + // Lua # operator - this MAY stop at nils, but isn't guaranteed to + // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 + var maxLen = runner.state.RawLen(runner.state.StackTop); + + // Find the TRUE length by scanning for nils + var trueLen = 0; + for (trueLen = 0; trueLen < maxLen; trueLen++) + { + var type = runner.state.RawGetInteger(null, runner.state.StackTop, trueLen + 1); + runner.state.Pop(1); + + if (type == LuaType.Nil) + { + break; + } + } + + while (!RespWriteUtils.WriteArrayLength((int)trueLen, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + for (var i = 1; i <= trueLen; i++) + { + // Push item at index i onto the stack + var type = runner.state.RawGetInteger(null, runner.state.StackTop, i); + + switch (type) + { + case LuaType.String: + WriteString(runner, ref resp); + break; + case LuaType.Number: + WriteNumber(runner, ref resp); + break; + case LuaType.Boolean: + WriteBoolean(runner, ref resp); + break; + case LuaType.Table: + // For tables, we need to recurse - which means we need to check stack sizes again + runner.state.ForceMinimumStackCapacity(AdditonalNeededStackSize); + WriteArray(runner, ref resp); + break; + + // All other Lua types map to nulls + default: + WriteNull(runner, ref resp); + break; + } + } + + runner.state.Pop(1); + } } } } \ No newline at end of file diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs new file mode 100644 index 0000000000..51ac6b0e0a --- /dev/null +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -0,0 +1,521 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text; +using Garnet.common; +using KeraLua; + +namespace Garnet.server +{ + /// + /// For performance purposes, we need to track some additional state alongside the + /// raw Lua runtime. + /// + /// This type does that. + /// + internal struct LuaStateWrapper : IDisposable + { + private const int LUA_MINSTACK = 20; + + private readonly Lua state; + + private int curStackSize; + + internal LuaStateWrapper(Lua state) + { + this.state = state; + + curStackSize = LUA_MINSTACK; + StackTop = 0; + + AssertLuaStackExpected(); + } + + /// + public readonly void Dispose() + { + state.Dispose(); + } + + /// + /// Current top item in the stack. + /// + /// 0 implies the stack is empty. + /// + internal int StackTop { get; private set; } + + /// + /// Call when ambient state indicates that the Lua stack is in fact empty. + /// + /// Maintains to avoid unnecessary p/invokes. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ExpectLuaStackEmpty() + { + StackTop = 0; + AssertLuaStackExpected(); + } + + /// + /// Ensure there's enough space on the Lua stack for more items. + /// + /// Throws if there is not. + /// + /// Maintains to avoid unnecessary p/invokes. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ForceMinimumStackCapacity(int additionalCapacity) + { + var availableSpace = curStackSize - StackTop; + + if (availableSpace >= additionalCapacity) + { + return; + } + + var needed = additionalCapacity - availableSpace; + if (!state.CheckStack(needed)) + { + throw new GarnetException("Could not reserve additional capacity on the Lua stack"); + } + + curStackSize += additionalCapacity; + } + + /// + /// Call when the Lua runtime calls back into .NET code. + /// + /// Figures out the state of the Lua stack once, to avoid unnecessary p/invokes. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void CallFromLuaEntered(IntPtr luaStatePtr) + { + Debug.Assert(luaStatePtr == state.Handle, "Unexpected Lua state presented"); + + StackTop = NativeMethods.GetTop(state.Handle); + curStackSize = StackTop > LUA_MINSTACK ? StackTop : LUA_MINSTACK; + } + + /// + /// This should be used for all CheckBuffer calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly bool CheckBuffer(int index, out ReadOnlySpan str) + { + AssertLuaStackIndexInBounds(index); + + return NativeMethods.CheckBuffer(state.Handle, index, out str); + } + + /// + /// This should be used for all Type calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly LuaType Type(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return NativeMethods.Type(state.Handle, stackIndex); + } + + /// + /// This should be used for all PushBuffer calls into Lua. + /// + /// If the string is a constant, consider registering it in the constructor and using instead. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushBuffer(ReadOnlySpan buffer) + { + AssertLuaStackNotFull(); + + NativeMethods.PushBuffer(state.Handle, buffer); + UpdateStackTop(1); + } + + /// + /// This should be used for all PushNil calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushNil() + { + AssertLuaStackNotFull(); + + NativeMethods.PushNil(state.Handle); + UpdateStackTop(1); + } + + /// + /// This should be used for all PushInteger calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushInteger(long number) + { + AssertLuaStackNotFull(); + + NativeMethods.PushInteger(state.Handle, number); + + UpdateStackTop(1); + } + + /// + /// This should be used for all PushBoolean calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushBoolean(bool b) + { + AssertLuaStackNotFull(); + + NativeMethods.PushBoolean(state.Handle, b); + UpdateStackTop(1); + } + + /// + /// This should be used for all Pop calls into Lua. + /// + /// Maintains to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Pop(int num) + { + NativeMethods.Pop(state.Handle, num); + + UpdateStackTop(-num); + } + + /// + /// This should be used for all Calls into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Call(int args, int rets) + { + // We have to copy this off, as once we Call curStackTop could be modified + var oldStackTop = StackTop; + state.Call(args, rets); + + if (rets < 0) + { + StackTop = NativeMethods.GetTop(state.Handle); + AssertLuaStackExpected(); + } + else + { + var newPosition = oldStackTop - (args + 1) + rets; + var update = newPosition - StackTop; + UpdateStackTop(update); + } + } + + /// + /// This should be used for all PCalls into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal LuaStatus PCall(int args, int rets) + { + // We have to copy this off, as once we Call curStackTop could be modified + var oldStackTop = StackTop; + var res = state.PCall(args, rets, 0); + + if (res != LuaStatus.OK || rets < 0) + { + StackTop = NativeMethods.GetTop(state.Handle); + AssertLuaStackExpected(); + } + else + { + var newPosition = oldStackTop - (args + 1) + rets; + var update = newPosition - StackTop; + UpdateStackTop(update); + } + + return res; + } + + /// + /// This should be used for all RawSetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void RawSetInteger(int stackIndex, int tableIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + state.RawSetInteger(stackIndex, tableIndex); + UpdateStackTop(-1); + } + + /// + /// This should be used for all RawSets into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void RawSet(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + state.RawSet(stackIndex); + UpdateStackTop(-2); + } + + /// + /// This should be used for all RawGetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal LuaType RawGetInteger(LuaType? expectedType, int stackIndex, int tableIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + AssertLuaStackNotFull(); + + var actual = state.RawGetInteger(stackIndex, tableIndex); + Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); + + UpdateStackTop(1); + + return actual; + } + + /// + /// This should be used for all RawGets into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly LuaType RawGet(LuaType? expectedType, int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + var actual = state.RawGet(stackIndex); + Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); + + AssertLuaStackExpected(); + + return actual; + } + + /// + /// This should be used for all Refs into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal int Ref() + { + var ret = state.Ref(LuaRegistry.Index); + UpdateStackTop(-1); + + return ret; + } + + /// + /// This should be used for all Unrefs into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly void Unref(LuaRegistry registry, int reference) + { + state.Unref(registry, reference); + } + + /// + /// This should be used for all CreateTables into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void CreateTable(int numArr, int numRec) + { + AssertLuaStackNotFull(); + + state.CreateTable(numArr, numRec); + UpdateStackTop(1); + } + + /// + /// This should be used for all GetGlobals into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + internal void GetGlobal(LuaType expectedType, string globalName) + { + AssertLuaStackNotFull(); + + var type = state.GetGlobal(globalName); + Debug.Assert(type == expectedType, "Unexpected type received"); + + UpdateStackTop(1); + } + + /// + /// This should be used for all LoadBuffers into Lua. + /// + /// Note that this is different from pushing a buffer, as the loaded buffer is compiled. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal LuaStatus LoadBuffer(ReadOnlySpan buffer) + { + AssertLuaStackNotFull(); + + var ret = NativeMethods.LoadBuffer(state.Handle, buffer); + + UpdateStackTop(1); + + return ret; + } + + /// + /// Call when value at index is KNOWN to be a string or number + /// + /// only remains valid as long as the buffer remains on the stack, + /// use with care. + /// + /// Note that is changes the value on the stack to be a string if it returns true, regardless of + /// what it was originally. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly void KnownStringToBuffer(int stackIndex, out ReadOnlySpan str) + { + AssertLuaStackIndexInBounds(stackIndex); + + Debug.Assert(NativeMethods.Type(state.Handle, stackIndex) is LuaType.String or LuaType.Number, "Called with non-string, non-number"); + + NativeMethods.KnownStringToBuffer(state.Handle, stackIndex, out str); + } + + /// + /// This should be used for all CheckNumbers into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly double CheckNumber(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return state.CheckNumber(stackIndex); + } + + /// + /// This should be used for all ToBooleans into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly bool ToBoolean(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return NativeMethods.ToBoolean(state.Handle, stackIndex); + } + + /// + /// This should be used for all RawLens into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly long RawLen(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return state.RawLen(stackIndex); + } + + /// + /// Call to register a function in the Lua global namespace. + /// + internal readonly void Register(string name, LuaFunction func) + => state.Register(name, func); + + /// + /// This should be used to push all known constants strings into Lua. + /// + /// This avoids extra copying of data between .NET and Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushConstantString(int constStringRegistryIndex, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + => RawGetInteger(LuaType.String, (int)LuaRegistry.Index, constStringRegistryIndex); + + // Rarely used + + /// + /// Remove everything from the Lua stack. + /// + internal void ClearStack() + { + state.SetTop(0); + StackTop = 0; + + AssertLuaStackExpected(); + } + + /// + /// Clear the stack and raise an error with the given message. + /// + internal int RaiseError(string msg) + { + ClearStack(); + + var b = Encoding.UTF8.GetBytes(msg); + return RaiseErrorFromStack(); + } + + /// + /// Raise an error, where the top of the stack is the error message. + /// + internal readonly int RaiseErrorFromStack() + { + Debug.Assert(StackTop != 0, "Expected error message on the stack"); + + return state.Error(); + } + + /// + /// Helper to update . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void UpdateStackTop(int by) + { + StackTop += by; + AssertLuaStackExpected(); + } + + // Conditional compilation checks + + /// + /// Check that the given index refers to a valid part of the stack. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly void AssertLuaStackIndexInBounds(int stackIndex) + { + Debug.Assert(stackIndex == (int)LuaRegistry.Index || (stackIndex > 0 && stackIndex <= StackTop), "Lua stack index out of bounds"); + } + + /// + /// Check that the Lua stack top is where expected in DEBUG builds. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly void AssertLuaStackExpected() + { + Debug.Assert(NativeMethods.GetTop(state.Handle) == StackTop, "Lua stack not where expected"); + } + + /// + /// Check that there's space to push some number of elements. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly void AssertLuaStackNotFull(int probe = 1) + { + Debug.Assert((StackTop + probe) <= curStackSize, "Lua stack should have been grown before pushing"); + } + } +} \ No newline at end of file diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs new file mode 100644 index 0000000000..cc6bcf085b --- /dev/null +++ b/libs/server/Lua/NativeMethods.cs @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using KeraLua; +using charptr_t = nint; +using lua_State = nint; +using size_t = nuint; + +namespace Garnet.server +{ + /// + /// Lua runtime methods we want that are not provided by . + /// + /// Long term we'll want to try and push these upstreams and move to just using KeraLua, + /// but for now we're just defining them ourselves. + /// + internal static partial class NativeMethods + { + private const string LuaLibraryName = "lua54"; + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_tolstring + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial charptr_t lua_tolstring(lua_State L, int index, out size_t len); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushlstring + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_loadbufferx + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); + + // GC Transition suppressed - only do this after auditing the Lua method and confirming constant-ish, fast, runtime w/o allocations + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_gettop + /// + /// Does basically nothing, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_gettop + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial int lua_gettop(lua_State luaState); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_type + /// + /// Does some very basic ifs and then returns, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_type + /// And + /// see: https://www.lua.org/source/5.4/lapi.c.html#index2value + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial LuaType lua_type(lua_State L, int index); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushnil + /// + /// Does some very small writes, and stack size is pre-validated, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushnil + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_pushnil(lua_State L); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushinteger + /// + /// Does some very small writes, and stack size is pre-validated, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushinteger + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_pushinteger(lua_State L, long num); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushboolean + /// + /// Does some very small writes, and stack size is pre-validated, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushboolean + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_pushboolean(lua_State L, int b); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_toboolean + /// + /// Does some very basic ifs and then returns an int, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_toboolean + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial int lua_toboolean(lua_State L, int ix); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_settop + /// + /// We aren't pushing complex types, so none of the close logic should run. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_settop + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_settop(lua_State L, int num); + + /// + /// Returns true if the given index on the stack holds a string or a number. + /// + /// Sets to the string equivalent if so, otherwise leaves it empty. + /// + /// only remains valid as long as the buffer remains on the stack, + /// use with care. + /// + /// Note that is changes the value on the stack to be a string if it returns true, regardless of + /// what it was originally. + /// + internal static bool CheckBuffer(lua_State luaState, int index, out ReadOnlySpan str) + { + var type = lua_type(luaState, index); + + if (type is not LuaType.String and not LuaType.Number) + { + str = []; + return false; + } + + var start = lua_tolstring(luaState, index, out var len); + unsafe + { + str = new ReadOnlySpan((byte*)start, (int)len); + return true; + } + } + + /// + /// Call when value at index is KNOWN to be a string or number + /// + /// only remains valid as long as the buffer remains on the stack, + /// use with care. + /// + /// Note that is changes the value on the stack to be a string if it returns true, regardless of + /// what it was originally. + /// + internal static void KnownStringToBuffer(lua_State luaState, int index, out ReadOnlySpan str) + { + var start = lua_tolstring(luaState, index, out var len); + unsafe + { + str = new ReadOnlySpan((byte*)start, (int)len); + } + } + + /// + /// Pushes given span to stack as a string. + /// + /// Provided data is copied, and can be reused once this call returns. + /// + internal static unsafe void PushBuffer(lua_State luaState, ReadOnlySpan str) + { + fixed (byte* ptr = str) + { + _ = lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length); + } + } + + /// + /// Push given span to stack, and compiles it. + /// + /// Provided data is copied, and can be reused once this call returns. + /// + internal static unsafe LuaStatus LoadBuffer(lua_State luaState, ReadOnlySpan str) + { + fixed (byte* ptr = str) + { + return luaL_loadbufferx(luaState, (charptr_t)ptr, (size_t)str.Length, (charptr_t)UIntPtr.Zero, (charptr_t)UIntPtr.Zero); + } + } + + /// + /// Get the top index on the stack. + /// + /// 0 indicates empty. + /// + /// Differs from by suppressing GC transition. + /// + internal static int GetTop(lua_State luaState) + => lua_gettop(luaState); + + /// + /// Gets the type of the value at the stack index. + /// + /// Differs from by suppressing GC transition. + /// + internal static LuaType Type(lua_State luaState, int index) + => lua_type(luaState, index); + + /// + /// Pushes a nil value onto the stack. + /// + /// Differs from by suppressing GC transition. + /// + internal static void PushNil(lua_State luaState) + => lua_pushnil(luaState); + + /// + /// Pushes a double onto the stack. + /// + /// Differs from by suppressing GC transition. + /// + internal static void PushInteger(lua_State luaState, long num) + => lua_pushinteger(luaState, num); + + /// + /// Pushes a boolean onto the stack. + /// + /// Differs from by suppressing GC transition. + /// + internal static void PushBoolean(lua_State luaState, bool b) + => lua_pushboolean(luaState, b ? 1 : 0); + + /// + /// Read a boolean off the stack + /// + /// Differs from by suppressing GC transition. + /// + internal static bool ToBoolean(lua_State luaState, int index) + => lua_toboolean(luaState, index) != 0; + + /// + /// Remove some number of items from the stack. + /// + /// Differs form by suppressing GC transition. + /// + internal static void Pop(lua_State luaState, int num) + => lua_settop(luaState, -num - 1); + } +} \ No newline at end of file diff --git a/libs/server/Lua/ScriptHashKey.cs b/libs/server/Lua/ScriptHashKey.cs new file mode 100644 index 0000000000..adcc0a602d --- /dev/null +++ b/libs/server/Lua/ScriptHashKey.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Garnet.server +{ + /// + /// Specialized key type for storing script hashes. + /// + public readonly struct ScriptHashKey : IEquatable + { + // Necessary to keep this alive + private readonly byte[] arrRef; + private readonly unsafe long* ptr; + + internal unsafe ScriptHashKey(ReadOnlySpan stackSpan) + { + Debug.Assert(stackSpan.Length == SessionScriptCache.SHA1Len, "Only one valid length for script hash keys"); + + ptr = (long*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(stackSpan)); + } + + internal unsafe ScriptHashKey(byte[] pohArr) + : this(pohArr.AsSpan()) + { + arrRef = pohArr; + } + + /// + /// Copy key data. + /// + public unsafe void CopyTo(Span into) + { + new Span(ptr, SessionScriptCache.SHA1Len).CopyTo(into); + } + + /// + public unsafe bool Equals(ScriptHashKey other) + { + Debug.Assert(SessionScriptCache.SHA1Len == 40, "Making a hard assumption that we're comparing 40 bytes"); + + var a = ptr; + var b = other.ptr; + + return + *(a++) == *(b++) && + *(a++) == *(b++) && + *(a++) == *(b++) && + *(a++) == *(b++) && + *(a++) == *(b++); + } + + /// + public override unsafe int GetHashCode() + => *(int*)ptr; + + /// + public override bool Equals([NotNullWhen(true)] object obj) + => obj is ScriptHashKey other && Equals(other); + } +} \ No newline at end of file diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index eddc0d9ffc..4dfd1ac296 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -25,7 +25,7 @@ internal sealed class SessionScriptCache : IDisposable readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly StoreWrapper storeWrapper; readonly ILogger logger; - readonly Dictionary scriptCache = new(SpanByteAndMemoryComparer.Instance); + readonly Dictionary scriptCache = []; readonly byte[] hash = new byte[SHA1Len / 2]; public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authenticator, ILogger logger = null) @@ -52,47 +52,48 @@ public void SetUser(User user) /// /// Try get script runner for given digest /// - public bool TryGetFromDigest(SpanByteAndMemory digest, out LuaRunner scriptRunner) + public bool TryGetFromDigest(ScriptHashKey digest, out LuaRunner scriptRunner) => scriptCache.TryGetValue(digest, out scriptRunner); /// - /// Load script into the cache + /// Load script into the cache. + /// + /// If necessary, will be set so the allocation can be reused. /// - public bool TryLoad(byte[] source, out byte[] digest, out LuaRunner runner, out string error) - { - digest = new byte[SHA1Len]; - GetScriptDigest(source, digest); - - return TryLoad(source, new SpanByteAndMemory(new ScriptHashOwner(digest), digest.Length), out runner, out error); - } - - internal bool TryLoad(byte[] source, SpanByteAndMemory digest, out LuaRunner runner, out string error) + internal bool TryLoad(RespServerSession session, ReadOnlySpan source, ScriptHashKey digest, out LuaRunner runner, out ScriptHashKey? digestOnHeap, out string error) { error = null; if (scriptCache.TryGetValue(digest, out runner)) + { + digestOnHeap = null; return true; + } try { - runner = new LuaRunner(source, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); - runner.Compile(); + var sourceOnHeap = source.ToArray(); + + runner = new LuaRunner(sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); + runner.CompileForSession(session); - // need to make sure the key is on the heap, so move it over if needed - var storeKeyDigest = digest; - if (storeKeyDigest.IsSpanByte) - { - var into = new byte[storeKeyDigest.Length]; - storeKeyDigest.AsReadOnlySpan().CopyTo(into); + // Need to make sure the key is on the heap, so move it over + // + // There's an implicit assumption that all callers are using unmanaged memory. + // If that becomes untrue, there's an optimization opportunity to re-use the + // managed memory here. + var into = GC.AllocateUninitializedArray(SHA1Len, pinned: true); + digest.CopyTo(into); - storeKeyDigest = new SpanByteAndMemory(new ScriptHashOwner(into), into.Length); - } + ScriptHashKey storeKeyDigest = new(into); + digestOnHeap = storeKeyDigest; _ = scriptCache.TryAdd(storeKeyDigest, runner); } catch (Exception ex) { error = ex.Message; + digestOnHeap = null; return false; } diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs index 7ec2967e95..22141be4d4 100644 --- a/libs/server/Resp/CmdStrings.cs +++ b/libs/server/Resp/CmdStrings.cs @@ -337,5 +337,13 @@ static partial class CmdStrings public static ReadOnlySpan initiate_replica_sync => "INITIATE_REPLICA_SYNC"u8; public static ReadOnlySpan send_ckpt_file_segment => "SEND_CKPT_FILE_SEGMENT"u8; public static ReadOnlySpan send_ckpt_metadata => "SEND_CKPT_METADATA"u8; + + // Lua scripting strings + public static ReadOnlySpan LUA_OK => "OK"u8; + public static ReadOnlySpan LUA_err => "err"u8; + public static ReadOnlySpan LUA_No_session_available => "No session available"u8; + public static ReadOnlySpan LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call => "ERR Please specify at least one argument for this redis lib call"u8; + public static ReadOnlySpan LUA_ERR_Unknown_Redis_command_called_from_script => "ERR Unknown Redis command called from script"u8; + public static ReadOnlySpan LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers => "ERR Lua redis lib command arguments must be strings or integers"u8; } } \ No newline at end of file diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index 138c8a14b4..f72be029bb 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -86,7 +86,7 @@ internal sealed unsafe partial class RespServerSession : ServerSessionBase /// int endReadHead; - byte* dcurr, dend; + internal byte* dcurr, dend; bool toDispose; int opCount; @@ -975,7 +975,7 @@ private unsafe bool Write(int seqNo, ref byte* dst, int length) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void SendAndReset() + internal void SendAndReset() { byte* d = networkSender.GetResponseObjectHead(); if ((int)(dcurr - d) > 0) diff --git a/libs/server/Resp/SpanByteAndMemoryComparer.cs b/libs/server/Resp/SpanByteAndMemoryComparer.cs deleted file mode 100644 index e3ecc4dded..0000000000 --- a/libs/server/Resp/SpanByteAndMemoryComparer.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -using System; -using System.Collections.Generic; -using Tsavorite.core; - -namespace Garnet.server -{ - /// - /// equality comparer. - /// - public sealed class SpanByteAndMemoryComparer : IEqualityComparer - { - /// - /// The default instance. - /// - /// Used to avoid allocating new comparers. - public static readonly SpanByteAndMemoryComparer Instance = new(); - - private SpanByteAndMemoryComparer() { } - - /// - public bool Equals(SpanByteAndMemory left, SpanByteAndMemory right) - => left.AsReadOnlySpan().SequenceEqual(right.AsReadOnlySpan()); - - /// - public unsafe int GetHashCode(SpanByteAndMemory key) - { - var hash = new HashCode(); - hash.AddBytes(key.AsReadOnlySpan()); - - var ret = hash.ToHashCode(); - - return ret; - } - } -} \ No newline at end of file diff --git a/libs/server/StoreWrapper.cs b/libs/server/StoreWrapper.cs index 49b30eff39..428003f335 100644 --- a/libs/server/StoreWrapper.cs +++ b/libs/server/StoreWrapper.cs @@ -99,8 +99,10 @@ public sealed class StoreWrapper internal readonly string run_id; private SingleWriterMultiReaderLock _checkpointTaskLock; - // Lua script cache - public readonly ConcurrentDictionary storeScriptCache; + /// + /// Lua script cache + /// + public readonly ConcurrentDictionary storeScriptCache; public readonly TimeSpan loggingFrequncy; @@ -153,7 +155,7 @@ public StoreWrapper( // Initialize store scripting cache if (serverOptions.EnableLua) - this.storeScriptCache = new(SpanByteAndMemoryComparer.Instance); + this.storeScriptCache = []; if (accessControlList == null) { diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index a9e07c0601..b7a7a7f0f7 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -3,7 +3,6 @@ using Garnet.common; using Garnet.server; -using NLua.Exceptions; using NUnit.Framework; using NUnit.Framework.Legacy; @@ -18,48 +17,48 @@ public void CannotRunUnsafeScript() // Try to load an assembly using (var runner = new LuaRunner("luanet.load_assembly('mscorlib')")) { - runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + runner.CompileForRunner(); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message); } // Try to call a OS function using (var runner = new LuaRunner("os = require('os'); return os.time();")) { - runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + runner.CompileForRunner(); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message); } // Try to execute the input stream using (var runner = new LuaRunner("dofile();")) { - runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + runner.CompileForRunner(); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message); } // Try to call a windows executable using (var runner = new LuaRunner("require \"notepad\"")) { - runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + runner.CompileForRunner(); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message); } // Try to call an OS function using (var runner = new LuaRunner("os.exit();")) { - runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + runner.CompileForRunner(); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message); } // Try to include a new .net library using (var runner = new LuaRunner("import ('System.Diagnostics');")) { - runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + runner.CompileForRunner(); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message); } } @@ -70,14 +69,14 @@ public void CanLoadScript() // Code with error using (var runner = new LuaRunner("local;")) { - var ex = Assert.Throws(runner.Compile); + var ex = Assert.Throws(runner.CompileForRunner); ClassicAssert.AreEqual("Compilation error: [string \"local;\"]:1: expected near ';'", ex.Message); } // Code without error using (var runner = new LuaRunner("local list; list = 1; return list;")) { - runner.Compile(); + runner.CompileForRunner(); } } @@ -90,17 +89,46 @@ public void CanRunScript() // Run code without errors using (var runner = new LuaRunner("local list; list = ARGV[1] ; return list;")) { - runner.Compile(); - var res = runner.Run(keys, args); + runner.CompileForRunner(); + var res = runner.RunForRunner(keys, args); ClassicAssert.AreEqual("arg1", res); } // Run code with errors using (var runner = new LuaRunner("local list; list = ; return list;")) { - var ex = Assert.Throws(runner.Compile); + var ex = Assert.Throws(runner.CompileForRunner); ClassicAssert.AreEqual("Compilation error: [string \"local list; list = ; return list;\"]:1: unexpected symbol near ';'", ex.Message); } } + + [Test] + public void KeysAndArgsCleared() + { + using (var runner = new LuaRunner("return { KEYS[1], ARGV[1], KEYS[2], ARGV[2] }")) + { + runner.CompileForRunner(); + var res1 = runner.RunForRunner(["hello", "world"], ["fizz", "buzz"]); + var obj1 = (object[])res1; + ClassicAssert.AreEqual(4, obj1.Length); + ClassicAssert.AreEqual("hello", (string)obj1[0]); + ClassicAssert.AreEqual("fizz", (string)obj1[1]); + ClassicAssert.AreEqual("world", (string)obj1[2]); + ClassicAssert.AreEqual("buzz", (string)obj1[3]); + + var res2 = runner.RunForRunner(["abc"], ["def"]); + var obj2 = (object[])res2; + ClassicAssert.AreEqual(2, obj2.Length); + ClassicAssert.AreEqual("abc", (string)obj2[0]); + ClassicAssert.AreEqual("def", (string)obj2[1]); + + var res3 = runner.RunForRunner(["012", "345"], ["678"]); + var obj3 = (object[])res3; + ClassicAssert.AreEqual(3, obj3.Length); + ClassicAssert.AreEqual("012", (string)obj3[0]); + ClassicAssert.AreEqual("678", (string)obj3[1]); + ClassicAssert.AreEqual("345", (string)obj3[2]); + } + } } } \ No newline at end of file diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index 3d326f2a34..ace0593340 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -143,7 +143,7 @@ public void CanDoEvalShaWithZAddMultiPairSE() var db = redis.GetDatabase(0); // Create a sorted set - var script = "local ptable = {100, \"value1\", 200, \"value2\"}; return redis.call('zadd', KEYS[1], ptable)"; + var script = "return redis.call('zadd', KEYS[1], 100, \"value1\", 200, \"value2\")"; var result = db.ScriptEvaluate(script, [(RedisKey)"mysskey"]); ClassicAssert.IsTrue(result.ToString() == "2"); @@ -156,7 +156,7 @@ public void CanDoEvalShaWithZAddMultiPairSE() ClassicAssert.IsTrue(result.ToString() == "0"); // Add more pairs - script = "local ptable = {300, \"value3\", 400, \"value4\"}; return redis.call('zadd', KEYS[1], ptable)"; + script = "return redis.call('zadd', KEYS[1], 300, \"value3\", 400, \"value4\")"; result = db.ScriptEvaluate(script, [(RedisKey)"mysskey"]); ClassicAssert.IsTrue(result.ToString() == "2"); @@ -330,23 +330,13 @@ public void FailureStatusReturn() using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); var statusReplyScript = "return redis.error_reply('Failure')"; - try - { - _ = db.ScriptEvaluate(statusReplyScript); - } - catch (RedisServerException ex) - { - ClassicAssert.AreEqual(ex.Message, "Failure"); - } + + var excReply = ClassicAssert.Throws(() => db.ScriptEvaluate(statusReplyScript)); + ClassicAssert.AreEqual("ERR Failure", excReply.Message); + var directReplyScript = "return { err = 'Failure' }"; - try - { - _ = db.ScriptEvaluate(directReplyScript); - } - catch (RedisServerException ex) - { - ClassicAssert.AreEqual(ex.Message, "Failure"); - } + var excDirect = ClassicAssert.Throws(() => db.ScriptEvaluate(directReplyScript)); + ClassicAssert.AreEqual("Failure", excDirect.Message); } [Test] @@ -430,10 +420,13 @@ local function callgetrange() using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); + var response1 = db.ScriptEvaluate(script1, ["key1", "key2"], ["foo", 3, 60_000]); ClassicAssert.AreEqual("OK", (string)response1); + var response2 = db.ScriptEvaluate(script2, ["key3"], ["foo"]); ClassicAssert.AreEqual(false, (bool)response2); + var response3 = db.ScriptEvaluate(script2, ["key1", "key2"], ["foo"]); ClassicAssert.AreEqual("OK", (string)response3); } @@ -442,8 +435,10 @@ local function callgetrange() public void ComplexLuaTest3() { var script1 = """ -return redis.call("mget", unpack(KEYS)) -"""; + return redis.call("mget", unpack(KEYS)) + """; + + //var script1 = "return KEYS"; using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); @@ -538,5 +533,257 @@ public void ScriptExistsMultiple() ClassicAssert.AreEqual(0, (long)exists[2]); } } + + [Test] + public void RedisCallErrors() + { + // Testing that our error replies for redis.call match Redis behavior + // + // TODO: exact matching of the hash and line number would also be nice, but that is trickier + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // No args + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call()")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Please specify at least one argument for this redis lib call")); + } + + // Unknown command + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('123')")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Unknown Redis command called from script")); + } + + // Bad command type + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call({ foo = 'bar'})")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + + // GET bad arg type + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('GET', { foo = 'bar' })")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + + // SET bad arg types + { + var exc1 = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('SET', 'hello', { foo = 'bar' })")); + ClassicAssert.IsTrue(exc1.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + + var exc2 = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('SET', { foo = 'bar' }, 'world')")); + ClassicAssert.IsTrue(exc2.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + + // Other bad arg types + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('DEL', { foo = 'bar' })")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + } + + [Test] + public void BinaryValuesInScripts() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var trickyKey = new byte[] { 0, 1, 2, 3, 4 }; + var trickyValue = new byte[] { 5, 6, 7, 8, 9, 0 }; + var trickyValue2 = new byte[] { 0, 1, 0, 1, 0, 1, 255 }; + + ClassicAssert.IsTrue(db.StringSet(trickyKey, trickyValue)); + + var luaEscapeKeyString = $"{string.Join("", trickyKey.Select(x => $"\\{x:X2}"))}"; + + var readDirectKeyRaw = db.ScriptEvaluate($"return redis.call('GET', '{luaEscapeKeyString}')", [(RedisKey)trickyKey]); + var readDirectKeyBytes = (byte[])readDirectKeyRaw; + ClassicAssert.IsTrue(trickyValue.AsSpan().SequenceEqual(readDirectKeyBytes)); + + var setKey = db.ScriptEvaluate("return redis.call('SET', KEYS[1], ARGV[1])", [(RedisKey)trickyKey], [(RedisValue)trickyValue2]); + ClassicAssert.AreEqual("OK", (string)setKey); + + var readTrickyValue2Raw = db.StringGet(trickyKey); + var readTrickyValue2 = (byte[])readTrickyValue2Raw; + ClassicAssert.IsTrue(trickyValue2.AsSpan().SequenceEqual(readTrickyValue2)); + } + + [Test] + public void NumberArgumentCoercion() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + db.StringSet("2", "hello"); + db.StringSet("2.1", "world"); + + var res = (string)db.ScriptEvaluate("return redis.call('GET', 2.1)"); + ClassicAssert.AreEqual("world", res); + } + + [Test] + public void ComplexLuaReturns() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // Relatively complicated + { + var res1 = (RedisResult[])db.ScriptEvaluate("return { 1, 'hello', { true, false }, { fizz = 'buzz', hello = 'world' }, 5, 6 }"); + ClassicAssert.AreEqual(6, res1.Length); + ClassicAssert.AreEqual(1, (long)res1[0]); + ClassicAssert.AreEqual("hello", (string)res1[1]); + var res1Sub1 = (RedisResult[])res1[2]; + ClassicAssert.AreEqual(2, res1Sub1.Length); + ClassicAssert.AreEqual(1, (long)res1Sub1[0]); + ClassicAssert.IsTrue(res1Sub1[1].IsNull); + var res1Sub2 = (RedisResult[])res1[2]; + ClassicAssert.AreEqual(2, res1Sub2.Length); + ClassicAssert.IsTrue((bool)res1Sub2[0]); + ClassicAssert.IsTrue(res1Sub2[1].IsNull); + var res1Sub3 = (RedisResult[])res1[3]; + ClassicAssert.AreEqual(0, res1Sub3.Length); + ClassicAssert.AreEqual(5, (long)res1[4]); + ClassicAssert.AreEqual(6, (long)res1[5]); + } + + // Only indexable will be included + { + var res2 = (RedisResult[])db.ScriptEvaluate("return { 1, 2, fizz='buzz' }"); + ClassicAssert.AreEqual(2, res2.Length); + ClassicAssert.AreEqual(1, (long)res2[0]); + ClassicAssert.AreEqual(2, (long)res2[1]); + } + + // Non-string, non-number, are nullish + { + var res3 = (RedisResult[])db.ScriptEvaluate("return { 1, function() end, 3 }"); + ClassicAssert.AreEqual(3, res3.Length); + ClassicAssert.AreEqual(1, (long)res3[0]); + ClassicAssert.IsTrue(res3[1].IsNull); + ClassicAssert.AreEqual(3, (long)res3[2]); + } + + // Nil stops return of subsequent values + { + var res4 = (RedisResult[])db.ScriptEvaluate("return { 1, nil, 3 }"); + ClassicAssert.AreEqual(1, res4.Length); + ClassicAssert.AreEqual(1, (long)res4[0]); + } + + // Incredibly deeply nested return + { + const int Depth = 100; + + var tableDepth = new StringBuilder(); + for (var i = 1; i <= Depth; i++) + { + if (i != 1) + { + tableDepth.Append(", "); + } + tableDepth.Append("{ "); + tableDepth.Append(i); + } + for (var i = 1; i <= Depth; i++) + { + tableDepth.Append(" }"); + } + + var script = "return " + tableDepth.ToString(); + + var res5 = db.ScriptEvaluate(script); + + var cur = res5; + for (var i = 1; i < Depth; i++) + { + var top = (RedisResult[])cur; + ClassicAssert.AreEqual(2, top.Length); + ClassicAssert.AreEqual(i, (long)top[0]); + + cur = top[1]; + } + + // Remainder should have a single element + var remainder = (RedisResult[])cur; + ClassicAssert.AreEqual(1, remainder.Length); + ClassicAssert.AreEqual(Depth, (long)remainder[0]); + } + + // Incredibly wide + { + const int Width = 100_000; + + var tableDepth = new StringBuilder(); + for (var i = 1; i <= Width; i++) + { + if (i != 1) + { + tableDepth.Append(", "); + } + tableDepth.Append("{ " + i + " }"); + } + + var script = "return { " + tableDepth.ToString() + " }"; + + var res5 = (RedisResult[])db.ScriptEvaluate(script); + for (var i = 0; i < Width; i++) + { + var elem = (RedisResult[])res5[i]; + ClassicAssert.AreEqual(1, elem.Length); + ClassicAssert.AreEqual(i + 1, (long)elem[0]); + } + } + } + + [Test] + public void MetatableReturn() + { + const string Script = @" + local table = { abc = 'def', ghi = 'jkl' } + local ret = setmetatable( + table, + { + __len = function (self) + return 4 + end, + __index = function(self, key) + if key == 1 then + return KEYS[1] + end + + if key == 2 then + return ARGV[1] + end + + if key == 3 then + return self.ghi + end + + if key == 4 then + return self.abc + end + + return nil + end + } + ) + + -- prove that metatables WORK but also that we don't deconstruct them for returns + return { ret, ret[1], ret[2], ret[3], ret[4] }"; + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var ret = (RedisResult[])db.ScriptEvaluate(Script, [(RedisKey)"foo"], [(RedisValue)"bar"]); + ClassicAssert.AreEqual(5, ret.Length); + var firstEmpty = (RedisResult[])ret[0]; + ClassicAssert.AreEqual(0, firstEmpty.Length); + ClassicAssert.AreEqual("foo", (string)ret[1]); + ClassicAssert.AreEqual("bar", (string)ret[2]); + ClassicAssert.AreEqual("jkl", (string)ret[3]); + ClassicAssert.AreEqual("def", (string)ret[4]); + } } } \ No newline at end of file