From 184f214b3eb9d1c5f20ef6f07efe4a9eeafb8d15 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Mon, 9 Dec 2024 14:48:04 -0500 Subject: [PATCH 01/51] NLua -> KeraLua; literally nothing compiles --- Directory.Packages.props | 2 +- benchmark/BDN.benchmark/BDN.benchmark.csproj | 2 +- libs/host/Garnet.host.csproj | 2 +- libs/server/ArgSlice/ScratchBufferManager.cs | 1 - libs/server/Garnet.server.csproj | 2 +- libs/server/Lua/LuaCommands.cs | 3 --- libs/server/Lua/LuaRunner.cs | 1 - test/Garnet.test/LuaScriptRunnerTests.cs | 1 - 8 files changed, 4 insertions(+), 10 deletions(-) diff --git a/Directory.Packages.props b/Directory.Packages.props index d674d273f0..0719a40ed8 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -8,7 +8,7 @@ - + diff --git a/benchmark/BDN.benchmark/BDN.benchmark.csproj b/benchmark/BDN.benchmark/BDN.benchmark.csproj index 60753023ab..525b4edc72 100644 --- a/benchmark/BDN.benchmark/BDN.benchmark.csproj +++ b/benchmark/BDN.benchmark/BDN.benchmark.csproj @@ -10,7 +10,7 @@ - + diff --git a/libs/host/Garnet.host.csproj b/libs/host/Garnet.host.csproj index a4d034dd85..1c1259fec3 100644 --- a/libs/host/Garnet.host.csproj +++ b/libs/host/Garnet.host.csproj @@ -29,7 +29,7 @@ - + diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 68861d3985..416dc6d560 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -7,7 +7,6 @@ using System.Runtime.CompilerServices; using System.Text; using Garnet.common; -using NLua; namespace Garnet.server { diff --git a/libs/server/Garnet.server.csproj b/libs/server/Garnet.server.csproj index e3abc5b5c2..04f91c2909 100644 --- a/libs/server/Garnet.server.csproj +++ b/libs/server/Garnet.server.csproj @@ -21,7 +21,7 @@ - + \ No newline at end of file diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index f306262719..3163aa6ba5 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -2,11 +2,8 @@ // Licensed under the MIT license. using System; -using System.Buffers; using Garnet.common; using Microsoft.Extensions.Logging; -using NLua; -using NLua.Exceptions; using Tsavorite.core; namespace Garnet.server diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index ea2ef3f0dd..7c8dca9ff5 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -6,7 +6,6 @@ using System.Text; using Garnet.common; using Microsoft.Extensions.Logging; -using NLua; namespace Garnet.server { diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index a9e07c0601..252296c884 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -3,7 +3,6 @@ using Garnet.common; using Garnet.server; -using NLua.Exceptions; using NUnit.Framework; using NUnit.Framework.Legacy; From ced705e5323cb68127d6a50a7f03cc1424b7aac4 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Mon, 9 Dec 2024 15:51:58 -0500 Subject: [PATCH 02/51] blind idiot translation into KeraLua; nothing works --- libs/server/ArgSlice/ScratchBufferManager.cs | 182 +++---- libs/server/Lua/LuaCommands.cs | 106 ++-- libs/server/Lua/LuaRunner.cs | 492 ++++++++++++++----- test/Garnet.test/LuaScriptRunnerTests.cs | 12 +- 4 files changed, 510 insertions(+), 282 deletions(-) diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 416dc6d560..69451e7c78 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -216,97 +216,97 @@ public ArgSlice FormatScratch(int headerSize, ReadOnlySpan arg) return retVal; } - /// - /// Format specified command with arguments, as a RESP command. Lua state - /// can be specified to handle Lua tables as arguments. - /// - public ArgSlice FormatCommandAsResp(string cmd, object[] args, Lua state) - { - if (scratchBuffer == null) - ExpandScratchBuffer(64); - - scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected - int commandStartOffset = scratchBufferOffset; - byte* ptr = scratchBufferHead + scratchBufferOffset; - - while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - - while (!RespWriteUtils.WriteAsciiBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - - int count = 1; - foreach (var item in args) - { - if (item is string str) - { - count++; - while (!RespWriteUtils.WriteAsciiBulkString(str, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - else if (item is LuaTable t) - { - var d = state.GetTableDict(t); - foreach (var value in d.Values) - { - count++; - while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(value), ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - t.Dispose(); - } - else if (item is long i) - { - count++; - while (!RespWriteUtils.WriteIntegerAsBulkString((int)i, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - else - { - count++; - while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(item), ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - } - if (count != args.Length + 1) - { - int extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); - if (commandStartOffset < extraSpace) - throw new InvalidOperationException("Invalid number of arguments"); - commandStartOffset -= extraSpace; - var head = scratchBufferHead + commandStartOffset; - // There should be space as we have reserved it - _ = RespWriteUtils.WriteArrayLength(count, ref head, scratchBufferHead + scratchBuffer.Length); - } - - var retVal = new ArgSlice(scratchBufferHead + commandStartOffset, scratchBufferOffset - commandStartOffset); - Debug.Assert(scratchBufferOffset <= scratchBuffer.Length); - return retVal; - } + ///// + ///// Format specified command with arguments, as a RESP command. Lua state + ///// can be specified to handle Lua tables as arguments. + ///// + //public ArgSlice FormatCommandAsResp(string cmd, object[] args, Lua state) + //{ + // if (scratchBuffer == null) + // ExpandScratchBuffer(64); + + // scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected + // int commandStartOffset = scratchBufferOffset; + // byte* ptr = scratchBufferHead + scratchBufferOffset; + + // while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) + // { + // ExpandScratchBuffer(scratchBuffer.Length + 1); + // ptr = scratchBufferHead + scratchBufferOffset; + // } + // scratchBufferOffset = (int)(ptr - scratchBufferHead); + + // while (!RespWriteUtils.WriteAsciiBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) + // { + // ExpandScratchBuffer(scratchBuffer.Length + 1); + // ptr = scratchBufferHead + scratchBufferOffset; + // } + // scratchBufferOffset = (int)(ptr - scratchBufferHead); + + // int count = 1; + // foreach (var item in args) + // { + // if (item is string str) + // { + // count++; + // while (!RespWriteUtils.WriteAsciiBulkString(str, ref ptr, scratchBufferHead + scratchBuffer.Length)) + // { + // ExpandScratchBuffer(scratchBuffer.Length + 1); + // ptr = scratchBufferHead + scratchBufferOffset; + // } + // scratchBufferOffset = (int)(ptr - scratchBufferHead); + // } + // else if (item is LuaTable t) + // { + // var d = state.GetTableDict(t); + // foreach (var value in d.Values) + // { + // count++; + // while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(value), ref ptr, scratchBufferHead + scratchBuffer.Length)) + // { + // ExpandScratchBuffer(scratchBuffer.Length + 1); + // ptr = scratchBufferHead + scratchBufferOffset; + // } + // scratchBufferOffset = (int)(ptr - scratchBufferHead); + // } + // t.Dispose(); + // } + // else if (item is long i) + // { + // count++; + // while (!RespWriteUtils.WriteIntegerAsBulkString((int)i, ref ptr, scratchBufferHead + scratchBuffer.Length)) + // { + // ExpandScratchBuffer(scratchBuffer.Length + 1); + // ptr = scratchBufferHead + scratchBufferOffset; + // } + // scratchBufferOffset = (int)(ptr - scratchBufferHead); + // } + // else + // { + // count++; + // while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(item), ref ptr, scratchBufferHead + scratchBuffer.Length)) + // { + // ExpandScratchBuffer(scratchBuffer.Length + 1); + // ptr = scratchBufferHead + scratchBufferOffset; + // } + // scratchBufferOffset = (int)(ptr - scratchBufferHead); + // } + // } + // if (count != args.Length + 1) + // { + // int extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); + // if (commandStartOffset < extraSpace) + // throw new InvalidOperationException("Invalid number of arguments"); + // commandStartOffset -= extraSpace; + // var head = scratchBufferHead + commandStartOffset; + // // There should be space as we have reserved it + // _ = RespWriteUtils.WriteArrayLength(count, ref head, scratchBufferHead + scratchBuffer.Length); + // } + + // var retVal = new ArgSlice(scratchBufferHead + commandStartOffset, scratchBufferOffset - commandStartOffset); + // Debug.Assert(scratchBufferOffset <= scratchBuffer.Length); + // return retVal; + //} /// /// Get length of a RESP Bulk-String formatted version of the specified ArgSlice diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index 3163aa6ba5..c60fdb61c4 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -250,13 +250,6 @@ private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner) var scriptResult = scriptRunner.Run(count, parseState); WriteObject(scriptResult); } - catch (LuaScriptException ex) - { - logger?.LogError(ex.InnerException ?? ex, "Error executing Lua script callback"); - while (!RespWriteUtils.WriteError("ERR " + (ex.InnerException ?? ex).Message, ref dcurr, dend)) - SendAndReset(); - return true; - } catch (Exception ex) { logger?.LogError(ex, "Error executing Lua script"); @@ -299,56 +292,61 @@ void WriteObject(object scriptResult) while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend)) SendAndReset(); } - else if (scriptResult is ArgSlice a) - { - while (!RespWriteUtils.WriteBulkString(a.ReadOnlySpan, ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is object[] o) - { - // Two objects one boolean value and the result from the Lua Call - while (!RespWriteUtils.WriteAsciiBulkString(o[1].ToString().AsSpan(), ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is LuaTable luaTable) - { - try - { - var retVal = luaTable["err"]; - if (retVal != null) - { - while (!RespWriteUtils.WriteError((string)retVal, ref dcurr, dend)) - SendAndReset(); - } - else - { - retVal = luaTable["ok"]; - if (retVal != null) - { - while (!RespWriteUtils.WriteAsciiBulkString((string)retVal, ref dcurr, dend)) - SendAndReset(); - } - else - { - int count = luaTable.Values.Count; - while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) - SendAndReset(); - foreach (var value in luaTable.Values) - { - WriteObject(value); - } - } - } - } - finally - { - luaTable.Dispose(); - } - } else { - throw new LuaScriptException("Unknown return type", ""); + // todo: this should all go away + throw new NotImplementedException(); } + //else if (scriptResult is ArgSlice a) + //{ + // while (!RespWriteUtils.WriteBulkString(a.ReadOnlySpan, ref dcurr, dend)) + // SendAndReset(); + //} + //else if (scriptResult is object[] o) + //{ + // // Two objects one boolean value and the result from the Lua Call + // while (!RespWriteUtils.WriteAsciiBulkString(o[1].ToString().AsSpan(), ref dcurr, dend)) + // SendAndReset(); + //} + //else if (scriptResult is LuaTable luaTable) + //{ + // try + // { + // var retVal = luaTable["err"]; + // if (retVal != null) + // { + // while (!RespWriteUtils.WriteError((string)retVal, ref dcurr, dend)) + // SendAndReset(); + // } + // else + // { + // retVal = luaTable["ok"]; + // if (retVal != null) + // { + // while (!RespWriteUtils.WriteAsciiBulkString((string)retVal, ref dcurr, dend)) + // SendAndReset(); + // } + // else + // { + // int count = luaTable.Values.Count; + // while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) + // SendAndReset(); + // foreach (var value in luaTable.Values) + // { + // WriteObject(value); + // } + // } + // } + // } + // finally + // { + // luaTable.Dispose(); + // } + //} + //else + //{ + // throw new LuaScriptException("Unknown return type", ""); + //} } else { diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 7c8dca9ff5..cc676a67e3 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -3,8 +3,10 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Text; using Garnet.common; +using KeraLua; using Microsoft.Extensions.Logging; namespace Garnet.server @@ -14,18 +16,25 @@ namespace Garnet.server /// internal sealed class LuaRunner : IDisposable { + // rooted to keep function pointer alive + readonly LuaFunction garnetCall; + + // references into Registry on the Lua side + readonly int sandboxEnvRegistryIndex; + readonly int keysTableRegistryIndex; + readonly int argvTableRegistryIndex; + readonly int loadSandboxedRegistryIndex; + int functionRegistryIndex; + readonly string source; readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly RespServerSession respServerSession; readonly ScratchBufferManager scratchBufferManager; readonly ILogger logger; readonly Lua state; - readonly LuaTable sandbox_env; - LuaFunction function; readonly TxnKeyEntries txnKeyEntries; readonly bool txnMode; - readonly LuaFunction garnetCall; - readonly LuaTable keyTable, argvTable; + int keyLength, argvLength; Queue disposeQueue; @@ -41,18 +50,28 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ this.scratchBufferManager = respServerSession?.scratchBufferManager; this.logger = logger; + sandboxEnvRegistryIndex = -1; + keysTableRegistryIndex = -1; + argvTableRegistryIndex = -1; + loadSandboxedRegistryIndex = -1; + functionRegistryIndex = -1; + + // todo: custom allocator? state = new Lua(); - state.State.Encoding = Encoding.UTF8; + Debug.Assert(state.GetTop() == 0, "Stack should be empty at allocation"); + if (txnMode) { this.txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); - garnetCall = state.RegisterFunction("garnet_call", this, this.GetType().GetMethod(nameof(garnet_call_txn))); + + garnetCall = garnet_call_txn; } else { - garnetCall = state.RegisterFunction("garnet_call", this, this.GetType().GetMethod("garnet_call")); + garnetCall = garnet_call; } - _ = state.DoString(@" + + var sandboxRes = state.DoString(@" import = function () end redis = {} function redis.call(cmd, ...) @@ -94,9 +113,31 @@ function load_sandboxed(source) return load(source, nil, nil, sandbox_env) end "); - sandbox_env = (LuaTable)state["sandbox_env"]; - keyTable = (LuaTable)state["KEYS"]; - argvTable = (LuaTable)state["ARGV"]; + if (!sandboxRes) + { + throw new GarnetException("Could not initialize Lua sandbox state"); + } + + // register garnet_call in global namespace + state.Register("garnet_call", garnetCall); + + var sandboxEnvType = state.GetGlobal("sandbox_env"); + Debug.Assert(sandboxEnvType == LuaType.Table, "Unexpected sandbox_env type"); + sandboxEnvRegistryIndex = state.Ref(LuaRegistry.Index); + + var keyTableType = state.GetGlobal("KEYS"); + Debug.Assert(keyTableType == LuaType.Table, "Unexpected KEYS type"); + keysTableRegistryIndex = state.Ref(LuaRegistry.Index); + + var argvTableType = state.GetGlobal("ARGV"); + Debug.Assert(argvTableType == LuaType.Table, "Unexpected ARGV type"); + argvTableRegistryIndex = state.Ref(LuaRegistry.Index); + + var loadSandboxedType = state.GetGlobal("load_sandboxed"); + Debug.Assert(loadSandboxedType == LuaType.Function, "Unexpected load_sandboxed type"); + loadSandboxedRegistryIndex = state.Ref(LuaRegistry.Index); + + Debug.Assert(state.GetTop() == 0, "Stack should be empty after initialization"); } /// @@ -112,23 +153,46 @@ public LuaRunner(ReadOnlySpan source, bool txnMode, RespServerSession resp /// public void Compile() { + Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); + try { - using var loader = (LuaFunction)state["load_sandboxed"]; - var result = loader.Call(source); - if (result?.Length == 1) + if (!state.CheckStack(2)) { - function = result[0] as LuaFunction; - return; + throw new GarnetException("Insufficient stack space to compile function"); } - if (result?.Length == 2) + Debug.Assert(state.GetTop() == 0, "Stack should be empty before compilation"); + + state.PushNumber(loadSandboxedRegistryIndex); + state.GetTable(LuaRegistry.Index); + state.PushString(source); + state.Call(1, -1); // multiple returns allowed + + var numRets = state.GetTop(); + if (numRets == 0) + { + throw new GarnetException("Shouldn't happen, no returns from load_sandboxed"); + } + else if (numRets == 1) + { + var returnType = state.Type(1); + if (returnType != LuaType.Function) + { + throw new GarnetException($"Could not compile function, got back a {returnType}"); + } + + functionRegistryIndex = state.Ref(LuaRegistry.Index); + } + else if (numRets == 2) { - throw new GarnetException($"Compilation error: {(string)result[1]}"); + var error = state.CheckString(2); + + throw new GarnetException($"Compilation error: {error}"); } else { - throw new GarnetException($"Unable to load script"); + throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}"); } } catch (Exception ex) @@ -136,6 +200,10 @@ public void Compile() logger?.LogError(ex, "CreateFunction threw an exception"); throw; } + finally + { + Debug.Assert(state.GetTop() == 0, "Stack should be empty after compilation"); + } } /// @@ -143,130 +211,207 @@ public void Compile() /// public void Dispose() { - garnetCall?.Dispose(); - keyTable?.Dispose(); - argvTable?.Dispose(); - sandbox_env?.Dispose(); - function?.Dispose(); state?.Dispose(); } /// /// Entry point for redis.call method from a Lua script (non-transactional mode) /// - /// - /// Parameters - /// - public object garnet_call(string cmd, params object[] args) - => respServerSession == null ? null : ProcessCommandFromScripting(respServerSession.basicGarnetApi, cmd, args); + public int garnet_call(IntPtr luaStatePtr) + { + Debug.Assert(state.Handle == luaStatePtr, "Unexpected state provided in call"); + + if (respServerSession == null) + { + return NoSessionError(); + } + + return ProcessCommandFromScripting(respServerSession.basicGarnetApi); + } /// /// Entry point for redis.call method from a Lua script (transactional mode) /// - /// - /// Parameters - /// - public object garnet_call_txn(string cmd, params object[] args) - => respServerSession == null ? null : ProcessCommandFromScripting(respServerSession.lockableGarnetApi, cmd, args); + public int garnet_call_txn(IntPtr luaStatePtr) + { + Debug.Assert(state.Handle == luaStatePtr, "Unexpected state provided in call"); + + if (respServerSession == null) + { + return NoSessionError(); + } + + return ProcessCommandFromScripting(respServerSession.lockableGarnetApi); + } /// - /// Entry point method for executing commands from a Lua Script + /// Call somehow came in with no valid resp server session. + /// + /// Raise an error. /// - unsafe object ProcessCommandFromScripting(TGarnetApi api, string cmd, params object[] args) - where TGarnetApi : IGarnetApi + /// + int NoSessionError() { - switch (cmd) - { - // We special-case a few performance-sensitive operations to directly invoke via the storage API - case "SET" when args.Length == 2: - case "set" when args.Length == 2: - { - if (!respServerSession.CheckACLPermissions(RespCommand.SET)) - return Encoding.ASCII.GetString(CmdStrings.RESP_ERR_NOAUTH); - var key = scratchBufferManager.CreateArgSlice(Convert.ToString(args[0])); - var value = scratchBufferManager.CreateArgSlice(Convert.ToString(args[1])); - _ = api.SET(key, value); - return "OK"; - } - case "GET": - case "get": - { - if (!respServerSession.CheckACLPermissions(RespCommand.GET)) - throw new Exception(Encoding.ASCII.GetString(CmdStrings.RESP_ERR_NOAUTH)); - var key = scratchBufferManager.CreateArgSlice(Convert.ToString(args[0])); - var status = api.GET(key, out var value); - if (status == GarnetStatus.OK) - return value.ToString(); - return null; - } - // As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized - // in future to provide parse state directly. - default: - { - var request = scratchBufferManager.FormatCommandAsResp(cmd, args, state); - _ = respServerSession.TryConsumeMessages(request.ptr, request.length); - var response = scratchBufferNetworkSender.GetResponse(); - var result = ProcessResponse(response.ptr, response.length); - scratchBufferNetworkSender.Reset(); - return result; - } - } + logger?.LogError("Lua call came in without a valid resp session"); + + state.PushString("No session available"); + + // this will never return, but we can pretend it does + return state.Error(); } /// - /// Process a RESP-formatted response from the RespServerSession + /// Entry point method for executing commands from a Lua Script /// - unsafe object ProcessResponse(byte* ptr, int length) + unsafe int ProcessCommandFromScripting(TGarnetApi api) + where TGarnetApi : IGarnetApi { - switch (*ptr) + try { - case (byte)'+': - if (RespReadUtils.ReadSimpleString(out var resultStr, ref ptr, ptr + length)) - return resultStr; - break; - case (byte)':': - if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) - return number; - break; - case (byte)'-': - if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) - return resultStr; - break; - - case (byte)'$': - if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) - return resultStr; - break; - - case (byte)'*': - if (RespReadUtils.ReadStringArrayResponseWithLengthHeader(out var resultArray, ref ptr, ptr + length)) - { - // Create return table - var returnValue = (LuaTable)state.DoString("return { }")[0]; - - // Queue up for disposal at the end of the script call - disposeQueue ??= new(); - disposeQueue.Enqueue(returnValue); - - // Populate the table - var i = 1; - foreach (var item in resultArray) - returnValue[i++] = item == null ? false : item; - return returnValue; - } - break; + var argCount = state.GetTop(); + + if (argCount == 0) + { + return state.Error("Please specify at least one argument for this redis lib call script"); + } + + // todo: no alloc + var cmd = state.CheckString(0).ToUpperInvariant(); + + switch (cmd) + { + // We special-case a few performance-sensitive operations to directly invoke via the storage API + case "SET" when argCount == 3: + { + if (!respServerSession.CheckACLPermissions(RespCommand.SET)) + { + // todo: no alloc + return state.Error(Encoding.UTF8.GetString(CmdStrings.RESP_ERR_NOAUTH)); + } + + // todo: no alloc + var keyBuf = state.CheckBuffer(1); + var valBuf = state.CheckBuffer(2); + + var key = scratchBufferManager.CreateArgSlice(keyBuf); + var value = scratchBufferManager.CreateArgSlice(valBuf); + _ = api.SET(key, value); - default: - throw new Exception("Unexpected response: " + Encoding.UTF8.GetString(new Span(ptr, length)).Replace("\n", "|").Replace("\r", "") + "]"); + state.PushString("OK"); + return 1; + } + case "GET" when argCount == 2: + { + if (!respServerSession.CheckACLPermissions(RespCommand.GET)) + { + // todo: no alloc + return state.Error(Encoding.UTF8.GetString(CmdStrings.RESP_ERR_NOAUTH)); + } + + // todo: no alloc + var keyBuf = state.CheckBuffer(1); + + var key = scratchBufferManager.CreateArgSlice(keyBuf); + var status = api.GET(key, out var value); + if (status == GarnetStatus.OK) + { + // todo: no alloc + state.PushBuffer(value.ToArray()); + } + else + { + state.PushNil(); + } + + return 1; + } + + // todo: implement + default: throw new NotImplementedException(); + + //// As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized + //// in future to provide parse state directly. + //default: + // { + // var request = scratchBufferManager.FormatCommandAsResp(cmd, args, state); + // _ = respServerSession.TryConsumeMessages(request.ptr, request.length); + // var response = scratchBufferNetworkSender.GetResponse(); + // var result = ProcessResponse(response.ptr, response.length); + // scratchBufferNetworkSender.Reset(); + // return result; + // } + } + } + catch (Exception e) + { + logger?.LogError(e, "During Lua script execution"); + + state.PushString(e.Message); + return state.Error(); } - return null; } + ///// + ///// Process a RESP-formatted response from the RespServerSession + ///// + //unsafe object ProcessResponse(byte* ptr, int length) + //{ + // switch (*ptr) + // { + // case (byte)'+': + // if (RespReadUtils.ReadSimpleString(out var resultStr, ref ptr, ptr + length)) + // return resultStr; + // break; + // case (byte)':': + // if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) + // return number; + // break; + // case (byte)'-': + // if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) + // return resultStr; + // break; + + // case (byte)'$': + // if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) + // return resultStr; + // break; + + // case (byte)'*': + // if (RespReadUtils.ReadStringArrayResponseWithLengthHeader(out var resultArray, ref ptr, ptr + length)) + // { + // // Create return table + // var returnValue = (LuaTable)state.DoString("return { }")[0]; + + // // Queue up for disposal at the end of the script call + // disposeQueue ??= new(); + // disposeQueue.Enqueue(returnValue); + + // // Populate the table + // var i = 1; + // foreach (var item in resultArray) + // returnValue[i++] = item == null ? false : item; + // return returnValue; + // } + // break; + + // default: + // throw new Exception("Unexpected response: " + Encoding.UTF8.GetString(new Span(ptr, length)).Replace("\n", "|").Replace("\r", "") + "]"); + // } + // return null; + //} + /// /// Runs the precompiled Lua function with specified parse state /// public object Run(int count, SessionParseState parseState) { + Debug.Assert(state.GetTop() == 0, "Stack should be empty at invocation start"); + + if (!state.CheckStack(2)) + { + throw new GarnetException("Insufficient stack space to run script"); + } + scratchBufferManager.Reset(); int offset = 1; @@ -278,32 +423,47 @@ public object Run(int count, SessionParseState parseState) { for (int i = 0; i < nKeys; i++) { + ref var key = ref parseState.GetArgSliceByRef(offset); + if (txnMode) { - var key = parseState.GetArgSliceByRef(offset); txnKeyEntries.AddKey(key, false, Tsavorite.core.LockType.Exclusive); if (!respServerSession.storageSession.objectStoreLockableContext.IsNull) txnKeyEntries.AddKey(key, true, Tsavorite.core.LockType.Exclusive); } - keyTable[i + 1] = parseState.GetString(offset++); + + // todo: no alloc + // todo: encoding is wrong here + + // equivalent to KEYS[i+1] = key.ToString() + state.PushNumber(i + 1); + state.PushString(key.ToString()); + state.SetTable(keysTableRegistryIndex); + + offset++; } - count -= nKeys; - //TODO: handle slot verification for Lua script keys - //if (NetworkKeyArraySlotVerify(keys, true)) - //{ - // return true; - //} + count -= nKeys; } if (count > 0) { for (int i = 0; i < count; i++) { - argvTable[i + 1] = parseState.GetString(offset++); + // todo: no alloc + // todo encoding is wrong here + + // equivalent to ARGV[i+1] = parseState.GetString(offset); + state.PushNumber(i + 1); + state.PushString(parseState.GetString(offset)); + state.SetTable(argvTableRegistryIndex); + + offset++; } } + Debug.Assert(state.GetTop() == 0, "Stack should be empty before running function"); + if (txnMode && nKeys > 0) { return RunTransaction(); @@ -364,40 +524,110 @@ void ResetParameters(int nKeys, int nArgs) { if (keyLength > nKeys) { - _ = state.DoString($"count = #KEYS for i={nKeys + 1}, {keyLength} do KEYS[i]=nil end"); + var keyResetRes = state.DoString($"count = #KEYS for i={nKeys + 1}, {keyLength} do KEYS[i]=nil end"); + + if (keyResetRes) + { + throw new GarnetException("Couldn't reset KEYS to run script"); + } } + keyLength = nKeys; + if (argvLength > nArgs) { - _ = state.DoString($"count = #ARGV for i={nArgs + 1}, {argvLength} do ARGV[i]=nil end"); + var argvResetRes = state.DoString($"count = #ARGV for i={nArgs + 1}, {argvLength} do ARGV[i]=nil end"); + + if (argvResetRes) + { + throw new GarnetException("Couldn't reset ARGV to run script"); + } } + argvLength = nArgs; } void LoadParameters(string[] keys, string[] argv) { + Debug.Assert(state.GetTop() == 0, "Stack should be empty before invocation starts"); + + if (!state.CheckStack(2)) + { + throw new GarnetException("Insufficient stack space to call function"); + } + ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); if (keys != null) { for (int i = 0; i < keys.Length; i++) - keyTable[i + 1] = keys[i]; + { + // equivalent to KEYS[i+1] = keys[i] + state.PushNumber(i + 1); + state.PushString(keys[i]); + state.SetTable(keysTableRegistryIndex); + } } if (argv != null) { for (int i = 0; i < argv.Length; i++) - argvTable[i + 1] = argv[i]; + { + // equivalent to ARGV[i+1] = keys[i] + state.PushNumber(i + 1); + state.PushString(argv[i]); + state.SetTable(argvTableRegistryIndex); + } } } /// /// Runs the precompiled Lua function /// - /// object Run() { - var result = function.Call(); - Cleanup(); - return result?.Length > 0 ? result[0] : null; + // todo: this shouldn't read the result, it should write the response out + + Debug.Assert(state.GetTop() == 0, "Stack should be empty at start of invocation"); + + if (!state.CheckStack(1)) + { + throw new GarnetException("Insufficient stack space to run function"); + } + + try + { + state.PushNumber(functionRegistryIndex); + state.GetTable(LuaRegistry.Index); + state.Call(0, 1); + + if (state.GetTop() == 0) + { + return null; + } + + var retType = state.Type(1); + if (retType == LuaType.Nil) + { + return null; + } + else if (retType == LuaType.Number) + { + return state.CheckNumber(1); + } + else if (retType == LuaType.String) + { + return state.CheckString(1); + } + else + { + // todo: implement + throw new NotImplementedException(); + } + } + finally + { + // FORCE the stack to be empty now + state.SetTop(0); + } } void Cleanup() diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index 252296c884..a1f5c945d1 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -18,7 +18,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("luanet.load_assembly('mscorlib')")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.Run()); ClassicAssert.AreEqual("[string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message); } @@ -26,7 +26,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("os = require('os'); return os.time();")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.Run()); ClassicAssert.AreEqual("[string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message); } @@ -34,7 +34,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("dofile();")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.Run()); ClassicAssert.AreEqual("[string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message); } @@ -42,7 +42,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("require \"notepad\"")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.Run()); ClassicAssert.AreEqual("[string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message); } @@ -50,7 +50,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("os.exit();")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.Run()); ClassicAssert.AreEqual("[string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message); } @@ -58,7 +58,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("import ('System.Diagnostics');")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.Run()); ClassicAssert.AreEqual("[string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message); } } From d9df977873ab7470d9080f45f3d67088306c76b8 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Mon, 9 Dec 2024 16:10:27 -0500 Subject: [PATCH 03/51] very basic functionality restored --- libs/server/Lua/LuaRunner.cs | 36 +++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index cc676a67e3..f05aa94623 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -113,7 +113,7 @@ function load_sandboxed(source) return load(source, nil, nil, sandbox_env) end "); - if (!sandboxRes) + if (sandboxRes) { throw new GarnetException("Could not initialize Lua sandbox state"); } @@ -276,7 +276,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } // todo: no alloc - var cmd = state.CheckString(0).ToUpperInvariant(); + var cmd = state.CheckString(1).ToUpperInvariant(); switch (cmd) { @@ -290,8 +290,8 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } // todo: no alloc - var keyBuf = state.CheckBuffer(1); - var valBuf = state.CheckBuffer(2); + var keyBuf = state.CheckBuffer(2); + var valBuf = state.CheckBuffer(3); var key = scratchBufferManager.CreateArgSlice(keyBuf); var value = scratchBufferManager.CreateArgSlice(valBuf); @@ -309,7 +309,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } // todo: no alloc - var keyBuf = state.CheckBuffer(1); + var keyBuf = state.CheckBuffer(2); var key = scratchBufferManager.CreateArgSlice(keyBuf); var status = api.GET(key, out var value); @@ -407,7 +407,7 @@ public object Run(int count, SessionParseState parseState) { Debug.Assert(state.GetTop() == 0, "Stack should be empty at invocation start"); - if (!state.CheckStack(2)) + if (!state.CheckStack(3)) { throw new GarnetException("Insufficient stack space to run script"); } @@ -421,7 +421,12 @@ public object Run(int count, SessionParseState parseState) if (nKeys > 0) { - for (int i = 0; i < nKeys; i++) + // get KEYS on the stack + state.PushNumber(keysTableRegistryIndex); + var loadedType = state.RawGet(LuaRegistry.Index); + Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting KEYS"); + + for (var i = 0; i < nKeys; i++) { ref var key = ref parseState.GetArgSliceByRef(offset); @@ -437,18 +442,25 @@ public object Run(int count, SessionParseState parseState) // equivalent to KEYS[i+1] = key.ToString() state.PushNumber(i + 1); - state.PushString(key.ToString()); - state.SetTable(keysTableRegistryIndex); + state.PushString(parseState.GetString(offset)); + state.RawSet(1); offset++; } + state.Pop(1); + count -= nKeys; } if (count > 0) { - for (int i = 0; i < count; i++) + // GET ARGV on the stack + state.PushNumber(argvTableRegistryIndex); + var loadedType = state.RawGet(LuaRegistry.Index); + Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting ARGV"); + + for (var i = 0; i < count; i++) { // todo: no alloc // todo encoding is wrong here @@ -456,10 +468,12 @@ public object Run(int count, SessionParseState parseState) // equivalent to ARGV[i+1] = parseState.GetString(offset); state.PushNumber(i + 1); state.PushString(parseState.GetString(offset)); - state.SetTable(argvTableRegistryIndex); + state.RawSet(1); offset++; } + + state.Pop(1); } Debug.Assert(state.GetTop() == 0, "Stack should be empty before running function"); From 61ba50deff90ca82a0b3074d73e1e9c73b43f756 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Mon, 9 Dec 2024 16:57:17 -0500 Subject: [PATCH 04/51] more complicated functionality sketched out; far from ideal, will need to upstream some KeraLua additions; also horribly broken --- libs/common/RespReadUtils.cs | 21 +- libs/server/ArgSlice/ScratchBufferManager.cs | 147 +++++------- libs/server/Lua/LuaRunner.cs | 232 ++++++++++++++----- 3 files changed, 240 insertions(+), 160 deletions(-) diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs index 8b93be12e0..9df08a48fe 100644 --- a/libs/common/RespReadUtils.cs +++ b/libs/common/RespReadUtils.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Buffers; using System.Buffers.Text; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -958,12 +959,14 @@ public static bool ReadStringArrayWithLengthHeader(out string[] result, ref byte } /// - /// Read string array with length header + /// Read string array with length header. + /// + /// result will be backed by an empty array or one rented from the given pool upon return. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool ReadStringArrayResponseWithLengthHeader(out string[] result, ref byte* ptr, byte* end) + public static bool ReadRentedStringArrayResponseWithLengthHeader(ArrayPool pool, out Memory result, ref byte* ptr, byte* end) { - result = null; + result = Array.Empty(); // Parse RESP array header if (!ReadSignedArrayLength(out var length, ref ptr, end)) @@ -978,22 +981,26 @@ public static bool ReadStringArrayResponseWithLengthHeader(out string[] result, } // Parse individual strings in the array - result = new string[length]; + result = ArrayPool.Shared.Rent(length); + result = result[..length]; + + var resultSpan = result.Span; + for (var i = 0; i < length; i++) { if (*ptr == '$') { - if (!ReadStringResponseWithLengthHeader(out result[i], ref ptr, end)) + if (!ReadStringResponseWithLengthHeader(out resultSpan[i], ref ptr, end)) return false; } else if (*ptr == '+') { - if (!ReadSimpleString(out result[i], ref ptr, end)) + if (!ReadSimpleString(out resultSpan[i], ref ptr, end)) return false; } else { - if (!ReadIntegerAsString(out result[i], ref ptr, end)) + if (!ReadIntegerAsString(out resultSpan[i], ref ptr, end)) return false; } } diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 69451e7c78..423f75efc7 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -216,97 +216,62 @@ public ArgSlice FormatScratch(int headerSize, ReadOnlySpan arg) return retVal; } - ///// - ///// Format specified command with arguments, as a RESP command. Lua state - ///// can be specified to handle Lua tables as arguments. - ///// - //public ArgSlice FormatCommandAsResp(string cmd, object[] args, Lua state) - //{ - // if (scratchBuffer == null) - // ExpandScratchBuffer(64); - - // scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected - // int commandStartOffset = scratchBufferOffset; - // byte* ptr = scratchBufferHead + scratchBufferOffset; - - // while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) - // { - // ExpandScratchBuffer(scratchBuffer.Length + 1); - // ptr = scratchBufferHead + scratchBufferOffset; - // } - // scratchBufferOffset = (int)(ptr - scratchBufferHead); - - // while (!RespWriteUtils.WriteAsciiBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) - // { - // ExpandScratchBuffer(scratchBuffer.Length + 1); - // ptr = scratchBufferHead + scratchBufferOffset; - // } - // scratchBufferOffset = (int)(ptr - scratchBufferHead); - - // int count = 1; - // foreach (var item in args) - // { - // if (item is string str) - // { - // count++; - // while (!RespWriteUtils.WriteAsciiBulkString(str, ref ptr, scratchBufferHead + scratchBuffer.Length)) - // { - // ExpandScratchBuffer(scratchBuffer.Length + 1); - // ptr = scratchBufferHead + scratchBufferOffset; - // } - // scratchBufferOffset = (int)(ptr - scratchBufferHead); - // } - // else if (item is LuaTable t) - // { - // var d = state.GetTableDict(t); - // foreach (var value in d.Values) - // { - // count++; - // while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(value), ref ptr, scratchBufferHead + scratchBuffer.Length)) - // { - // ExpandScratchBuffer(scratchBuffer.Length + 1); - // ptr = scratchBufferHead + scratchBufferOffset; - // } - // scratchBufferOffset = (int)(ptr - scratchBufferHead); - // } - // t.Dispose(); - // } - // else if (item is long i) - // { - // count++; - // while (!RespWriteUtils.WriteIntegerAsBulkString((int)i, ref ptr, scratchBufferHead + scratchBuffer.Length)) - // { - // ExpandScratchBuffer(scratchBuffer.Length + 1); - // ptr = scratchBufferHead + scratchBufferOffset; - // } - // scratchBufferOffset = (int)(ptr - scratchBufferHead); - // } - // else - // { - // count++; - // while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(item), ref ptr, scratchBufferHead + scratchBuffer.Length)) - // { - // ExpandScratchBuffer(scratchBuffer.Length + 1); - // ptr = scratchBufferHead + scratchBufferOffset; - // } - // scratchBufferOffset = (int)(ptr - scratchBufferHead); - // } - // } - // if (count != args.Length + 1) - // { - // int extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); - // if (commandStartOffset < extraSpace) - // throw new InvalidOperationException("Invalid number of arguments"); - // commandStartOffset -= extraSpace; - // var head = scratchBufferHead + commandStartOffset; - // // There should be space as we have reserved it - // _ = RespWriteUtils.WriteArrayLength(count, ref head, scratchBufferHead + scratchBuffer.Length); - // } - - // var retVal = new ArgSlice(scratchBufferHead + commandStartOffset, scratchBufferOffset - commandStartOffset); - // Debug.Assert(scratchBufferOffset <= scratchBuffer.Length); - // return retVal; - //} + /// + /// Format specified command with arguments, as a RESP command. Lua state + /// can be specified to handle Lua tables as arguments. + /// + public ArgSlice FormatCommandAsResp(string cmd, ReadOnlySpan args) + { + if (scratchBuffer == null) + ExpandScratchBuffer(64); + + scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected + int commandStartOffset = scratchBufferOffset; + byte* ptr = scratchBufferHead + scratchBufferOffset; + + while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + scratchBufferOffset = (int)(ptr - scratchBufferHead); + + while (!RespWriteUtils.WriteAsciiBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + scratchBufferOffset = (int)(ptr - scratchBufferHead); + + int count = 1; + foreach (var str in args) + { + count++; + while (!RespWriteUtils.WriteAsciiBulkString(str, ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + scratchBufferOffset = (int)(ptr - scratchBufferHead); + } + + if (count != args.Length + 1) + { + int extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); + if (commandStartOffset < extraSpace) + throw new InvalidOperationException("Invalid number of arguments"); + + commandStartOffset -= extraSpace; + var head = scratchBufferHead + commandStartOffset; + + // There should be space as we have reserved it + _ = RespWriteUtils.WriteArrayLength(count, ref head, scratchBufferHead + scratchBuffer.Length); + } + + var retVal = new ArgSlice(scratchBufferHead + commandStartOffset, scratchBufferOffset - commandStartOffset); + Debug.Assert(scratchBufferOffset <= scratchBuffer.Length); + return retVal; + } /// /// Get length of a RESP Bulk-String formatted version of the specified ArgSlice diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index f05aa94623..52f1e1e1cd 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -2,8 +2,10 @@ // Licensed under the MIT license. using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.Runtime.InteropServices; using System.Text; using Garnet.common; using KeraLua; @@ -293,6 +295,11 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) var keyBuf = state.CheckBuffer(2); var valBuf = state.CheckBuffer(3); + if (keyBuf == null || valBuf == null) + { + return ErrorInvalidArgumentType(state); + } + var key = scratchBufferManager.CreateArgSlice(keyBuf); var value = scratchBufferManager.CreateArgSlice(valBuf); _ = api.SET(key, value); @@ -311,6 +318,11 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // todo: no alloc var keyBuf = state.CheckBuffer(2); + if (keyBuf == null) + { + return ErrorInvalidArgumentType(state); + } + var key = scratchBufferManager.CreateArgSlice(keyBuf); var status = api.GET(key, out var value); if (status == GarnetStatus.OK) @@ -326,20 +338,60 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return 1; } - // todo: implement - default: throw new NotImplementedException(); - - //// As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized - //// in future to provide parse state directly. - //default: - // { - // var request = scratchBufferManager.FormatCommandAsResp(cmd, args, state); - // _ = respServerSession.TryConsumeMessages(request.ptr, request.length); - // var response = scratchBufferNetworkSender.GetResponse(); - // var result = ProcessResponse(response.ptr, response.length); - // scratchBufferNetworkSender.Reset(); - // return result; - // } + // As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized + // in future to provide parse state directly. + default: + { + // todo: remove all these allocations + var args = ArrayPool.Shared.Rent(argCount); + + try + { + var top = state.GetTop(); + + // move backwards validating arguments + // and removing them from the stack + for (var i = argCount - 1; i >= 0; i--) + { + var argType = state.Type(top); + if (argType == LuaType.Nil) + { + args[i] = null; + } + else if (argType == LuaType.String) + { + args[i] = state.CheckString(top); + } + else if (argType == LuaType.Number) + { + var asNum = state.CheckNumber(top); + args[i] = ((long)asNum).ToString(); + } + else + { + state.Pop(1); + + return ErrorInvalidArgumentType(state); + } + + state.Pop(1); + top--; + } + + Debug.Assert(state.GetTop() == 0, "Should have emptied the stack"); + + var request = scratchBufferManager.FormatCommandAsResp(cmd, args.AsSpan()[..argCount]); + _ = respServerSession.TryConsumeMessages(request.ptr, request.length); + var response = scratchBufferNetworkSender.GetResponse(); + var result = ProcessResponse(response.ptr, response.length); + scratchBufferNetworkSender.Reset(); + return result; + } + finally + { + ArrayPool.Shared.Return(args); + } + } } } catch (Exception e) @@ -349,56 +401,112 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) state.PushString(e.Message); return state.Error(); } + + static int ErrorInvalidArgumentType(Lua state) + { + state.PushString("Lua redis lib command arguments must be strings or integers"); + return state.Error(); + } } - ///// - ///// Process a RESP-formatted response from the RespServerSession - ///// - //unsafe object ProcessResponse(byte* ptr, int length) - //{ - // switch (*ptr) - // { - // case (byte)'+': - // if (RespReadUtils.ReadSimpleString(out var resultStr, ref ptr, ptr + length)) - // return resultStr; - // break; - // case (byte)':': - // if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) - // return number; - // break; - // case (byte)'-': - // if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) - // return resultStr; - // break; - - // case (byte)'$': - // if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) - // return resultStr; - // break; - - // case (byte)'*': - // if (RespReadUtils.ReadStringArrayResponseWithLengthHeader(out var resultArray, ref ptr, ptr + length)) - // { - // // Create return table - // var returnValue = (LuaTable)state.DoString("return { }")[0]; - - // // Queue up for disposal at the end of the script call - // disposeQueue ??= new(); - // disposeQueue.Enqueue(returnValue); - - // // Populate the table - // var i = 1; - // foreach (var item in resultArray) - // returnValue[i++] = item == null ? false : item; - // return returnValue; - // } - // break; - - // default: - // throw new Exception("Unexpected response: " + Encoding.UTF8.GetString(new Span(ptr, length)).Replace("\n", "|").Replace("\r", "") + "]"); - // } - // return null; - //} + /// + /// Process a RESP-formatted response from the RespServerSession. + /// + /// Pushes result onto state stack, and returns 1 + /// + unsafe int ProcessResponse(byte* ptr, int length) + { + Debug.Assert(state.GetTop() == 0, "Stack should be empty before processing response"); + + if (!state.CheckStack(3)) + { + throw new GarnetException("Insufficent space on stack to prepare response"); + } + + switch (*ptr) + { + case (byte)'+': + // todo: remove alloc + if (RespReadUtils.ReadSimpleString(out var resultStr, ref ptr, ptr + length)) + { + state.PushString(resultStr); + return 1; + } + goto default; + + case (byte)':': + if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) + { + state.PushNumber(number); + return 1; + } + goto default; + + case (byte)'-': + // todo: remove alloc + if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) + { + state.PushString(resultStr); + return state.Error(); + } + goto default; + + case (byte)'$': + // todo: remove alloc + if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) + { + state.PushString(resultStr); + return 1; + } + goto default; + + case (byte)'*': + // todo: remove allocs + if (RespReadUtils.ReadRentedStringArrayResponseWithLengthHeader(ArrayPool.Shared, out var resultArray, ref ptr, ptr + length)) + { + try + { + // create the new table + state.NewTable(); + Debug.Assert(state.GetTop() == 1, "New table should be at top of stack"); + + // Populate the table + var i = 1; + foreach (var item in resultArray.Span) + { + state.PushNumber(i); + + if (item == null) + { + state.PushNil(); + } + else + { + state.PushString(item); + } + + state.RawSet(1); + } + + return 1; + } + finally + { + if (!resultArray.IsEmpty) + { + if (MemoryMarshal.TryGetArray(resultArray, out ArraySegment rented)) + { + ArrayPool.Shared.Return(rented.Array); + } + } + } + } + goto default; + + default: + throw new Exception("Unexpected response: " + Encoding.UTF8.GetString(new Span(ptr, length)).Replace("\n", "|").Replace("\r", "") + "]"); + } + } /// /// Runs the precompiled Lua function with specified parse state From de7c4bea03fe235f3ec65a453e968e389494abf9 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Mon, 9 Dec 2024 17:10:30 -0500 Subject: [PATCH 05/51] handle errors during script run (compilation is still unguarded, as we can't really do much there) --- libs/server/Lua/LuaRunner.cs | 66 ++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 52f1e1e1cd..f3ea5fcbd3 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -167,7 +167,9 @@ public void Compile() Debug.Assert(state.GetTop() == 0, "Stack should be empty before compilation"); state.PushNumber(loadSandboxedRegistryIndex); - state.GetTable(LuaRegistry.Index); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Function, "Unexpected load_sandboxed type"); + state.PushString(source); state.Call(1, -1); // multiple returns allowed @@ -710,7 +712,7 @@ object Run() Debug.Assert(state.GetTop() == 0, "Stack should be empty at start of invocation"); - if (!state.CheckStack(1)) + if (!state.CheckStack(2)) { throw new GarnetException("Insufficient stack space to run function"); } @@ -718,31 +720,53 @@ object Run() try { state.PushNumber(functionRegistryIndex); - state.GetTable(LuaRegistry.Index); - state.Call(0, 1); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Function, "Unexpected type for function to invoke"); - if (state.GetTop() == 0) + var callRes = state.PCall(0, 1, 0); + if (callRes == LuaStatus.OK) { - return null; - } + // the actual call worked, handle the response - var retType = state.Type(1); - if (retType == LuaType.Nil) - { - return null; - } - else if (retType == LuaType.Number) - { - return state.CheckNumber(1); - } - else if (retType == LuaType.String) - { - return state.CheckString(1); + if (state.GetTop() == 0) + { + return null; + } + + var retType = state.Type(1); + if (retType == LuaType.Nil) + { + return null; + } + else if (retType == LuaType.Number) + { + return state.CheckNumber(1); + } + else if (retType == LuaType.String) + { + return state.CheckString(1); + } + else + { + // todo: implement + throw new NotImplementedException(); + } } else { - // todo: implement - throw new NotImplementedException(); + // an error was raised + + var stackTop = state.GetTop(); + if (stackTop == 0) + { + // and we got nothing back + throw new GarnetException("An error occurred while invoking a Lua script"); + } + + // todo: we should just write this out, not throw + // it's not exceptional + var msg = state.CheckString(stackTop); + throw new GarnetException(msg); } } finally From 400f2a6724a104bc04d444d718169539b07e25a0 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Mon, 9 Dec 2024 17:19:40 -0500 Subject: [PATCH 06/51] faster (and easier to understand, frankly) way to set and clear KEYS and ARGV at start --- libs/server/Lua/LuaRunner.cs | 40 +++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index f3ea5fcbd3..1890330567 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -646,13 +646,25 @@ object RunTransaction() void ResetParameters(int nKeys, int nArgs) { + Debug.Assert(state.GetTop() == 0, "Stack should be empty before resetting parameters"); + + if (!state.CheckStack(2)) + { + throw new GarnetException("Insufficient space on stack to reset parameters"); + } + if (keyLength > nKeys) { - var keyResetRes = state.DoString($"count = #KEYS for i={nKeys + 1}, {keyLength} do KEYS[i]=nil end"); + // get KEYS on the stack + state.PushNumber(keysTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); - if (keyResetRes) + // clear all the values in KEYS that we aren't going to set anyway + for (var i = nKeys + 1; i <= keyLength; i++) { - throw new GarnetException("Couldn't reset KEYS to run script"); + state.PushNil(); + state.RawSetInteger(1, i); } } @@ -660,15 +672,21 @@ void ResetParameters(int nKeys, int nArgs) if (argvLength > nArgs) { - var argvResetRes = state.DoString($"count = #ARGV for i={nArgs + 1}, {argvLength} do ARGV[i]=nil end"); + // get ARGV on the stack + state.PushNumber(argvTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); - if (argvResetRes) + for (var i = nArgs + 1; i <= argvLength; i++) { - throw new GarnetException("Couldn't reset ARGV to run script"); + state.PushNil(); + state.RawSetInteger(1, i); } } argvLength = nArgs; + + Debug.Assert(state.GetTop() == 0, "Stack should be empty after resetting parameters"); } void LoadParameters(string[] keys, string[] argv) @@ -683,22 +701,20 @@ void LoadParameters(string[] keys, string[] argv) ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); if (keys != null) { - for (int i = 0; i < keys.Length; i++) + for (var i = 0; i < keys.Length; i++) { // equivalent to KEYS[i+1] = keys[i] - state.PushNumber(i + 1); state.PushString(keys[i]); - state.SetTable(keysTableRegistryIndex); + state.RawSetInteger(keysTableRegistryIndex, i + 1); } } if (argv != null) { - for (int i = 0; i < argv.Length; i++) + for (var i = 0; i < argv.Length; i++) { // equivalent to ARGV[i+1] = keys[i] - state.PushNumber(i + 1); state.PushString(argv[i]); - state.SetTable(argvTableRegistryIndex); + state.RawSetInteger(argvTableRegistryIndex, i + 1); } } } From 1be33264c14ef5289fd5f1ff6bf86f81f757fd44 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Mon, 9 Dec 2024 17:44:22 -0500 Subject: [PATCH 07/51] a little more functional; moving OK/ERR checking into the loader script may not be viable... but was worth experimenting with --- libs/server/Lua/LuaRunner.cs | 32 +++++++++++++++++------------- test/Garnet.test/LuaScriptTests.cs | 22 ++++++-------------- 2 files changed, 24 insertions(+), 30 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 1890330567..9efc3c9fa8 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -38,7 +38,6 @@ internal sealed class LuaRunner : IDisposable readonly bool txnMode; int keyLength, argvLength; - Queue disposeQueue; /// /// Creates a new runner with the source of the script @@ -112,7 +111,24 @@ function redis.error_reply(text) } function load_sandboxed(source) if (not source) then return nil end - return load(source, nil, nil, sandbox_env) + local rawFunc = load(source, nil, nil, sandbox_env) + + return function() + local rawRet = rawFunc() + + -- handle err and ok response wrappers without crossing the pinvoke boundary + if rawRet and type(rawRet) == ""table"" then + if rawRet.err then + error(rawRet.err) + end + + if rawRet.ok then + return rawRet.ok + end + end + + return rawRet + end end "); if (sandboxRes) @@ -791,17 +807,5 @@ object Run() state.SetTop(0); } } - - void Cleanup() - { - if (disposeQueue != null) - { - while (disposeQueue.Count > 0) - { - var table = disposeQueue.Dequeue(); - table.Dispose(); - } - } - } } } \ No newline at end of file diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index 3d326f2a34..f40ddde3d7 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -330,23 +330,13 @@ public void FailureStatusReturn() using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); var statusReplyScript = "return redis.error_reply('Failure')"; - try - { - _ = db.ScriptEvaluate(statusReplyScript); - } - catch (RedisServerException ex) - { - ClassicAssert.AreEqual(ex.Message, "Failure"); - } + + var excReply = ClassicAssert.Throws(() => db.ScriptEvaluate(statusReplyScript)); + ClassicAssert.AreEqual("ERR Failure", excReply.Message); + var directReplyScript = "return { err = 'Failure' }"; - try - { - _ = db.ScriptEvaluate(directReplyScript); - } - catch (RedisServerException ex) - { - ClassicAssert.AreEqual(ex.Message, "Failure"); - } + var excDirect = ClassicAssert.Throws(() => db.ScriptEvaluate(directReplyScript)); + ClassicAssert.AreEqual("Failure", excDirect.Message); } [Test] From 40d116e67b597c2effd447bb06f5ca9d5931b1bf Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 11:02:32 -0500 Subject: [PATCH 08/51] redis.call should not accept table arguments, only strings and numbers; fix some ZADD scripting tests; special err response stuff can't be done on the Lua side, remove it --- libs/server/Lua/LuaCommands.cs | 1 + libs/server/Lua/LuaRunner.cs | 36 +++++++++++++++++------------- test/Garnet.test/LuaScriptTests.cs | 4 ++-- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index c60fdb61c4..c3c4ef2a00 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -292,6 +292,7 @@ void WriteObject(object scriptResult) while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend)) SendAndReset(); } + else { // todo: this should all go away diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 9efc3c9fa8..5647af4a70 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -116,15 +116,10 @@ function load_sandboxed(source) return function() local rawRet = rawFunc() - -- handle err and ok response wrappers without crossing the pinvoke boundary - if rawRet and type(rawRet) == ""table"" then - if rawRet.err then - error(rawRet.err) - end - - if rawRet.ok then - return rawRet.ok - end + -- handle ok response wrappers without crossing the pinvoke boundary + -- err response wrapper requires a bit more work, but is also rarer + if rawRet and type(rawRet) == ""table"" and rawRet.ok then + return rawRet.ok end return rawRet @@ -361,7 +356,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) default: { // todo: remove all these allocations - var args = ArrayPool.Shared.Rent(argCount); + var stackArgs = ArrayPool.Shared.Rent(argCount); try { @@ -374,16 +369,16 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) var argType = state.Type(top); if (argType == LuaType.Nil) { - args[i] = null; + stackArgs[i] = null; } else if (argType == LuaType.String) { - args[i] = state.CheckString(top); + stackArgs[i] = state.CheckString(top); } else if (argType == LuaType.Number) { var asNum = state.CheckNumber(top); - args[i] = ((long)asNum).ToString(); + stackArgs[i] = ((long)asNum).ToString(); } else { @@ -398,7 +393,10 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) Debug.Assert(state.GetTop() == 0, "Should have emptied the stack"); - var request = scratchBufferManager.FormatCommandAsResp(cmd, args.AsSpan()[..argCount]); + // command is handled specially, so trim it off + var cmdArgs = stackArgs.AsSpan().Slice(1, argCount - 1); + + var request = scratchBufferManager.FormatCommandAsResp(cmd, cmdArgs); _ = respServerSession.TryConsumeMessages(request.ptr, request.length); var response = scratchBufferNetworkSender.GetResponse(); var result = ProcessResponse(response.ptr, response.length); @@ -407,7 +405,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } finally { - ArrayPool.Shared.Return(args); + ArrayPool.Shared.Return(stackArgs); } } } @@ -682,6 +680,8 @@ void ResetParameters(int nKeys, int nArgs) state.PushNil(); state.RawSetInteger(1, i); } + + state.Pop(1); } keyLength = nKeys; @@ -698,6 +698,8 @@ void ResetParameters(int nKeys, int nArgs) state.PushNil(); state.RawSetInteger(1, i); } + + state.Pop(1); } argvLength = nArgs; @@ -772,7 +774,9 @@ object Run() } else if (retType == LuaType.Number) { - return state.CheckNumber(1); + // Redis appears to unconditionally convert all "number" replies to integer replies + // so we match that + return (long)state.CheckNumber(1); } else if (retType == LuaType.String) { diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index f40ddde3d7..167fa1c6b3 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -143,7 +143,7 @@ public void CanDoEvalShaWithZAddMultiPairSE() var db = redis.GetDatabase(0); // Create a sorted set - var script = "local ptable = {100, \"value1\", 200, \"value2\"}; return redis.call('zadd', KEYS[1], ptable)"; + var script = "return redis.call('zadd', KEYS[1], 100, \"value1\", 200, \"value2\")"; var result = db.ScriptEvaluate(script, [(RedisKey)"mysskey"]); ClassicAssert.IsTrue(result.ToString() == "2"); @@ -156,7 +156,7 @@ public void CanDoEvalShaWithZAddMultiPairSE() ClassicAssert.IsTrue(result.ToString() == "0"); // Add more pairs - script = "local ptable = {300, \"value3\", 400, \"value4\"}; return redis.call('zadd', KEYS[1], ptable)"; + script = "return redis.call('zadd', KEYS[1], 300, \"value3\", 400, \"value4\")"; result = db.ScriptEvaluate(script, [(RedisKey)"mysskey"]); ClassicAssert.IsTrue(result.ToString() == "2"); From ed712c768e8de354b282cf0274eda8d29b38076d Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 12:29:39 -0500 Subject: [PATCH 09/51] scripting appears to be at parity with NLua implementation --- libs/server/Lua/LuaCommands.cs | 15 ++++ libs/server/Lua/LuaRunner.cs | 109 +++++++++++++++++++++++++---- test/Garnet.test/LuaScriptTests.cs | 9 ++- 3 files changed, 117 insertions(+), 16 deletions(-) diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index c3c4ef2a00..79e4bf20da 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -292,7 +292,22 @@ void WriteObject(object scriptResult) while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend)) SendAndReset(); } + else if (scriptResult is object[] o) + { + var count = o.Length; + while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) + SendAndReset(); + foreach (var value in o) + { + WriteObject(value); + } + } + else if (scriptResult is ErrorResult e) + { + while (!RespWriteUtils.WriteError(e.Message, ref dcurr, dend)) + SendAndReset(); + } else { // todo: this should all go away diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 5647af4a70..2046008623 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -3,8 +3,8 @@ using System; using System.Buffers; -using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using System.Runtime.InteropServices; using System.Text; using Garnet.common; @@ -13,6 +13,9 @@ namespace Garnet.server { + // hack hack hack + internal sealed record ErrorResult(string Message); + /// /// Creates the instance to run Lua scripts /// @@ -82,7 +85,7 @@ function redis.status_reply(text) return text end function redis.error_reply(text) - return { err = text } + return { err = 'ERR ' .. text } end KEYS = {} ARGV = {} @@ -471,7 +474,16 @@ unsafe int ProcessResponse(byte* ptr, int length) // todo: remove alloc if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) { - state.PushString(resultStr); + // bulk null strings are mapped to FALSE + // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + if (resultStr == null) + { + state.PushBoolean(false); + } + else + { + state.PushString(resultStr); + } return 1; } goto default; @@ -490,18 +502,20 @@ unsafe int ProcessResponse(byte* ptr, int length) var i = 1; foreach (var item in resultArray.Span) { - state.PushNumber(i); - + // null strings are mapped to false + // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion if (item == null) { - state.PushNil(); + state.PushBoolean(false); } else { state.PushString(item); } - state.RawSet(1); + state.RawSetInteger(1, i); + + i++; } return 1; @@ -566,7 +580,7 @@ public object Run(int count, SessionParseState parseState) // equivalent to KEYS[i+1] = key.ToString() state.PushNumber(i + 1); - state.PushString(parseState.GetString(offset)); + state.PushString(key.ToString()); state.RawSet(1); offset++; @@ -722,8 +736,7 @@ void LoadParameters(string[] keys, string[] argv) for (var i = 0; i < keys.Length; i++) { // equivalent to KEYS[i+1] = keys[i] - state.PushString(keys[i]); - state.RawSetInteger(keysTableRegistryIndex, i + 1); + throw new NotImplementedException(); } } if (argv != null) @@ -731,8 +744,7 @@ void LoadParameters(string[] keys, string[] argv) for (var i = 0; i < argv.Length; i++) { // equivalent to ARGV[i+1] = keys[i] - state.PushString(argv[i]); - state.RawSetInteger(argvTableRegistryIndex, i + 1); + throw new NotImplementedException(); } } } @@ -742,8 +754,10 @@ void LoadParameters(string[] keys, string[] argv) /// object Run() { - // todo: this shouldn't read the result, it should write the response out + // todo: mapping is dependent on Resp2 vs Resp3 settings + // and that's not implemented at all + // todo: this shouldn't read the result, it should write the response out Debug.Assert(state.GetTop() == 0, "Stack should be empty at start of invocation"); if (!state.CheckStack(2)) @@ -774,14 +788,81 @@ object Run() } else if (retType == LuaType.Number) { - // Redis appears to unconditionally convert all "number" replies to integer replies + // Redis unconditionally converts all "number" replies to integer replies // so we match that + // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion return (long)state.CheckNumber(1); } else if (retType == LuaType.String) { return state.CheckString(1); } + else if (retType == LuaType.Boolean) + { + // Redis maps Lua false to null, and Lua true to 1 + // this is strange, but documented + // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + if (state.ToBoolean(1)) + { + return 1L; + } + else + { + return null; + } + } + else if (retType == LuaType.Table) + { + // todo: this is hacky, and doesn't support nested arrays or whatever + // but is good enough for now + // when refactored to avoid intermediate objects this should be fixed + + // note: because we are dealing with a user provided type, we MUST respect + // metatables - so we can't use any of the RawXXX methods + + // if the key err is in there, we need to short circuit + state.PushString("err"); + + var errType = state.GetTable(1); + if (errType == LuaType.String) + { + var errStr = state.CheckString(2); + // hack hack hack + // todo: all this goes away when we write results directly + return new ErrorResult(errStr); + } + + state.Pop(1); + + // otherwise, we need to convert the table to an array + var tableLength = state.Length(1); + + var ret = new object[tableLength]; + for (var i = 1; i <= tableLength; i++) + { + var type = state.GetInteger(1, i); + switch (type) + { + case LuaType.String: + ret[i - 1] = state.CheckString(2); + break; + case LuaType.Number: + ret[i - 1] = (long)state.CheckNumber(2); + break; + case LuaType.Boolean: + ret[i - 1] = state.ToBoolean(2) ? 1L : null; + break; + // Redis stops processesing the array when a nil is encountered + // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + case LuaType.Nil: + return ret.Take(i - 1).ToArray(); + } + + state.Pop(1); + } + + return ret; + } else { // todo: implement diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index 167fa1c6b3..1c1ab2bdb1 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -420,10 +420,13 @@ local function callgetrange() using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); + var response1 = db.ScriptEvaluate(script1, ["key1", "key2"], ["foo", 3, 60_000]); ClassicAssert.AreEqual("OK", (string)response1); + var response2 = db.ScriptEvaluate(script2, ["key3"], ["foo"]); ClassicAssert.AreEqual(false, (bool)response2); + var response3 = db.ScriptEvaluate(script2, ["key1", "key2"], ["foo"]); ClassicAssert.AreEqual("OK", (string)response3); } @@ -432,8 +435,10 @@ local function callgetrange() public void ComplexLuaTest3() { var script1 = """ -return redis.call("mget", unpack(KEYS)) -"""; + return redis.call("mget", unpack(KEYS)) + """; + + //var script1 = "return KEYS"; using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(0); From 4ab85224521d1ec7f3199bc0977c95de074fdd76 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 13:13:34 -0500 Subject: [PATCH 10/51] fix error propogation; implement needed functions for direct script running; all tests passing again --- libs/server/Lua/LuaRunner.cs | 41 ++++++++++++++++++++++++++++++------ 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 2046008623..f1afb35084 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -114,8 +114,14 @@ function redis.error_reply(text) } function load_sandboxed(source) if (not source) then return nil end - local rawFunc = load(source, nil, nil, sandbox_env) + local rawFunc, err = load(source, nil, nil, sandbox_env) + -- compilation error is returned directly + if err then + return rawFunc, err + end + + -- otherwise we wrap the compiled function in a helper return function() local rawRet = rawFunc() @@ -171,6 +177,8 @@ public void Compile() { Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); + Debug.Assert(state.GetTop() == 0, "Stack should be empty at start of compilation"); + try { if (!state.CheckStack(2)) @@ -178,8 +186,6 @@ public void Compile() throw new GarnetException("Insufficient stack space to compile function"); } - Debug.Assert(state.GetTop() == 0, "Stack should be empty before compilation"); - state.PushNumber(loadSandboxedRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Function, "Unexpected load_sandboxed type"); @@ -207,6 +213,7 @@ public void Compile() var error = state.CheckString(2); throw new GarnetException($"Compilation error: {error}"); + } else { @@ -220,7 +227,8 @@ public void Compile() } finally { - Debug.Assert(state.GetTop() == 0, "Stack should be empty after compilation"); + // force stack empty after compilation, no matter what happens + state.SetTop(0); } } @@ -481,7 +489,7 @@ unsafe int ProcessResponse(byte* ptr, int length) state.PushBoolean(false); } else - { + { state.PushString(resultStr); } return 1; @@ -733,20 +741,39 @@ void LoadParameters(string[] keys, string[] argv) ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); if (keys != null) { + // get KEYS on the stack + state.PushNumber(keysTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); + for (var i = 0; i < keys.Length; i++) { // equivalent to KEYS[i+1] = keys[i] - throw new NotImplementedException(); + state.PushString(keys[i]); + state.RawSetInteger(1, i + 1); } + + state.Pop(1); } + if (argv != null) { + // get ARGV on the stack + state.PushNumber(argvTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); + for (var i = 0; i < argv.Length; i++) { // equivalent to ARGV[i+1] = keys[i] - throw new NotImplementedException(); + state.PushString(argv[i]); + state.RawSetInteger(1, i + 1); } + + state.Pop(1); } + + Debug.Assert(state.GetTop() == 0, "Stack should be empty when invocation ends"); } /// From d3bfdcec675546bcabeb716374915710801c43db Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 14:17:05 -0500 Subject: [PATCH 11/51] start on removing allocations and fixing encodings --- libs/server/Lua/LuaRunner.cs | 216 ++++++++++++++++--------------- libs/server/Lua/NativeMethods.cs | 82 ++++++++++++ 2 files changed, 195 insertions(+), 103 deletions(-) create mode 100644 libs/server/Lua/NativeMethods.cs diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index f1afb35084..f44552e67e 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -298,128 +298,124 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) if (argCount == 0) { - return state.Error("Please specify at least one argument for this redis lib call script"); + return LuaError("Please specify at least one argument for this redis lib call script"u8); } - // todo: no alloc - var cmd = state.CheckString(1).ToUpperInvariant(); + if (!NativeMethods.CheckBuffer(state.Handle, 1, out var cmdSpan)) + { + return LuaError("Unknown Redis command called from script"u8); + } - switch (cmd) + // We special-case a few performance-sensitive operations to directly invoke via the storage API + if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "SET"u8) && argCount == 3) { - // We special-case a few performance-sensitive operations to directly invoke via the storage API - case "SET" when argCount == 3: - { - if (!respServerSession.CheckACLPermissions(RespCommand.SET)) - { - // todo: no alloc - return state.Error(Encoding.UTF8.GetString(CmdStrings.RESP_ERR_NOAUTH)); - } + if (!respServerSession.CheckACLPermissions(RespCommand.SET)) + { + return LuaError(CmdStrings.RESP_ERR_NOAUTH); + } - // todo: no alloc - var keyBuf = state.CheckBuffer(2); - var valBuf = state.CheckBuffer(3); + // todo: no alloc + var keyBuf = state.CheckBuffer(2); + var valBuf = state.CheckBuffer(3); - if (keyBuf == null || valBuf == null) - { - return ErrorInvalidArgumentType(state); - } + if (keyBuf == null || valBuf == null) + { + return ErrorInvalidArgumentType(state); + } - var key = scratchBufferManager.CreateArgSlice(keyBuf); - var value = scratchBufferManager.CreateArgSlice(valBuf); - _ = api.SET(key, value); + var key = scratchBufferManager.CreateArgSlice(keyBuf); + var value = scratchBufferManager.CreateArgSlice(valBuf); + _ = api.SET(key, value); - state.PushString("OK"); - return 1; - } - case "GET" when argCount == 2: - { - if (!respServerSession.CheckACLPermissions(RespCommand.GET)) - { - // todo: no alloc - return state.Error(Encoding.UTF8.GetString(CmdStrings.RESP_ERR_NOAUTH)); - } + NativeMethods.PushBuffer(state.Handle, "OK"u8); + return 1; + } + else if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "GET"u8) && argCount == 2) + { + if (!respServerSession.CheckACLPermissions(RespCommand.GET)) + { + return LuaError(CmdStrings.RESP_ERR_NOAUTH); + } - // todo: no alloc - var keyBuf = state.CheckBuffer(2); + // todo: no alloc + var keyBuf = state.CheckBuffer(2); - if (keyBuf == null) - { - return ErrorInvalidArgumentType(state); - } + if (keyBuf == null) + { + return ErrorInvalidArgumentType(state); + } - var key = scratchBufferManager.CreateArgSlice(keyBuf); - var status = api.GET(key, out var value); - if (status == GarnetStatus.OK) - { - // todo: no alloc - state.PushBuffer(value.ToArray()); - } - else - { - state.PushNil(); - } + var key = scratchBufferManager.CreateArgSlice(keyBuf); + var status = api.GET(key, out var value); + if (status == GarnetStatus.OK) + { + NativeMethods.PushBuffer(state.Handle, value.ReadOnlySpan); + } + else + { + state.PushNil(); + } - return 1; - } + return 1; + } + + // As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized + // in future to provide parse state directly. + + // todo: remove all these allocations + var stackArgs = ArrayPool.Shared.Rent(argCount); + var cmd = Encoding.UTF8.GetString(cmdSpan); + + try + { + var top = state.GetTop(); - // As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized - // in future to provide parse state directly. - default: + // move backwards validating arguments + // and removing them from the stack + for (var i = argCount - 1; i >= 0; i--) + { + var argType = state.Type(top); + if (argType == LuaType.Nil) + { + stackArgs[i] = null; + } + else if (argType == LuaType.String) + { + stackArgs[i] = state.CheckString(top); + } + else if (argType == LuaType.Number) { - // todo: remove all these allocations - var stackArgs = ArrayPool.Shared.Rent(argCount); + var asNum = state.CheckNumber(top); + stackArgs[i] = ((long)asNum).ToString(); + } + else + { + state.Pop(1); - try - { - var top = state.GetTop(); + return ErrorInvalidArgumentType(state); + } - // move backwards validating arguments - // and removing them from the stack - for (var i = argCount - 1; i >= 0; i--) - { - var argType = state.Type(top); - if (argType == LuaType.Nil) - { - stackArgs[i] = null; - } - else if (argType == LuaType.String) - { - stackArgs[i] = state.CheckString(top); - } - else if (argType == LuaType.Number) - { - var asNum = state.CheckNumber(top); - stackArgs[i] = ((long)asNum).ToString(); - } - else - { - state.Pop(1); - - return ErrorInvalidArgumentType(state); - } - - state.Pop(1); - top--; - } + state.Pop(1); + top--; + } - Debug.Assert(state.GetTop() == 0, "Should have emptied the stack"); + Debug.Assert(state.GetTop() == 0, "Should have emptied the stack"); - // command is handled specially, so trim it off - var cmdArgs = stackArgs.AsSpan().Slice(1, argCount - 1); + // command is handled specially, so trim it off + var cmdArgs = stackArgs.AsSpan().Slice(1, argCount - 1); - var request = scratchBufferManager.FormatCommandAsResp(cmd, cmdArgs); - _ = respServerSession.TryConsumeMessages(request.ptr, request.length); - var response = scratchBufferNetworkSender.GetResponse(); - var result = ProcessResponse(response.ptr, response.length); - scratchBufferNetworkSender.Reset(); - return result; - } - finally - { - ArrayPool.Shared.Return(stackArgs); - } - } + var request = scratchBufferManager.FormatCommandAsResp(cmd, cmdArgs); + _ = respServerSession.TryConsumeMessages(request.ptr, request.length); + var response = scratchBufferNetworkSender.GetResponse(); + var result = ProcessResponse(response.ptr, response.length); + scratchBufferNetworkSender.Reset(); + return result; + } + finally + { + ArrayPool.Shared.Return(stackArgs); } + } catch (Exception e) { @@ -431,11 +427,25 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) static int ErrorInvalidArgumentType(Lua state) { - state.PushString("Lua redis lib command arguments must be strings or integers"); + NativeMethods.PushBuffer(state.Handle, "Lua redis lib command arguments must be strings or integers"u8); return state.Error(); } } + /// + /// Cause a lua error to be raised with the given message. + /// + int LuaError(ReadOnlySpan msg) + { + if (!state.CheckStack(1)) + { + throw new GarnetException("Insufficient stack space to error"); + } + + NativeMethods.PushBuffer(state.Handle, msg); + return state.Error(); + } + /// /// Process a RESP-formatted response from the RespServerSession. /// diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs new file mode 100644 index 0000000000..9a44bfdad2 --- /dev/null +++ b/libs/server/Lua/NativeMethods.cs @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using KeraLua; +using static System.Net.WebRequestMethods; +using charptr_t = System.IntPtr; +using lua_Integer = System.Int64; +using lua_State = System.IntPtr; +using size_t = System.UIntPtr; +using voidptr_t = System.IntPtr; + +namespace Garnet.server +{ + /// + /// Lua runtime methods we want that are not provided by . + /// + /// Long term we'll want to try and push these upstreams and move to just using KeraLua, + /// but for now we're just defining them ourselves. + /// + internal static class NativeMethods + { + private const string LuaLibraryName = "lua54"; + + /// + /// see: https://www.lua.org/manual/5.3/manual.html#lua_tolstring + /// + [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern charptr_t lua_tolstring(lua_State L, int index, out size_t len); + + /// + /// see: https://www.lua.org/manual/5.3/manual.html#lua_type + /// + [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LuaType lua_type(lua_State L, int index); + + /// + /// see: https://www.lua.org/manual/5.3/manual.html#lua_pushlstring + /// + [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); + + /// + /// Returns true if the given index on the stack holds a string. + /// + /// Sets to the string if so, otherwise leaves it empty. + /// + /// only remains valid as long as the buffer remains on the stack, + /// use with care. + /// + internal static bool CheckBuffer(lua_State luaState, int index, out ReadOnlySpan str) + { + if (lua_type(luaState, index) != LuaType.String) + { + str = []; + return false; + } + + var start = lua_tolstring(luaState, index, out var len); + unsafe + { + str = new ReadOnlySpan((byte*)start, (int)len); + return true; + } + } + + /// + /// Pushes given span to stack as a string. + /// + /// Provided data is copied, and can be reused once this call returns. + /// + internal static unsafe void PushBuffer(lua_State luaState, ReadOnlySpan str) + { + fixed (byte* ptr = str) + { + lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length); + } + } + } +} From 196c95a0fd222b4af0708388eb3c94f80b70cda7 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 14:28:43 -0500 Subject: [PATCH 12/51] remove more allocs --- libs/server/Lua/LuaRunner.cs | 20 ++++++++------------ libs/server/Lua/NativeMethods.cs | 11 ++++++++--- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index f44552e67e..cab99c8fd3 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -314,17 +314,15 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return LuaError(CmdStrings.RESP_ERR_NOAUTH); } - // todo: no alloc - var keyBuf = state.CheckBuffer(2); - var valBuf = state.CheckBuffer(3); - - if (keyBuf == null || valBuf == null) + if(!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan) || !NativeMethods.CheckBuffer(state.Handle, 3, out var valSpan)) { return ErrorInvalidArgumentType(state); } - var key = scratchBufferManager.CreateArgSlice(keyBuf); - var value = scratchBufferManager.CreateArgSlice(valBuf); + // note these spans are implicitly pinned, as they're actually on the Lua stack + var key = ArgSlice.FromPinnedSpan(keySpan); + var value = ArgSlice.FromPinnedSpan(valSpan); + _ = api.SET(key, value); NativeMethods.PushBuffer(state.Handle, "OK"u8); @@ -337,15 +335,13 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return LuaError(CmdStrings.RESP_ERR_NOAUTH); } - // todo: no alloc - var keyBuf = state.CheckBuffer(2); - - if (keyBuf == null) + if(!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan)) { return ErrorInvalidArgumentType(state); } - var key = scratchBufferManager.CreateArgSlice(keyBuf); + // span is (implicitly) pinned since it's actually on the Lua stack + var key = ArgSlice.FromPinnedSpan(keySpan); var status = api.GET(key, out var value); if (status == GarnetStatus.OK) { diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index 9a44bfdad2..367b9f64c1 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -43,16 +43,21 @@ internal static class NativeMethods private static extern charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); /// - /// Returns true if the given index on the stack holds a string. + /// Returns true if the given index on the stack holds a string or a number. /// - /// Sets to the string if so, otherwise leaves it empty. + /// Sets to the string equivalent if so, otherwise leaves it empty. /// /// only remains valid as long as the buffer remains on the stack, /// use with care. + /// + /// Note that is changes the value on the stack to be a string if it returns true, regardless of + /// what it was originally. /// internal static bool CheckBuffer(lua_State luaState, int index, out ReadOnlySpan str) { - if (lua_type(luaState, index) != LuaType.String) + var type = lua_type(luaState, index); + + if (type != LuaType.String && type != LuaType.Number) { str = []; return false; From 7b88a977e9ae82cbec5fd79e20649c2076bad47e Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 14:59:36 -0500 Subject: [PATCH 13/51] add a test (and fixes) confirming that redis.call errors match Redis behavior --- libs/server/Lua/LuaRunner.cs | 19 ++++++++---- test/Garnet.test/LuaScriptTests.cs | 49 ++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index cab99c8fd3..0ddad4f8ae 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -78,8 +78,8 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ var sandboxRes = state.DoString(@" import = function () end redis = {} - function redis.call(cmd, ...) - return garnet_call(cmd, ...) + function redis.call(...) + return garnet_call(...) end function redis.status_reply(text) return text @@ -298,7 +298,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) if (argCount == 0) { - return LuaError("Please specify at least one argument for this redis lib call script"u8); + return LuaError("Please specify at least one argument for this redis lib call"u8); } if (!NativeMethods.CheckBuffer(state.Handle, 1, out var cmdSpan)) @@ -314,7 +314,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return LuaError(CmdStrings.RESP_ERR_NOAUTH); } - if(!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan) || !NativeMethods.CheckBuffer(state.Handle, 3, out var valSpan)) + if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan) || !NativeMethods.CheckBuffer(state.Handle, 3, out var valSpan)) { return ErrorInvalidArgumentType(state); } @@ -322,7 +322,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // note these spans are implicitly pinned, as they're actually on the Lua stack var key = ArgSlice.FromPinnedSpan(keySpan); var value = ArgSlice.FromPinnedSpan(valSpan); - + _ = api.SET(key, value); NativeMethods.PushBuffer(state.Handle, "OK"u8); @@ -335,7 +335,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return LuaError(CmdStrings.RESP_ERR_NOAUTH); } - if(!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan)) + if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan)) { return ErrorInvalidArgumentType(state); } @@ -476,6 +476,13 @@ unsafe int ProcessResponse(byte* ptr, int length) goto default; case (byte)'-': + var errSpan = new ReadOnlySpan(ptr + 1, length - 3); // cut \r\n off too + if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD)) + { + // gets a special response + return LuaError("Unknown Redis command called from script"u8); + } + // todo: remove alloc if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) { diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index 1c1ab2bdb1..f10297ed95 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -533,5 +533,54 @@ public void ScriptExistsMultiple() ClassicAssert.AreEqual(0, (long)exists[2]); } } + + [Test] + public void RedisCallErrors() + { + // testing that our error replies for redis.call match Redis behavior + // + // todo: exact matching of the hash and line number would also be nice, but that is trickier + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // no args + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call()")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Please specify at least one argument for this redis lib call")); + } + + // unknown command + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('123')")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Unknown Redis command called from script")); + } + + // bad command type + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call({ foo = 'bar'})")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + + // GET bad arg type + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('GET', { foo = 'bar' })")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + + // SET bad arg types + { + var exc1 = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('SET', 'hello', { foo = 'bar' })")); + ClassicAssert.IsTrue(exc1.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + + var exc2 = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('SET', { foo = 'bar' }, 'world')")); + ClassicAssert.IsTrue(exc2.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + + // other bad arg types + { + var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('DEL', { foo = 'bar' })")); + ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); + } + } } } \ No newline at end of file From d8ca94ce16024ed5aac230fc485a75bc2b88cffb Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 15:12:59 -0500 Subject: [PATCH 14/51] add a test (and fixes) confirming that redis.call errors match Redis behavior --- libs/server/Lua/LuaRunner.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 0ddad4f8ae..b7cfcf9282 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -303,7 +303,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) if (!NativeMethods.CheckBuffer(state.Handle, 1, out var cmdSpan)) { - return LuaError("Unknown Redis command called from script"u8); + return ErrorInvalidArgumentType(state); } // We special-case a few performance-sensitive operations to directly invoke via the storage API From fe4322b45ca8fd3818f8a64ae7732411f38c1de6 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 16:12:37 -0500 Subject: [PATCH 15/51] remove more allocations --- libs/server/Lua/LuaRunner.cs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index b7cfcf9282..96f0d81522 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -573,8 +573,8 @@ public object Run(int count, SessionParseState parseState) scratchBufferManager.Reset(); - int offset = 1; - int nKeys = parseState.GetInt(offset++); + var offset = 1; + var nKeys = parseState.GetInt(offset++); count--; ResetParameters(nKeys, count - nKeys); @@ -596,17 +596,15 @@ public object Run(int count, SessionParseState parseState) txnKeyEntries.AddKey(key, true, Tsavorite.core.LockType.Exclusive); } - // todo: no alloc - // todo: encoding is wrong here - - // equivalent to KEYS[i+1] = key.ToString() + // equivalent to KEYS[i+1] = key state.PushNumber(i + 1); - state.PushString(key.ToString()); + NativeMethods.PushBuffer(state.Handle, key.ReadOnlySpan); state.RawSet(1); offset++; } + // remove KEYS from the stack state.Pop(1); count -= nKeys; @@ -621,17 +619,17 @@ public object Run(int count, SessionParseState parseState) for (var i = 0; i < count; i++) { - // todo: no alloc - // todo encoding is wrong here + ref var argv = ref parseState.GetArgSliceByRef(offset); - // equivalent to ARGV[i+1] = parseState.GetString(offset); + // equivalent to ARGV[i+1] = argv state.PushNumber(i + 1); - state.PushString(parseState.GetString(offset)); + NativeMethods.PushBuffer(state.Handle, argv.ReadOnlySpan); state.RawSet(1); offset++; } + // remove ARGV from the stack state.Pop(1); } From e567d6fa6822d2ba50a8d3073a0e5568bed22f1e Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 16:24:19 -0500 Subject: [PATCH 16/51] add a test for weird binary values, note this fails in main today --- test/Garnet.test/LuaScriptTests.cs | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index f10297ed95..1fa09ef5e1 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -582,5 +582,31 @@ public void RedisCallErrors() ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); } } + + [Test] + public void BinaryValuesInScripts() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var trickyKey = new byte[] { 0, 1, 2, 3, 4 }; + var trickyValue = new byte[] { 5, 6, 7, 8, 9, 0 }; + var trickyValue2 = new byte[] { 0, 1, 0, 1, 0, 1, 255 }; + + ClassicAssert.IsTrue(db.StringSet(trickyKey, trickyValue)); + + var luaEscapeKeyString = $"{string.Join("", trickyKey.Select(x => $"\\{x:X2}"))}"; + + var readDirectKeyRaw = db.ScriptEvaluate($"return redis.call('GET', '{luaEscapeKeyString}')", [(RedisKey)trickyKey]); + var readDirectKeyBytes = (byte[])readDirectKeyRaw; + ClassicAssert.IsTrue(trickyValue.AsSpan().SequenceEqual(readDirectKeyBytes)); + + var setKey = db.ScriptEvaluate("return redis.call('SET', KEYS[1], ARGV[1])", [(RedisKey)trickyKey], [(RedisValue)trickyValue2]); + ClassicAssert.AreEqual("OK", (string)setKey); + + var readTrickyValue2Raw = db.StringGet(trickyKey); + var readTrickyValue2 = (byte[])readTrickyValue2Raw; + ClassicAssert.IsTrue(trickyValue2.AsSpan().SequenceEqual(readTrickyValue2)); + } } } \ No newline at end of file From a8b125229888d2998e73a3e8d41a2a65d41cee29 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 17:21:47 -0500 Subject: [PATCH 17/51] knock out some alloc todos; do some cleanup --- libs/server/ArgSlice/ScratchBufferManager.cs | 34 +++-- libs/server/Lua/LuaRunner.cs | 134 ++++++++++--------- test/Garnet.test/LuaScriptTests.cs | 13 ++ 3 files changed, 106 insertions(+), 75 deletions(-) diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 423f75efc7..c68c223214 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -220,14 +220,14 @@ public ArgSlice FormatScratch(int headerSize, ReadOnlySpan arg) /// Format specified command with arguments, as a RESP command. Lua state /// can be specified to handle Lua tables as arguments. /// - public ArgSlice FormatCommandAsResp(string cmd, ReadOnlySpan args) + public ArgSlice FormatCommandAsResp(ReadOnlySpan cmd, ReadOnlySpan args) { if (scratchBuffer == null) ExpandScratchBuffer(64); scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected - int commandStartOffset = scratchBufferOffset; - byte* ptr = scratchBufferHead + scratchBufferOffset; + var commandStartOffset = scratchBufferOffset; + var ptr = scratchBufferHead + scratchBufferOffset; while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) { @@ -236,31 +236,45 @@ public ArgSlice FormatCommandAsResp(string cmd, ReadOnlySpan args) } scratchBufferOffset = (int)(ptr - scratchBufferHead); - while (!RespWriteUtils.WriteAsciiBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) + while (!RespWriteUtils.WriteBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) { ExpandScratchBuffer(scratchBuffer.Length + 1); ptr = scratchBufferHead + scratchBufferOffset; } scratchBufferOffset = (int)(ptr - scratchBufferHead); - int count = 1; + var count = 1; foreach (var str in args) { count++; - while (!RespWriteUtils.WriteAsciiBulkString(str, ref ptr, scratchBufferHead + scratchBuffer.Length)) + + // Smuggling a null-ish value in + if (str.Length < 0) + { + while (!RespWriteUtils.WriteNull(ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + } + else { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; + while (!RespWriteUtils.WriteBulkString(str.ReadOnlySpan, ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } } + scratchBufferOffset = (int)(ptr - scratchBufferHead); } if (count != args.Length + 1) { - int extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); + var extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); if (commandStartOffset < extraSpace) throw new InvalidOperationException("Invalid number of arguments"); - + commandStartOffset -= extraSpace; var head = scratchBufferHead + commandStartOffset; diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 96f0d81522..4b774d1cce 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -21,10 +21,10 @@ internal sealed record ErrorResult(string Message); /// internal sealed class LuaRunner : IDisposable { - // rooted to keep function pointer alive + // Rooted to keep function pointer alive readonly LuaFunction garnetCall; - // references into Registry on the Lua side + // References into Registry on the Lua side readonly int sandboxEnvRegistryIndex; readonly int keysTableRegistryIndex; readonly int argvTableRegistryIndex; @@ -60,13 +60,13 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ loadSandboxedRegistryIndex = -1; functionRegistryIndex = -1; - // todo: custom allocator? + // TODO: custom allocator? state = new Lua(); Debug.Assert(state.GetTop() == 0, "Stack should be empty at allocation"); if (txnMode) { - this.txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); + txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); garnetCall = garnet_call_txn; } @@ -140,7 +140,7 @@ function load_sandboxed(source) throw new GarnetException("Could not initialize Lua sandbox state"); } - // register garnet_call in global namespace + // Register garnet_call in global namespace state.Register("garnet_call", garnetCall); var sandboxEnvType = state.GetGlobal("sandbox_env"); @@ -191,7 +191,7 @@ public void Compile() Debug.Assert(loadRes == LuaType.Function, "Unexpected load_sandboxed type"); state.PushString(source); - state.Call(1, -1); // multiple returns allowed + state.Call(1, -1); // Multiple returns allowed var numRets = state.GetTop(); if (numRets == 0) @@ -227,7 +227,7 @@ public void Compile() } finally { - // force stack empty after compilation, no matter what happens + // Force stack empty after compilation, no matter what happens state.SetTop(0); } } @@ -319,7 +319,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return ErrorInvalidArgumentType(state); } - // note these spans are implicitly pinned, as they're actually on the Lua stack + // Note these spans are implicitly pinned, as they're actually on the Lua stack var key = ArgSlice.FromPinnedSpan(keySpan); var value = ArgSlice.FromPinnedSpan(valSpan); @@ -340,7 +340,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return ErrorInvalidArgumentType(state); } - // span is (implicitly) pinned since it's actually on the Lua stack + // Span is (implicitly) pinned since it's actually on the Lua stack var key = ArgSlice.FromPinnedSpan(keySpan); var status = api.GET(key, out var value); if (status == GarnetStatus.OK) @@ -357,51 +357,52 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized // in future to provide parse state directly. + var trueArgCount = argCount - 1; - // todo: remove all these allocations - var stackArgs = ArrayPool.Shared.Rent(argCount); - var cmd = Encoding.UTF8.GetString(cmdSpan); + // Avoid allocating entirely if fewer than 16 commands (note we only store pointers, we make no copies) + // + // At 17+ we'll rent an array, which might allocate, but typically won't + var cmdArgsArr = trueArgCount <= 16 ? null : ArrayPool.Shared.Rent(argCount); + var cmdArgs = cmdArgsArr != null ? cmdArgsArr.AsSpan()[..trueArgCount] : stackalloc ArgSlice[trueArgCount]; try { - var top = state.GetTop(); - - // move backwards validating arguments - // and removing them from the stack - for (var i = argCount - 1; i >= 0; i--) + for (var i = 0; i < argCount - 1; i++) { - var argType = state.Type(top); + // Index 1 holds the command, so skip it + var argIx = 2 + i; + + var argType = state.Type(argIx); if (argType == LuaType.Nil) { - stackArgs[i] = null; + cmdArgs[i] = new ArgSlice(null, -1); } - else if (argType == LuaType.String) + else if (argType is LuaType.String or LuaType.Number) { - stackArgs[i] = state.CheckString(top); - } - else if (argType == LuaType.Number) - { - var asNum = state.CheckNumber(top); - stackArgs[i] = ((long)asNum).ToString(); + // CheckBuffer will coerce a number into a string + // + // Redis nominally converts numbers to integers, but in this case just ToStrings things + var checkRes = NativeMethods.CheckBuffer(state.Handle, argIx, out var span); + Debug.Assert(checkRes, "Should never fail"); + + // Span remains pinned so long as we don't pop the stack + cmdArgs[i] = ArgSlice.FromPinnedSpan(span); } else { - state.Pop(1); - return ErrorInvalidArgumentType(state); } - - state.Pop(1); - top--; } - Debug.Assert(state.GetTop() == 0, "Should have emptied the stack"); + var request = scratchBufferManager.FormatCommandAsResp(cmdSpan, cmdArgs); - // command is handled specially, so trim it off - var cmdArgs = stackArgs.AsSpan().Slice(1, argCount - 1); + // Once the request is formatted, we can release all the args on the Lua stack + // + // This keeps the stack size down for processing the response + state.Pop(argCount); - var request = scratchBufferManager.FormatCommandAsResp(cmd, cmdArgs); _ = respServerSession.TryConsumeMessages(request.ptr, request.length); + var response = scratchBufferNetworkSender.GetResponse(); var result = ProcessResponse(response.ptr, response.length); scratchBufferNetworkSender.Reset(); @@ -409,7 +410,10 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } finally { - ArrayPool.Shared.Return(stackArgs); + if (cmdArgsArr != null) + { + ArrayPool.Shared.Return(cmdArgsArr); + } } } @@ -421,6 +425,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return state.Error(); } + // Common failure mode is passing wrong arg static int ErrorInvalidArgumentType(Lua state) { NativeMethods.PushBuffer(state.Handle, "Lua redis lib command arguments must be strings or integers"u8); @@ -483,7 +488,7 @@ unsafe int ProcessResponse(byte* ptr, int length) return LuaError("Unknown Redis command called from script"u8); } - // todo: remove alloc + // TIDI: remove alloc if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) { state.PushString(resultStr); @@ -492,7 +497,7 @@ unsafe int ProcessResponse(byte* ptr, int length) goto default; case (byte)'$': - // todo: remove alloc + // TODO: remove alloc if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) { // bulk null strings are mapped to FALSE @@ -510,7 +515,7 @@ unsafe int ProcessResponse(byte* ptr, int length) goto default; case (byte)'*': - // todo: remove allocs + // TODO: remove allocs if (RespReadUtils.ReadRentedStringArrayResponseWithLengthHeader(ArrayPool.Shared, out var resultArray, ref ptr, ptr + length)) { try @@ -580,7 +585,7 @@ public object Run(int count, SessionParseState parseState) if (nKeys > 0) { - // get KEYS on the stack + // Get KEYS on the stack state.PushNumber(keysTableRegistryIndex); var loadedType = state.RawGet(LuaRegistry.Index); Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting KEYS"); @@ -596,7 +601,7 @@ public object Run(int count, SessionParseState parseState) txnKeyEntries.AddKey(key, true, Tsavorite.core.LockType.Exclusive); } - // equivalent to KEYS[i+1] = key + // Equivalent to KEYS[i+1] = key state.PushNumber(i + 1); NativeMethods.PushBuffer(state.Handle, key.ReadOnlySpan); state.RawSet(1); @@ -604,7 +609,7 @@ public object Run(int count, SessionParseState parseState) offset++; } - // remove KEYS from the stack + // Remove KEYS from the stack state.Pop(1); count -= nKeys; @@ -612,7 +617,7 @@ public object Run(int count, SessionParseState parseState) if (count > 0) { - // GET ARGV on the stack + // Get ARGV on the stack state.PushNumber(argvTableRegistryIndex); var loadedType = state.RawGet(LuaRegistry.Index); Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting ARGV"); @@ -621,7 +626,7 @@ public object Run(int count, SessionParseState parseState) { ref var argv = ref parseState.GetArgSliceByRef(offset); - // equivalent to ARGV[i+1] = argv + // Equivalent to ARGV[i+1] = argv state.PushNumber(i + 1); NativeMethods.PushBuffer(state.Handle, argv.ReadOnlySpan); state.RawSet(1); @@ -629,7 +634,7 @@ public object Run(int count, SessionParseState parseState) offset++; } - // remove ARGV from the stack + // Remove ARGV from the stack state.Pop(1); } @@ -702,12 +707,12 @@ void ResetParameters(int nKeys, int nArgs) if (keyLength > nKeys) { - // get KEYS on the stack + // Get KEYS on the stack state.PushNumber(keysTableRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); - // clear all the values in KEYS that we aren't going to set anyway + // Clear all the values in KEYS that we aren't going to set anyway for (var i = nKeys + 1; i <= keyLength; i++) { state.PushNil(); @@ -721,7 +726,7 @@ void ResetParameters(int nKeys, int nArgs) if (argvLength > nArgs) { - // get ARGV on the stack + // Get ARGV on the stack state.PushNumber(argvTableRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); @@ -792,10 +797,10 @@ void LoadParameters(string[] keys, string[] argv) /// object Run() { - // todo: mapping is dependent on Resp2 vs Resp3 settings + // TODO: mapping is dependent on Resp2 vs Resp3 settings // and that's not implemented at all - // todo: this shouldn't read the result, it should write the response out + // TODO: this shouldn't read the result, it should write the response out Debug.Assert(state.GetTop() == 0, "Stack should be empty at start of invocation"); if (!state.CheckStack(2)) @@ -812,7 +817,7 @@ object Run() var callRes = state.PCall(0, 1, 0); if (callRes == LuaStatus.OK) { - // the actual call worked, handle the response + // The actual call worked, handle the response if (state.GetTop() == 0) { @@ -826,9 +831,9 @@ object Run() } else if (retType == LuaType.Number) { - // Redis unconditionally converts all "number" replies to integer replies - // so we match that - // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + // Redis unconditionally converts all "number" replies to integer replies so we match that + // + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion return (long)state.CheckNumber(1); } else if (retType == LuaType.String) @@ -837,9 +842,9 @@ object Run() } else if (retType == LuaType.Boolean) { - // Redis maps Lua false to null, and Lua true to 1 - // this is strange, but documented - // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + // Redis maps Lua false to null, and Lua true to 1 this is strange, but documented + // + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion if (state.ToBoolean(1)) { return 1L; @@ -851,11 +856,11 @@ object Run() } else if (retType == LuaType.Table) { - // todo: this is hacky, and doesn't support nested arrays or whatever + // TODO: this is hacky, and doesn't support nested arrays or whatever // but is good enough for now // when refactored to avoid intermediate objects this should be fixed - // note: because we are dealing with a user provided type, we MUST respect + // TODO: because we are dealing with a user provided type, we MUST respect // metatables - so we can't use any of the RawXXX methods // if the key err is in there, we need to short circuit @@ -872,7 +877,7 @@ object Run() state.Pop(1); - // otherwise, we need to convert the table to an array + // Otherwise, we need to convert the table to an array var tableLength = state.Length(1); var ret = new object[tableLength]; @@ -891,7 +896,7 @@ object Run() ret[i - 1] = state.ToBoolean(2) ? 1L : null; break; // Redis stops processesing the array when a nil is encountered - // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion case LuaType.Nil: return ret.Take(i - 1).ToArray(); } @@ -903,13 +908,13 @@ object Run() } else { - // todo: implement + // TODO: implement throw new NotImplementedException(); } } else { - // an error was raised + // An error was raised var stackTop = state.GetTop(); if (stackTop == 0) @@ -918,8 +923,7 @@ object Run() throw new GarnetException("An error occurred while invoking a Lua script"); } - // todo: we should just write this out, not throw - // it's not exceptional + // Todo: we should just write this out, not throw it's not exceptional var msg = state.CheckString(stackTop); throw new GarnetException(msg); } diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index 1fa09ef5e1..695f556b78 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -608,5 +608,18 @@ public void BinaryValuesInScripts() var readTrickyValue2 = (byte[])readTrickyValue2Raw; ClassicAssert.IsTrue(trickyValue2.AsSpan().SequenceEqual(readTrickyValue2)); } + + [Test] + public void NumberArgumentCoercion() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + db.StringSet("2", "hello"); + db.StringSet("2.1", "world"); + + var res = (string)db.ScriptEvaluate("return redis.call('GET', 2.1)"); + ClassicAssert.AreEqual("world", res); + } } } \ No newline at end of file From a17aceb6dfd4b9252705cb969abb26cb4695426b Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 17:36:31 -0500 Subject: [PATCH 18/51] remove some more allocations --- libs/server/Lua/LuaRunner.cs | 55 ++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 4b774d1cce..0155c08213 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -450,7 +450,7 @@ int LuaError(ReadOnlySpan msg) /// /// Process a RESP-formatted response from the RespServerSession. /// - /// Pushes result onto state stack, and returns 1 + /// Pushes result onto state stack and returns 1, or raises an error and never returns. /// unsafe int ProcessResponse(byte* ptr, int length) { @@ -464,10 +464,11 @@ unsafe int ProcessResponse(byte* ptr, int length) switch (*ptr) { case (byte)'+': - // todo: remove alloc - if (RespReadUtils.ReadSimpleString(out var resultStr, ref ptr, ptr + length)) + ptr++; + length--; + if (RespReadUtils.ReadAsSpan(out var resultSpan, ref ptr, ptr + length)) { - state.PushString(resultStr); + NativeMethods.PushBuffer(state.Handle, resultSpan); return 1; } goto default; @@ -481,35 +482,35 @@ unsafe int ProcessResponse(byte* ptr, int length) goto default; case (byte)'-': - var errSpan = new ReadOnlySpan(ptr + 1, length - 3); // cut \r\n off too - if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD)) + ptr++; + length--; + if (RespReadUtils.ReadAsSpan(out var errSpan, ref ptr, ptr + length)) { - // gets a special response - return LuaError("Unknown Redis command called from script"u8); - } + if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD)) + { + // Gets a special response + return LuaError("Unknown Redis command called from script"u8); + } - // TIDI: remove alloc - if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length)) - { - state.PushString(resultStr); + NativeMethods.PushBuffer(state.Handle, errSpan); return state.Error(); + } goto default; case (byte)'$': - // TODO: remove alloc - if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length)) + if (length >= 5 && new ReadOnlySpan(ptr + 1, 4).SequenceEqual("-1\r\n"u8)) { - // bulk null strings are mapped to FALSE - // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - if (resultStr == null) - { - state.PushBoolean(false); - } - else - { - state.PushString(resultStr); - } + // Bulk null strings are mapped to FALSE + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + state.PushBoolean(false); + + return 1; + } + else if (RespReadUtils.ReadSpanWithLengthHeader(out var bulkSpan, ref ptr, ptr + length)) + { + NativeMethods.PushBuffer(state.Handle, bulkSpan); + return 1; } goto default; @@ -528,8 +529,8 @@ unsafe int ProcessResponse(byte* ptr, int length) var i = 1; foreach (var item in resultArray.Span) { - // null strings are mapped to false - // see: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + // Null strings are mapped to false + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion if (item == null) { state.PushBoolean(false); From 344c5977abce114ff0932ab13d722727db53338f Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Tue, 10 Dec 2024 18:07:58 -0500 Subject: [PATCH 19/51] response processesing is now allocation free --- libs/common/RespReadUtils.cs | 50 ------------------------------- libs/server/Lua/LuaRunner.cs | 57 ++++++++++++++++++------------------ 2 files changed, 29 insertions(+), 78 deletions(-) diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs index 9df08a48fe..b3168a2484 100644 --- a/libs/common/RespReadUtils.cs +++ b/libs/common/RespReadUtils.cs @@ -958,56 +958,6 @@ public static bool ReadStringArrayWithLengthHeader(out string[] result, ref byte return true; } - /// - /// Read string array with length header. - /// - /// result will be backed by an empty array or one rented from the given pool upon return. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static bool ReadRentedStringArrayResponseWithLengthHeader(ArrayPool pool, out Memory result, ref byte* ptr, byte* end) - { - result = Array.Empty(); - - // Parse RESP array header - if (!ReadSignedArrayLength(out var length, ref ptr, end)) - { - return false; - } - - if (length < 0) - { - // NULL value ('*-1\r\n') - return true; - } - - // Parse individual strings in the array - result = ArrayPool.Shared.Rent(length); - result = result[..length]; - - var resultSpan = result.Span; - - for (var i = 0; i < length; i++) - { - if (*ptr == '$') - { - if (!ReadStringResponseWithLengthHeader(out resultSpan[i], ref ptr, end)) - return false; - } - else if (*ptr == '+') - { - if (!ReadSimpleString(out resultSpan[i], ref ptr, end)) - return false; - } - else - { - if (!ReadIntegerAsString(out resultSpan[i], ref ptr, end)) - return false; - } - } - - return true; - } - /// /// Read double with length header /// diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 0155c08213..a7f51084d7 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -516,47 +516,48 @@ unsafe int ProcessResponse(byte* ptr, int length) goto default; case (byte)'*': - // TODO: remove allocs - if (RespReadUtils.ReadRentedStringArrayResponseWithLengthHeader(ArrayPool.Shared, out var resultArray, ref ptr, ptr + length)) + if (RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref ptr, ptr + length)) { - try - { - // create the new table - state.NewTable(); - Debug.Assert(state.GetTop() == 1, "New table should be at top of stack"); + // Create the new table + state.CreateTable(itemCount, 0); + Debug.Assert(state.GetTop() == 1, "New table should be at top of stack"); - // Populate the table - var i = 1; - foreach (var item in resultArray.Span) + for (var itemIx = 0; itemIx < itemCount; itemIx++) + { + if (*ptr == '$') { - // Null strings are mapped to false - // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - if (item == null) + // Bulk String + if (length >= 4 && new ReadOnlySpan(ptr + 1, 4).SequenceEqual("-1\r\n"u8)) { + // Null strings are mapped to false + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion state.PushBoolean(false); } - else + else if (RespReadUtils.ReadSpanWithLengthHeader(out var strSpan, ref ptr, ptr + length)) { - state.PushString(item); + NativeMethods.PushBuffer(state.Handle, strSpan); } + else + { - state.RawSetInteger(1, i); - - i++; + // Error, drop the table we allocated + state.Pop(1); + goto default; + } } - - return 1; - } - finally - { - if (!resultArray.IsEmpty) + else { - if (MemoryMarshal.TryGetArray(resultArray, out ArraySegment rented)) - { - ArrayPool.Shared.Return(rented.Array); - } + // In practice, we ONLY ever return bulk strings + // So just... not implementing the rest for now + throw new NotImplementedException($"Unexpected sigil: {(char)*ptr}"); } + + // Stack now has table and value at itemIx on it + state.RawSetInteger(1, itemIx + 1); } + + Debug.Assert(state.GetTop() == 1, "Only the table should be on the stack"); + return 1; } goto default; From 7610349653502d2829066e732ead87b53ecb6ad4 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 10:41:38 -0500 Subject: [PATCH 20/51] kill a hand full of additional allocations --- libs/server/Lua/LuaRunner.cs | 91 +++++++++++++++++++++++++----------- 1 file changed, 65 insertions(+), 26 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index a7f51084d7..92086eb993 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -5,7 +5,6 @@ using System.Buffers; using System.Diagnostics; using System.Linq; -using System.Runtime.InteropServices; using System.Text; using Garnet.common; using KeraLua; @@ -280,7 +279,7 @@ int NoSessionError() { logger?.LogError("Lua call came in without a valid resp session"); - state.PushString("No session available"); + NativeMethods.PushBuffer(state.Handle, "No session available"u8); // this will never return, but we can pretend it does return state.Error(); @@ -757,41 +756,81 @@ void LoadParameters(string[] keys, string[] argv) } ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); - if (keys != null) + + byte[] encodingBufferArr = null; + Span encodingBuffer = stackalloc byte[64]; + try { - // get KEYS on the stack - state.PushNumber(keysTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); - for (var i = 0; i < keys.Length; i++) + if (keys != null) { - // equivalent to KEYS[i+1] = keys[i] - state.PushString(keys[i]); - state.RawSetInteger(1, i + 1); + // get KEYS on the stack + state.PushNumber(keysTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); + + for (var i = 0; i < keys.Length; i++) + { + // equivalent to KEYS[i+1] = keys[i] + var key = keys[i]; + + var keyLen = PrepareString(key, ref encodingBufferArr, ref encodingBuffer); + NativeMethods.PushBuffer(state.Handle, encodingBuffer[..keyLen]); + + state.RawSetInteger(1, i + 1); + } + + state.Pop(1); } - state.Pop(1); - } + if (argv != null) + { + // get ARGV on the stack + state.PushNumber(argvTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); - if (argv != null) - { - // get ARGV on the stack - state.PushNumber(argvTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); + for (var i = 0; i < argv.Length; i++) + { + // equivalent to ARGV[i+1] = keys[i] + var arg = argv[i]; - for (var i = 0; i < argv.Length; i++) + var argLen = PrepareString(arg, ref encodingBufferArr, ref encodingBuffer); + NativeMethods.PushBuffer(state.Handle, encodingBuffer[..argLen]); + + state.RawSetInteger(1, i + 1); + } + + state.Pop(1); + } + } + finally + { + if(encodingBufferArr != null) { - // equivalent to ARGV[i+1] = keys[i] - state.PushString(argv[i]); - state.RawSetInteger(1, i + 1); + ArrayPool.Shared.Return(encodingBufferArr); } - - state.Pop(1); } + Debug.Assert(state.GetTop() == 0, "Stack should be empty when invocation ends"); + + static int PrepareString(string raw, ref byte[] arr, ref Span span) + { + var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length); + if(span.Length < maxLen) + { + if(arr != null) + { + ArrayPool.Shared.Return(arr); + } + + arr = ArrayPool.Shared.Rent(maxLen); + span = arr; + } + + return Encoding.UTF8.GetBytes(raw, span); + } } /// @@ -866,7 +905,7 @@ object Run() // metatables - so we can't use any of the RawXXX methods // if the key err is in there, we need to short circuit - state.PushString("err"); + NativeMethods.PushBuffer(state.Handle, "err"u8); var errType = state.GetTable(1); if (errType == LuaType.String) From 089f1e66aeaed1cb200c9c28d2f1dae64cd6c076 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 11:27:57 -0500 Subject: [PATCH 21/51] DRY up some repeated checks; assert assumptions about Lua stack in DEBUG builds --- libs/server/Lua/LuaRunner.cs | 293 ++++++++++++++++++++++++----------- 1 file changed, 203 insertions(+), 90 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 92086eb993..c5dcdc085b 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -5,10 +5,12 @@ using System.Buffers; using System.Diagnostics; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using Garnet.common; using KeraLua; using Microsoft.Extensions.Logging; +using static System.Runtime.InteropServices.JavaScript.JSType; namespace Garnet.server { @@ -24,13 +26,14 @@ internal sealed class LuaRunner : IDisposable readonly LuaFunction garnetCall; // References into Registry on the Lua side + // TODO: essentially all constant strings should be pulled out of registry too to avoid copying cost readonly int sandboxEnvRegistryIndex; readonly int keysTableRegistryIndex; readonly int argvTableRegistryIndex; readonly int loadSandboxedRegistryIndex; int functionRegistryIndex; - readonly string source; + readonly ReadOnlyMemory source; readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly RespServerSession respServerSession; readonly ScratchBufferManager scratchBufferManager; @@ -44,7 +47,7 @@ internal sealed class LuaRunner : IDisposable /// /// Creates a new runner with the source of the script /// - public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) + public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) { this.source = source; this.txnMode = txnMode; @@ -61,7 +64,7 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ // TODO: custom allocator? state = new Lua(); - Debug.Assert(state.GetTop() == 0, "Stack should be empty at allocation"); + AssertLuaStackEmpty(); if (txnMode) { @@ -158,14 +161,14 @@ function load_sandboxed(source) Debug.Assert(loadSandboxedType == LuaType.Function, "Unexpected load_sandboxed type"); loadSandboxedRegistryIndex = state.Ref(LuaRegistry.Index); - Debug.Assert(state.GetTop() == 0, "Stack should be empty after initialization"); + AssertLuaStackEmpty(); } /// /// Creates a new runner with the source of the script /// - public LuaRunner(ReadOnlySpan source, bool txnMode, RespServerSession respServerSession, ScratchBufferNetworkSender scratchBufferNetworkSender, ILogger logger = null) - : this(Encoding.UTF8.GetString(source), txnMode, respServerSession, scratchBufferNetworkSender, logger) + public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) + : this(Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, logger) { } @@ -174,22 +177,21 @@ public LuaRunner(ReadOnlySpan source, bool txnMode, RespServerSession resp /// public void Compile() { + const int NeededStackSpace = 2; + Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); - Debug.Assert(state.GetTop() == 0, "Stack should be empty at start of compilation"); + AssertLuaStackEmpty(); try { - if (!state.CheckStack(2)) - { - throw new GarnetException("Insufficient stack space to compile function"); - } + ForceGrowLuaStack(NeededStackSpace); - state.PushNumber(loadSandboxedRegistryIndex); + CheckedPushNumber(NeededStackSpace, loadSandboxedRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Function, "Unexpected load_sandboxed type"); - state.PushString(source); + CheckedPushBuffer(NeededStackSpace, source.Span); state.Call(1, -1); // Multiple returns allowed var numRets = state.GetTop(); @@ -277,9 +279,13 @@ public int garnet_call_txn(IntPtr luaStatePtr) /// int NoSessionError() { + const int NeededStackSpace = 1; + logger?.LogError("Lua call came in without a valid resp session"); - NativeMethods.PushBuffer(state.Handle, "No session available"u8); + ForceGrowLuaStack(NeededStackSpace); + + CheckedPushBuffer(NeededStackSpace, "No session available"u8); // this will never return, but we can pretend it does return state.Error(); @@ -291,6 +297,8 @@ int NoSessionError() unsafe int ProcessCommandFromScripting(TGarnetApi api) where TGarnetApi : IGarnetApi { + const int AdditionalStackSpace = 1; + try { var argCount = state.GetTop(); @@ -300,9 +308,13 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) return LuaError("Please specify at least one argument for this redis lib call"u8); } + ForceGrowLuaStack(AdditionalStackSpace); + + var neededStackSpace = argCount + AdditionalStackSpace; + if (!NativeMethods.CheckBuffer(state.Handle, 1, out var cmdSpan)) { - return ErrorInvalidArgumentType(state); + return ErrorInvalidArgumentType(neededStackSpace); } // We special-case a few performance-sensitive operations to directly invoke via the storage API @@ -315,7 +327,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan) || !NativeMethods.CheckBuffer(state.Handle, 3, out var valSpan)) { - return ErrorInvalidArgumentType(state); + return ErrorInvalidArgumentType(neededStackSpace); } // Note these spans are implicitly pinned, as they're actually on the Lua stack @@ -324,7 +336,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) _ = api.SET(key, value); - NativeMethods.PushBuffer(state.Handle, "OK"u8); + CheckedPushBuffer(neededStackSpace, "OK"u8); return 1; } else if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "GET"u8) && argCount == 2) @@ -336,7 +348,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan)) { - return ErrorInvalidArgumentType(state); + return ErrorInvalidArgumentType(neededStackSpace); } // Span is (implicitly) pinned since it's actually on the Lua stack @@ -344,11 +356,11 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) var status = api.GET(key, out var value); if (status == GarnetStatus.OK) { - NativeMethods.PushBuffer(state.Handle, value.ReadOnlySpan); + CheckedPushBuffer(neededStackSpace, value.ReadOnlySpan); } else { - state.PushNil(); + CheckedPushNil(neededStackSpace); } return 1; @@ -389,7 +401,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } else { - return ErrorInvalidArgumentType(state); + return ErrorInvalidArgumentType(neededStackSpace); } } @@ -420,29 +432,45 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) { logger?.LogError(e, "During Lua script execution"); - state.PushString(e.Message); - return state.Error(); - } + // Clear the stack + state.SetTop(0); - // Common failure mode is passing wrong arg - static int ErrorInvalidArgumentType(Lua state) - { - NativeMethods.PushBuffer(state.Handle, "Lua redis lib command arguments must be strings or integers"u8); - return state.Error(); + // Try real hard to raise an error in Lua, but we may just be SOL + // + // We don't use ForceGrowLuaStack here because we're in an exception handler + if (state.CheckStack(AdditionalStackSpace)) + { + // TODO: Remove alloc + var b = Encoding.UTF8.GetBytes(e.Message); + CheckedPushBuffer(AdditionalStackSpace, b); + return state.Error(); + } + + throw; } } + + + /// + /// Common failure mode is passing wrong arg, so DRY it up. + /// + int ErrorInvalidArgumentType(int neededCapacity) + { + CheckedPushBuffer(neededCapacity, "Lua redis lib command arguments must be strings or integers"u8); + return state.Error(); + } + /// /// Cause a lua error to be raised with the given message. /// int LuaError(ReadOnlySpan msg) { - if (!state.CheckStack(1)) - { - throw new GarnetException("Insufficient stack space to error"); - } + const int NeededStackSize = 1; - NativeMethods.PushBuffer(state.Handle, msg); + ForceGrowLuaStack(NeededStackSize); + + CheckedPushBuffer(NeededStackSize, msg); return state.Error(); } @@ -453,12 +481,11 @@ int LuaError(ReadOnlySpan msg) /// unsafe int ProcessResponse(byte* ptr, int length) { - Debug.Assert(state.GetTop() == 0, "Stack should be empty before processing response"); + const int NeededStackSize = 3; - if (!state.CheckStack(3)) - { - throw new GarnetException("Insufficent space on stack to prepare response"); - } + AssertLuaStackEmpty(); + + ForceGrowLuaStack(NeededStackSize); switch (*ptr) { @@ -467,7 +494,7 @@ unsafe int ProcessResponse(byte* ptr, int length) length--; if (RespReadUtils.ReadAsSpan(out var resultSpan, ref ptr, ptr + length)) { - NativeMethods.PushBuffer(state.Handle, resultSpan); + CheckedPushBuffer(NeededStackSize, resultSpan); return 1; } goto default; @@ -475,7 +502,7 @@ unsafe int ProcessResponse(byte* ptr, int length) case (byte)':': if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) { - state.PushNumber(number); + CheckedPushNumber(NeededStackSize, number); return 1; } goto default; @@ -491,7 +518,7 @@ unsafe int ProcessResponse(byte* ptr, int length) return LuaError("Unknown Redis command called from script"u8); } - NativeMethods.PushBuffer(state.Handle, errSpan); + CheckedPushBuffer(NeededStackSize, errSpan); return state.Error(); } @@ -502,13 +529,13 @@ unsafe int ProcessResponse(byte* ptr, int length) { // Bulk null strings are mapped to FALSE // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - state.PushBoolean(false); + CheckedPushBoolean(NeededStackSize, false); return 1; } else if (RespReadUtils.ReadSpanWithLengthHeader(out var bulkSpan, ref ptr, ptr + length)) { - NativeMethods.PushBuffer(state.Handle, bulkSpan); + CheckedPushBuffer(NeededStackSize, bulkSpan); return 1; } @@ -530,15 +557,14 @@ unsafe int ProcessResponse(byte* ptr, int length) { // Null strings are mapped to false // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - state.PushBoolean(false); + CheckedPushBoolean(NeededStackSize, false); } else if (RespReadUtils.ReadSpanWithLengthHeader(out var strSpan, ref ptr, ptr + length)) { - NativeMethods.PushBuffer(state.Handle, strSpan); + CheckedPushBuffer(NeededStackSize, strSpan); } else { - // Error, drop the table we allocated state.Pop(1); goto default; @@ -570,12 +596,11 @@ unsafe int ProcessResponse(byte* ptr, int length) /// public object Run(int count, SessionParseState parseState) { - Debug.Assert(state.GetTop() == 0, "Stack should be empty at invocation start"); + const int NeededStackSize = 3; - if (!state.CheckStack(3)) - { - throw new GarnetException("Insufficient stack space to run script"); - } + AssertLuaStackEmpty(); + + ForceGrowLuaStack(NeededStackSize); scratchBufferManager.Reset(); @@ -587,7 +612,7 @@ public object Run(int count, SessionParseState parseState) if (nKeys > 0) { // Get KEYS on the stack - state.PushNumber(keysTableRegistryIndex); + CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); var loadedType = state.RawGet(LuaRegistry.Index); Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting KEYS"); @@ -603,8 +628,8 @@ public object Run(int count, SessionParseState parseState) } // Equivalent to KEYS[i+1] = key - state.PushNumber(i + 1); - NativeMethods.PushBuffer(state.Handle, key.ReadOnlySpan); + CheckedPushNumber(NeededStackSize, i + 1); + CheckedPushBuffer(NeededStackSize, key.ReadOnlySpan); state.RawSet(1); offset++; @@ -619,7 +644,7 @@ public object Run(int count, SessionParseState parseState) if (count > 0) { // Get ARGV on the stack - state.PushNumber(argvTableRegistryIndex); + CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); var loadedType = state.RawGet(LuaRegistry.Index); Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting ARGV"); @@ -628,8 +653,8 @@ public object Run(int count, SessionParseState parseState) ref var argv = ref parseState.GetArgSliceByRef(offset); // Equivalent to ARGV[i+1] = argv - state.PushNumber(i + 1); - NativeMethods.PushBuffer(state.Handle, argv.ReadOnlySpan); + CheckedPushNumber(NeededStackSize, i + 1); + CheckedPushBuffer(NeededStackSize, argv.ReadOnlySpan); state.RawSet(1); offset++; @@ -639,7 +664,7 @@ public object Run(int count, SessionParseState parseState) state.Pop(1); } - Debug.Assert(state.GetTop() == 0, "Stack should be empty before running function"); + AssertLuaStackEmpty(); if (txnMode && nKeys > 0) { @@ -699,24 +724,24 @@ object RunTransaction() void ResetParameters(int nKeys, int nArgs) { - Debug.Assert(state.GetTop() == 0, "Stack should be empty before resetting parameters"); + // TODO: is this faster than punching a function in to do it? + const int NeededStackSize = 2; - if (!state.CheckStack(2)) - { - throw new GarnetException("Insufficient space on stack to reset parameters"); - } + AssertLuaStackEmpty(); + + ForceGrowLuaStack(NeededStackSize); if (keyLength > nKeys) { // Get KEYS on the stack - state.PushNumber(keysTableRegistryIndex); + CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); // Clear all the values in KEYS that we aren't going to set anyway for (var i = nKeys + 1; i <= keyLength; i++) { - state.PushNil(); + CheckedPushNil(NeededStackSize); state.RawSetInteger(1, i); } @@ -728,13 +753,13 @@ void ResetParameters(int nKeys, int nArgs) if (argvLength > nArgs) { // Get ARGV on the stack - state.PushNumber(argvTableRegistryIndex); + CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); for (var i = nArgs + 1; i <= argvLength; i++) { - state.PushNil(); + CheckedPushNil(NeededStackSize); state.RawSetInteger(1, i); } @@ -743,17 +768,16 @@ void ResetParameters(int nKeys, int nArgs) argvLength = nArgs; - Debug.Assert(state.GetTop() == 0, "Stack should be empty after resetting parameters"); + AssertLuaStackEmpty(); } void LoadParameters(string[] keys, string[] argv) { - Debug.Assert(state.GetTop() == 0, "Stack should be empty before invocation starts"); + const int NeededStackSize = 2; - if (!state.CheckStack(2)) - { - throw new GarnetException("Insufficient stack space to call function"); - } + AssertLuaStackEmpty(); + + ForceGrowLuaStack(NeededStackSize); ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); @@ -765,7 +789,7 @@ void LoadParameters(string[] keys, string[] argv) if (keys != null) { // get KEYS on the stack - state.PushNumber(keysTableRegistryIndex); + CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); @@ -775,7 +799,7 @@ void LoadParameters(string[] keys, string[] argv) var key = keys[i]; var keyLen = PrepareString(key, ref encodingBufferArr, ref encodingBuffer); - NativeMethods.PushBuffer(state.Handle, encodingBuffer[..keyLen]); + CheckedPushBuffer(NeededStackSize, encodingBuffer[..keyLen]); state.RawSetInteger(1, i + 1); } @@ -786,7 +810,7 @@ void LoadParameters(string[] keys, string[] argv) if (argv != null) { // get ARGV on the stack - state.PushNumber(argvTableRegistryIndex); + CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); @@ -796,7 +820,7 @@ void LoadParameters(string[] keys, string[] argv) var arg = argv[i]; var argLen = PrepareString(arg, ref encodingBufferArr, ref encodingBuffer); - NativeMethods.PushBuffer(state.Handle, encodingBuffer[..argLen]); + CheckedPushBuffer(NeededStackSize, encodingBuffer[..argLen]); state.RawSetInteger(1, i + 1); } @@ -806,21 +830,20 @@ void LoadParameters(string[] keys, string[] argv) } finally { - if(encodingBufferArr != null) + if (encodingBufferArr != null) { ArrayPool.Shared.Return(encodingBufferArr); } } - - Debug.Assert(state.GetTop() == 0, "Stack should be empty when invocation ends"); + AssertLuaStackEmpty(); static int PrepareString(string raw, ref byte[] arr, ref Span span) { var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length); - if(span.Length < maxLen) + if (span.Length < maxLen) { - if(arr != null) + if (arr != null) { ArrayPool.Shared.Return(arr); } @@ -838,20 +861,19 @@ static int PrepareString(string raw, ref byte[] arr, ref Span span) /// object Run() { + const int NeededStackSize = 2; + // TODO: mapping is dependent on Resp2 vs Resp3 settings // and that's not implemented at all // TODO: this shouldn't read the result, it should write the response out - Debug.Assert(state.GetTop() == 0, "Stack should be empty at start of invocation"); + AssertLuaStackEmpty(); - if (!state.CheckStack(2)) - { - throw new GarnetException("Insufficient stack space to run function"); - } + ForceGrowLuaStack(NeededStackSize); try { - state.PushNumber(functionRegistryIndex); + CheckedPushNumber(NeededStackSize, functionRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Function, "Unexpected type for function to invoke"); @@ -905,7 +927,7 @@ object Run() // metatables - so we can't use any of the RawXXX methods // if the key err is in there, we need to short circuit - NativeMethods.PushBuffer(state.Handle, "err"u8); + CheckedPushBuffer(NeededStackSize, "err"u8); var errType = state.GetTable(1); if (errType == LuaType.String) @@ -975,5 +997,96 @@ object Run() state.SetTop(0); } } + + /// + /// Ensure there's enough space on the Lua stack for more items. + /// + /// Throws if there is not. + /// + /// Prefer using this to calling directly. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ForceGrowLuaStack(int additionalCapacity) + { + if (!state.CheckStack(additionalCapacity)) + { + throw new GarnetException("Could not reserve additional capacity on the Lua stack"); + } + } + + /// + /// Check that the Lua stack is empty in DEBUG builds. + /// + /// This is never necessary for correctness, but is often useful to find logical bugs. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private void AssertLuaStackEmpty([CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + Debug.Assert(state.GetTop() == 0, $"Lua stack not empty when expected ({method}:{line} in {file})"); + } + + /// + /// Check the Lua stack has not grown beyond the capacity we initially reserved. + /// + /// This asserts (in DEBUG) that the next .PushXXX will succeed. + /// + /// In practice, Lua almost always gives us enough space (default is ~20 slots) but that's not guaranteed and can be false + /// for complicated redis.call invocations. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private void AssertLuaStackBelow(int reservedCapacity, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + Debug.Assert(state.GetTop() < reservedCapacity, $"About to push to Lua stack without having reserved sufficient capacity."); + } + + /// + /// This should be used for all PushBuffer calls into Lua. + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedPushBuffer(int reservedCapacity, ReadOnlySpan buffer, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + AssertLuaStackBelow(reservedCapacity, file, method, line); + + NativeMethods.PushBuffer(state.Handle, buffer); + } + + /// + /// This should be used for all PushNil calls into Lua. + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedPushNil(int reservedCapacity, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + AssertLuaStackBelow(reservedCapacity, file, method, line); + + state.PushNil(); + } + + /// + /// This should be used for all PushNumber calls into Lua. + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedPushNumber(int reservedCapacity, double number, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + AssertLuaStackBelow(reservedCapacity, file, method, line); + + state.PushNumber(number); + } + + /// + /// This should be used for all PushBoolean calls into Lua. + /// + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedPushBoolean(int reservedCapacity, bool b, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + AssertLuaStackBelow(reservedCapacity, file, method, line); + + state.PushBoolean(b); + } } } \ No newline at end of file From 360503f403ddf91dfc97305cdd842392e5c7167e Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 12:19:29 -0500 Subject: [PATCH 22/51] adjust some names, these keep confusing me; preparing for writing response directly to network stream --- benchmark/BDN.benchmark/Lua/LuaScripts.cs | 8 ++-- libs/server/Lua/LuaCommands.cs | 5 +-- libs/server/Lua/LuaRunner.cs | 45 +++++++++++++++-------- test/Garnet.test/LuaScriptRunnerTests.cs | 14 +++---- 4 files changed, 42 insertions(+), 30 deletions(-) diff --git a/benchmark/BDN.benchmark/Lua/LuaScripts.cs b/benchmark/BDN.benchmark/Lua/LuaScripts.cs index 5c6060d2e6..ae5bea1e9d 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScripts.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScripts.cs @@ -55,18 +55,18 @@ public void GlobalCleanup() [Benchmark] public void Script1() - => r1.Run(); + => r1.RunForRunner(); [Benchmark] public void Script2() - => r2.Run(); + => r2.RunForRunner(); [Benchmark] public void Script3() - => r3.Run(keys, null); + => r3.RunForRunner(keys, null); [Benchmark] public void Script4() - => r4.Run(keys, null); + => r4.RunForRunner(keys, null); } } \ No newline at end of file diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index 79e4bf20da..a2db625152 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -240,14 +240,11 @@ private bool CheckLuaEnabled() /// /// Invoke the execution of a server-side Lua script. /// - /// - /// - /// private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner) { try { - var scriptResult = scriptRunner.Run(count, parseState); + var scriptResult = scriptRunner.RunForParseState(count, parseState); WriteObject(scriptResult); } catch (Exception ex) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index c5dcdc085b..5b6294a30e 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -10,7 +10,6 @@ using Garnet.common; using KeraLua; using Microsoft.Extensions.Logging; -using static System.Runtime.InteropServices.JavaScript.JSType; namespace Garnet.server { @@ -36,6 +35,8 @@ internal sealed class LuaRunner : IDisposable readonly ReadOnlyMemory source; readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly RespServerSession respServerSession; + + // TODO: all buffers should be rented from this, remove ArrayPool use readonly ScratchBufferManager scratchBufferManager; readonly ILogger logger; readonly Lua state; @@ -592,9 +593,11 @@ unsafe int ProcessResponse(byte* ptr, int length) } /// - /// Runs the precompiled Lua function with specified parse state + /// Runs the precompiled Lua function with specified parse state. + /// + /// Meant for use directly from Garnet. /// - public object Run(int count, SessionParseState parseState) + public object RunForParseState(int count, SessionParseState parseState) { const int NeededStackSize = 3; @@ -668,21 +671,23 @@ public object Run(int count, SessionParseState parseState) if (txnMode && nKeys > 0) { - return RunTransaction(); + return RunInTransaction(); } else { - return Run(); + return RunCommon(); } } /// - /// Runs the precompiled Lua function with specified (keys, argv) state + /// Runs the precompiled Lua function with specified (keys, argv) state. + /// + /// Meant for use from a .NET host rather than in Garnet properly. /// - public object Run(string[] keys = null, string[] argv = null) + public object RunForRunner(string[] keys = null, string[] argv = null) { scratchBufferManager?.Reset(); - LoadParameters(keys, argv); + LoadParametersForRunner(keys, argv); if (txnMode && keys?.Length > 0) { // Add keys to the transaction @@ -693,15 +698,18 @@ public object Run(string[] keys = null, string[] argv = null) if (!respServerSession.storageSession.objectStoreLockableContext.IsNull) txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive); } - return RunTransaction(); + return RunInTransaction(); } else { - return Run(); + return RunCommon(); } } - object RunTransaction() + /// + /// Calls after setting up appropriate state for a transaction. + /// + object RunInTransaction() { try { @@ -710,7 +718,8 @@ object RunTransaction() respServerSession.storageSession.objectStoreLockableContext.BeginLockable(); respServerSession.SetTransactionMode(true); txnKeyEntries.LockAllKeys(); - return Run(); + + return RunCommon(); } finally { @@ -722,6 +731,9 @@ object RunTransaction() } } + /// + /// Remove extra keys and args from KEYS and ARGV globals. + /// void ResetParameters(int nKeys, int nArgs) { // TODO: is this faster than punching a function in to do it? @@ -771,7 +783,10 @@ void ResetParameters(int nKeys, int nArgs) AssertLuaStackEmpty(); } - void LoadParameters(string[] keys, string[] argv) + /// + /// Takes .NET strings for keys and args and pushes them into KEYS and ARGV globals. + /// + void LoadParametersForRunner(string[] keys, string[] argv) { const int NeededStackSize = 2; @@ -857,9 +872,9 @@ static int PrepareString(string raw, ref byte[] arr, ref Span span) } /// - /// Runs the precompiled Lua function + /// Runs the precompiled Lua function. /// - object Run() + object RunCommon() { const int NeededStackSize = 2; diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index a1f5c945d1..eb09a8cf9e 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -18,7 +18,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("luanet.load_assembly('mscorlib')")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message); } @@ -26,7 +26,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("os = require('os'); return os.time();")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message); } @@ -34,7 +34,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("dofile();")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message); } @@ -42,7 +42,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("require \"notepad\"")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message); } @@ -50,7 +50,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("os.exit();")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message); } @@ -58,7 +58,7 @@ public void CannotRunUnsafeScript() using (var runner = new LuaRunner("import ('System.Diagnostics');")) { runner.Compile(); - var ex = Assert.Throws(() => runner.Run()); + var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message); } } @@ -90,7 +90,7 @@ public void CanRunScript() using (var runner = new LuaRunner("local list; list = ARGV[1] ; return list;")) { runner.Compile(); - var res = runner.Run(keys, args); + var res = runner.RunForRunner(keys, args); ClassicAssert.AreEqual("arg1", res); } From abf803b2fb35a26fa280f629b5558801e9dff710 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 13:14:04 -0500 Subject: [PATCH 23/51] first whack at getting script results directly into network stream; lots broken right now --- libs/server/Lua/LuaCommands.cs | 223 +++++++++++---------- libs/server/Lua/LuaRunner.cs | 269 +++++++++++++++++++------- libs/server/Resp/RespServerSession.cs | 4 +- 3 files changed, 313 insertions(+), 183 deletions(-) diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index a2db625152..ed08513911 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -244,8 +244,7 @@ private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner) { try { - var scriptResult = scriptRunner.RunForParseState(count, parseState); - WriteObject(scriptResult); + scriptRunner.RunForSession(count, this); } catch (Exception ex) { @@ -257,115 +256,115 @@ private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner) return true; } - void WriteObject(object scriptResult) - { - if (scriptResult != null) - { - if (scriptResult is string s) - { - while (!RespWriteUtils.WriteAsciiBulkString(s, ref dcurr, dend)) - SendAndReset(); - } - else if ((scriptResult as byte?) != null && (byte)scriptResult == 36) //equals to $ - { - while (!RespWriteUtils.WriteDirect((byte[])scriptResult, ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is bool b) - { - if (b) - { - while (!RespWriteUtils.WriteInteger(1, ref dcurr, dend)) - SendAndReset(); - } - else - { - while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) - SendAndReset(); - } - } - else if (scriptResult is long l) - { - while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend)) - SendAndReset(); - } - else if (scriptResult is object[] o) - { - var count = o.Length; - while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) - SendAndReset(); - - foreach (var value in o) - { - WriteObject(value); - } - } - else if (scriptResult is ErrorResult e) - { - while (!RespWriteUtils.WriteError(e.Message, ref dcurr, dend)) - SendAndReset(); - } - else - { - // todo: this should all go away - throw new NotImplementedException(); - } - //else if (scriptResult is ArgSlice a) - //{ - // while (!RespWriteUtils.WriteBulkString(a.ReadOnlySpan, ref dcurr, dend)) - // SendAndReset(); - //} - //else if (scriptResult is object[] o) - //{ - // // Two objects one boolean value and the result from the Lua Call - // while (!RespWriteUtils.WriteAsciiBulkString(o[1].ToString().AsSpan(), ref dcurr, dend)) - // SendAndReset(); - //} - //else if (scriptResult is LuaTable luaTable) - //{ - // try - // { - // var retVal = luaTable["err"]; - // if (retVal != null) - // { - // while (!RespWriteUtils.WriteError((string)retVal, ref dcurr, dend)) - // SendAndReset(); - // } - // else - // { - // retVal = luaTable["ok"]; - // if (retVal != null) - // { - // while (!RespWriteUtils.WriteAsciiBulkString((string)retVal, ref dcurr, dend)) - // SendAndReset(); - // } - // else - // { - // int count = luaTable.Values.Count; - // while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) - // SendAndReset(); - // foreach (var value in luaTable.Values) - // { - // WriteObject(value); - // } - // } - // } - // } - // finally - // { - // luaTable.Dispose(); - // } - //} - //else - //{ - // throw new LuaScriptException("Unknown return type", ""); - //} - } - else - { - while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) - SendAndReset(); - } - } + //void WriteObject(object scriptResult) + //{ + // if (scriptResult != null) + // { + // if (scriptResult is string s) + // { + // while (!RespWriteUtils.WriteAsciiBulkString(s, ref dcurr, dend)) + // SendAndReset(); + // } + // else if ((scriptResult as byte?) != null && (byte)scriptResult == 36) //equals to $ + // { + // while (!RespWriteUtils.WriteDirect((byte[])scriptResult, ref dcurr, dend)) + // SendAndReset(); + // } + // else if (scriptResult is bool b) + // { + // if (b) + // { + // while (!RespWriteUtils.WriteInteger(1, ref dcurr, dend)) + // SendAndReset(); + // } + // else + // { + // while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) + // SendAndReset(); + // } + // } + // else if (scriptResult is long l) + // { + // while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend)) + // SendAndReset(); + // } + // else if (scriptResult is object[] o) + // { + // var count = o.Length; + // while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) + // SendAndReset(); + + // foreach (var value in o) + // { + // WriteObject(value); + // } + // } + // else if (scriptResult is ErrorResult e) + // { + // while (!RespWriteUtils.WriteError(e.Message, ref dcurr, dend)) + // SendAndReset(); + // } + // else + // { + // // todo: this should all go away + // throw new NotImplementedException(); + // } + // //else if (scriptResult is ArgSlice a) + // //{ + // // while (!RespWriteUtils.WriteBulkString(a.ReadOnlySpan, ref dcurr, dend)) + // // SendAndReset(); + // //} + // //else if (scriptResult is object[] o) + // //{ + // // // Two objects one boolean value and the result from the Lua Call + // // while (!RespWriteUtils.WriteAsciiBulkString(o[1].ToString().AsSpan(), ref dcurr, dend)) + // // SendAndReset(); + // //} + // //else if (scriptResult is LuaTable luaTable) + // //{ + // // try + // // { + // // var retVal = luaTable["err"]; + // // if (retVal != null) + // // { + // // while (!RespWriteUtils.WriteError((string)retVal, ref dcurr, dend)) + // // SendAndReset(); + // // } + // // else + // // { + // // retVal = luaTable["ok"]; + // // if (retVal != null) + // // { + // // while (!RespWriteUtils.WriteAsciiBulkString((string)retVal, ref dcurr, dend)) + // // SendAndReset(); + // // } + // // else + // // { + // // int count = luaTable.Values.Count; + // // while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) + // // SendAndReset(); + // // foreach (var value in luaTable.Values) + // // { + // // WriteObject(value); + // // } + // // } + // // } + // // } + // // finally + // // { + // // luaTable.Dispose(); + // // } + // //} + // //else + // //{ + // // throw new LuaScriptException("Unknown return type", ""); + // //} + // } + // else + // { + // while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) + // SendAndReset(); + // } + //} } } \ No newline at end of file diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 5b6294a30e..e713c69fa1 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -6,6 +6,7 @@ using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using Garnet.common; using KeraLua; @@ -13,14 +14,106 @@ namespace Garnet.server { - // hack hack hack - internal sealed record ErrorResult(string Message); - /// /// Creates the instance to run Lua scripts /// internal sealed class LuaRunner : IDisposable { + /// + /// Adapter to allow us to write directly to the network + /// when in Garnet and still keep script runner work. + /// + private unsafe interface IResponseAdapter + { + /// + /// Equivalent to the ref curr we pass into methods. + /// + ref byte* BufferCur { get; } + + /// + /// Equivalent to the end we pass into methods. + /// + byte* BufferEnd { get; } + + /// + /// Equivalent to . + /// + void SendAndReset(); + } + + /// + /// Adapter so script results go directly + /// to the network. + /// + private readonly struct RespResponseAdapter : IResponseAdapter + { + private readonly RespServerSession session; + + internal RespResponseAdapter(RespServerSession session) + { + this.session = session; + } + + /// + public unsafe ref byte* BufferCur + => ref session.dcurr; + + /// + public unsafe byte* BufferEnd + => session.dend; + + /// + public void SendAndReset() + => session.SendAndReset(); + } + + /// + /// For the runner, put output into an array. + /// + private unsafe struct RunnerAdapter : IResponseAdapter + { + private byte* cur; + private byte[] pinnedArr; + + internal RunnerAdapter(byte[] initialPinnedArr) + { + pinnedArr = initialPinnedArr; + cur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(initialPinnedArr)); + BufferEnd = cur + initialPinnedArr.Length; + } + +#pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference + /// + public unsafe ref byte* BufferCur + => ref cur; +#pragma warning restore CS9084 + + /// + public unsafe byte* BufferEnd { readonly get; private set; } + + /// + /// Gets a span that covers the responses as written so far. + /// + public readonly ReadOnlySpan Response + => new(cur, (int)(BufferEnd - cur)); + + /// + public void SendAndReset() + { + var newLen = pinnedArr.Length * 2; + var newPinnedArr = GC.AllocateUninitializedArray(newLen); + var copyLen = pinnedArr.Length - (int)(BufferEnd - cur); + + pinnedArr.AsSpan()[..copyLen].CopyTo(newPinnedArr); + + pinnedArr = newPinnedArr; + cur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(newPinnedArr)); + + BufferEnd = cur + newLen; + cur += copyLen; + } + } + // Rooted to keep function pointer alive readonly LuaFunction garnetCall; @@ -310,7 +403,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } ForceGrowLuaStack(AdditionalStackSpace); - + var neededStackSpace = argCount + AdditionalStackSpace; if (!NativeMethods.CheckBuffer(state.Handle, 1, out var cmdSpan)) @@ -470,7 +563,7 @@ int LuaError(ReadOnlySpan msg) const int NeededStackSize = 1; ForceGrowLuaStack(NeededStackSize); - + CheckedPushBuffer(NeededStackSize, msg); return state.Error(); } @@ -593,11 +686,11 @@ unsafe int ProcessResponse(byte* ptr, int length) } /// - /// Runs the precompiled Lua function with specified parse state. + /// Runs the precompiled Lua function with the given outer session. /// - /// Meant for use directly from Garnet. + /// Response is written directly into the . /// - public object RunForParseState(int count, SessionParseState parseState) + public void RunForSession(int count, RespServerSession outerSession) { const int NeededStackSize = 3; @@ -607,6 +700,8 @@ public object RunForParseState(int count, SessionParseState parseState) scratchBufferManager.Reset(); + var parseState = outerSession.parseState; + var offset = 1; var nKeys = parseState.GetInt(offset++); count--; @@ -669,13 +764,15 @@ public object RunForParseState(int count, SessionParseState parseState) AssertLuaStackEmpty(); + var adapter = new RespResponseAdapter(outerSession); + if (txnMode && nKeys > 0) { - return RunInTransaction(); + RunInTransaction(ref adapter); } else { - return RunCommon(); + RunCommon(ref adapter); } } @@ -686,8 +783,13 @@ public object RunForParseState(int count, SessionParseState parseState) /// public object RunForRunner(string[] keys = null, string[] argv = null) { + const int InitialSize = 64; + scratchBufferManager?.Reset(); LoadParametersForRunner(keys, argv); + + var adapter = new RunnerAdapter(GC.AllocateUninitializedArray(InitialSize, pinned: true)); + if (txnMode && keys?.Length > 0) { // Add keys to the transaction @@ -698,18 +800,23 @@ public object RunForRunner(string[] keys = null, string[] argv = null) if (!respServerSession.storageSession.objectStoreLockableContext.IsNull) txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive); } - return RunInTransaction(); + + RunInTransaction(ref adapter); } else { - return RunCommon(); + RunCommon(ref adapter); } + + // TODO: convert response into object + throw new NotImplementedException(); } /// /// Calls after setting up appropriate state for a transaction. /// - object RunInTransaction() + void RunInTransaction(ref TResponse response) + where TResponse : struct, IResponseAdapter { try { @@ -719,7 +826,7 @@ object RunInTransaction() respServerSession.SetTransactionMode(true); txnKeyEntries.LockAllKeys(); - return RunCommon(); + RunCommon(ref response); } finally { @@ -874,7 +981,8 @@ static int PrepareString(string raw, ref byte[] arr, ref Span span) /// /// Runs the precompiled Lua function. /// - object RunCommon() + unsafe void RunCommon(ref TResponse resp) + where TResponse : struct, IResponseAdapter { const int NeededStackSize = 2; @@ -899,24 +1007,41 @@ object RunCommon() if (state.GetTop() == 0) { - return null; + while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; } var retType = state.Type(1); if (retType == LuaType.Nil) { - return null; + while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; } else if (retType == LuaType.Number) { // Redis unconditionally converts all "number" replies to integer replies so we match that // // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - return (long)state.CheckNumber(1); + var num = (long)state.CheckNumber(1); + + while (!RespWriteUtils.WriteInteger(num, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; } else if (retType == LuaType.String) { - return state.CheckString(1); + var checkRes = NativeMethods.CheckBuffer(state.Handle, 1, out var buf); + Debug.Assert(checkRes, "Should never fail"); + + while (!RespWriteUtils.WriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; } else if (retType == LuaType.Boolean) { @@ -925,64 +1050,70 @@ object RunCommon() // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion if (state.ToBoolean(1)) { - return 1L; + while (!RespWriteUtils.WriteInteger(1, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); } else { - return null; + while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); } + + return; } else if (retType == LuaType.Table) { - // TODO: this is hacky, and doesn't support nested arrays or whatever - // but is good enough for now - // when refactored to avoid intermediate objects this should be fixed - - // TODO: because we are dealing with a user provided type, we MUST respect - // metatables - so we can't use any of the RawXXX methods - - // if the key err is in there, we need to short circuit - CheckedPushBuffer(NeededStackSize, "err"u8); - - var errType = state.GetTable(1); - if (errType == LuaType.String) - { - var errStr = state.CheckString(2); - // hack hack hack - // todo: all this goes away when we write results directly - return new ErrorResult(errStr); - } - - state.Pop(1); - - // Otherwise, we need to convert the table to an array - var tableLength = state.Length(1); - - var ret = new object[tableLength]; - for (var i = 1; i <= tableLength; i++) - { - var type = state.GetInteger(1, i); - switch (type) - { - case LuaType.String: - ret[i - 1] = state.CheckString(2); - break; - case LuaType.Number: - ret[i - 1] = (long)state.CheckNumber(2); - break; - case LuaType.Boolean: - ret[i - 1] = state.ToBoolean(2) ? 1L : null; - break; - // Redis stops processesing the array when a nil is encountered - // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - case LuaType.Nil: - return ret.Take(i - 1).ToArray(); - } - - state.Pop(1); - } + throw new NotImplementedException(); - return ret; + //// TODO: this is hacky, and doesn't support nested arrays or whatever + //// but is good enough for now + //// when refactored to avoid intermediate objects this should be fixed + + //// TODO: because we are dealing with a user provided type, we MUST respect + //// metatables - so we can't use any of the RawXXX methods + + //// if the key err is in there, we need to short circuit + //CheckedPushBuffer(NeededStackSize, "err"u8); + + //var errType = state.GetTable(1); + //if (errType == LuaType.String) + //{ + // var errStr = state.CheckString(2); + // // hack hack hack + // // todo: all this goes away when we write results directly + // return new ErrorResult(errStr); + //} + + //state.Pop(1); + + //// Otherwise, we need to convert the table to an array + //var tableLength = state.Length(1); + + //var ret = new object[tableLength]; + //for (var i = 1; i <= tableLength; i++) + //{ + // var type = state.GetInteger(1, i); + // switch (type) + // { + // case LuaType.String: + // ret[i - 1] = state.CheckString(2); + // break; + // case LuaType.Number: + // ret[i - 1] = (long)state.CheckNumber(2); + // break; + // case LuaType.Boolean: + // ret[i - 1] = state.ToBoolean(2) ? 1L : null; + // break; + // // Redis stops processesing the array when a nil is encountered + // // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + // case LuaType.Nil: + // return ret.Take(i - 1).ToArray(); + // } + + // state.Pop(1); + //} + + //return ret; } else { diff --git a/libs/server/Resp/RespServerSession.cs b/libs/server/Resp/RespServerSession.cs index a5b00f020c..5c6dbdacae 100644 --- a/libs/server/Resp/RespServerSession.cs +++ b/libs/server/Resp/RespServerSession.cs @@ -86,7 +86,7 @@ internal sealed unsafe partial class RespServerSession : ServerSessionBase /// int endReadHead; - byte* dcurr, dend; + internal byte* dcurr, dend; bool toDispose; int opCount; @@ -970,7 +970,7 @@ private unsafe bool Write(int seqNo, ref byte* dst, int length) } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void SendAndReset() + internal void SendAndReset() { byte* d = networkSender.GetResponseObjectHead(); if ((int)(dcurr - d) > 0) From ff2f7877bc404be2cdd310f4102380c76ae77a4b Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 14:49:50 -0500 Subject: [PATCH 24/51] first pass of tables directly into network stream --- libs/common/RespReadUtils.cs | 1 - libs/server/Lua/LuaRunner.cs | 272 ++++++++++++++++++++++------------- 2 files changed, 175 insertions(+), 98 deletions(-) diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs index b3168a2484..a82ed7ee87 100644 --- a/libs/common/RespReadUtils.cs +++ b/libs/common/RespReadUtils.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using System; -using System.Buffers; using System.Buffers.Text; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index e713c69fa1..b88d465f99 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -399,7 +399,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) if (argCount == 0) { - return LuaError("Please specify at least one argument for this redis lib call"u8); + return LuaError("ERR Please specify at least one argument for this redis lib call"u8); } ForceGrowLuaStack(AdditionalStackSpace); @@ -551,7 +551,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) /// int ErrorInvalidArgumentType(int neededCapacity) { - CheckedPushBuffer(neededCapacity, "Lua redis lib command arguments must be strings or integers"u8); + CheckedPushBuffer(neededCapacity, "ERR Lua redis lib command arguments must be strings or integers"u8); return state.Error(); } @@ -609,7 +609,7 @@ unsafe int ProcessResponse(byte* ptr, int length) if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD)) { // Gets a special response - return LuaError("Unknown Redis command called from script"u8); + return LuaError("ERR Unknown Redis command called from script"u8); } CheckedPushBuffer(NeededStackSize, errSpan); @@ -989,13 +989,12 @@ unsafe void RunCommon(ref TResponse resp) // TODO: mapping is dependent on Resp2 vs Resp3 settings // and that's not implemented at all - // TODO: this shouldn't read the result, it should write the response out AssertLuaStackEmpty(); - ForceGrowLuaStack(NeededStackSize); - try { + ForceGrowLuaStack(NeededStackSize); + CheckedPushNumber(NeededStackSize, functionRegistryIndex); var loadRes = state.GetTable(LuaRegistry.Index); Debug.Assert(loadRes == LuaType.Function, "Unexpected type for function to invoke"); @@ -1007,113 +1006,89 @@ unsafe void RunCommon(ref TResponse resp) if (state.GetTop() == 0) { - while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - + WriteNull(state, 0, ref resp); return; } var retType = state.Type(1); if (retType == LuaType.Nil) { - while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - + WriteNull(state, 1, ref resp); return; } else if (retType == LuaType.Number) { - // Redis unconditionally converts all "number" replies to integer replies so we match that - // - // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - var num = (long)state.CheckNumber(1); - - while (!RespWriteUtils.WriteInteger(num, ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - + WriteNumber(state, 1, ref resp); return; } else if (retType == LuaType.String) { - var checkRes = NativeMethods.CheckBuffer(state.Handle, 1, out var buf); - Debug.Assert(checkRes, "Should never fail"); - - while (!RespWriteUtils.WriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - + WriteString(state, 1, ref resp); return; } else if (retType == LuaType.Boolean) { - // Redis maps Lua false to null, and Lua true to 1 this is strange, but documented - // - // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - if (state.ToBoolean(1)) - { - while (!RespWriteUtils.WriteInteger(1, ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - } - else - { - while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - } - + WriteBoolean(state, 1, ref resp); return; } else if (retType == LuaType.Table) { - throw new NotImplementedException(); + // TODO: this is hacky, and doesn't support nested arrays or whatever + // but is good enough for now + // when refactored to avoid intermediate objects this should be fixed + + // TODO: because we are dealing with a user provided type, we MUST respect + // metatables - so we can't use any of the RawXXX methods + // so we need a test that use metatables (and compare to how Redis does this) + + // If the key err is in there, we need to short circuit + CheckedPushBuffer(NeededStackSize, "err"u8); - //// TODO: this is hacky, and doesn't support nested arrays or whatever - //// but is good enough for now - //// when refactored to avoid intermediate objects this should be fixed - - //// TODO: because we are dealing with a user provided type, we MUST respect - //// metatables - so we can't use any of the RawXXX methods - - //// if the key err is in there, we need to short circuit - //CheckedPushBuffer(NeededStackSize, "err"u8); - - //var errType = state.GetTable(1); - //if (errType == LuaType.String) - //{ - // var errStr = state.CheckString(2); - // // hack hack hack - // // todo: all this goes away when we write results directly - // return new ErrorResult(errStr); - //} - - //state.Pop(1); - - //// Otherwise, we need to convert the table to an array - //var tableLength = state.Length(1); - - //var ret = new object[tableLength]; - //for (var i = 1; i <= tableLength; i++) - //{ - // var type = state.GetInteger(1, i); - // switch (type) - // { - // case LuaType.String: - // ret[i - 1] = state.CheckString(2); - // break; - // case LuaType.Number: - // ret[i - 1] = (long)state.CheckNumber(2); - // break; - // case LuaType.Boolean: - // ret[i - 1] = state.ToBoolean(2) ? 1L : null; - // break; - // // Redis stops processesing the array when a nil is encountered - // // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - // case LuaType.Nil: - // return ret.Take(i - 1).ToArray(); - // } - - // state.Pop(1); - //} - - //return ret; + var errType = state.GetTable(1); + if (errType == LuaType.String) + { + WriteError(state, 2, ref resp); + + // Remove table from stack + state.Pop(1); + + return; + } + + // Remove whatever we read from the table under the "err" key + state.Pop(1); + + // Map this table to an array + + // Lua # operator - this stops at nils, so we don't need to explicitly handle them + // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 + var tableLength = state.Length(1); + + while (!RespWriteUtils.WriteArrayLength((int)tableLength, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + for (var i = 1; i <= tableLength; i++) + { + // Push item at index i onto the stack + var type = state.GetInteger(1, i); + switch (type) + { + case LuaType.String: + WriteString(state, 2, ref resp); + break; + case LuaType.Number: + WriteNumber(state, 2, ref resp); + break; + case LuaType.Boolean: + WriteBoolean(state, 2, ref resp); + break; + default: + throw new NotImplementedException(); + } + } + + // Remove table from stack + state.Pop(1); } else { @@ -1128,19 +1103,122 @@ unsafe void RunCommon(ref TResponse resp) var stackTop = state.GetTop(); if (stackTop == 0) { - // and we got nothing back - throw new GarnetException("An error occurred while invoking a Lua script"); + while (!RespWriteUtils.WriteError("ERR An error occurred while invoking a Lua script"u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; } + else if (stackTop == 1) + { + if (NativeMethods.CheckBuffer(state.Handle, 1, out var errBuf)) + { + while (!RespWriteUtils.WriteError(errBuf, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + } - // Todo: we should just write this out, not throw it's not exceptional - var msg = state.CheckString(stackTop); - throw new GarnetException(msg); + state.Pop(1); + + return; + } + else + { + logger?.LogError("Got an unexpected number of values back from a pcall error {stackTop} {callRes}", stackTop, callRes); + + while (!RespWriteUtils.WriteError("ERR Unexpected error response"u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + state.Pop(stackTop); + + return; + } } } finally { - // FORCE the stack to be empty now - state.SetTop(0); + AssertLuaStackEmpty(); + } + + // Write a null RESP value, remove the top value on the stack if there is one + static void WriteNull(Lua state, int top, ref TResponse resp) + { + Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); + + while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + // The stack _could_ be empty if we're writing a null, so check before popping + if (top != 0) + { + state.Pop(1); + } + } + + // Writes the number on the top of the stack, removes it from the stack + static void WriteNumber(Lua state, int top, ref TResponse resp) + { + Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); + Debug.Assert(state.Type(top) == LuaType.Number, "Number was not on top of stack"); + + // Redis unconditionally converts all "number" replies to integer replies so we match that + // + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + var num = (long)state.CheckNumber(top); + + while (!RespWriteUtils.WriteInteger(num, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + state.Pop(1); + } + + // Writes the string on the top of the stack, removes it from the stack + static void WriteString(Lua state, int top, ref TResponse resp) + { + Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); + + var checkRes = NativeMethods.CheckBuffer(state.Handle, top, out var buf); + Debug.Assert(checkRes, "Should never fail"); + + while (!RespWriteUtils.WriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + state.Pop(1); + } + + // Writes the boolean on the top of the stack, removes it from the stack + static void WriteBoolean(Lua state, int top, ref TResponse resp) + { + Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); + Debug.Assert(state.Type(top) == LuaType.Boolean, "Boolean was not on top of stack"); + + // Redis maps Lua false to null, and Lua true to 1 this is strange, but documented + // + // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion + if (state.ToBoolean(top)) + { + while (!RespWriteUtils.WriteInteger(1, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + } + else + { + while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + } + + state.Pop(1); + } + + // Writes the string on the top of the stack out as an error, removes the string from the stack + static void WriteError(Lua state, int top, ref TResponse resp) + { + Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); + + var errRes = NativeMethods.CheckBuffer(state.Handle, top, out var errBuff); + Debug.Assert(errRes, "Should never fail"); + + while (!RespWriteUtils.WriteError(errBuff, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + state.Pop(1); } } From eda5ba4f9691c9dfd38ced51ebb75aab70ed614a Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 16:28:52 -0500 Subject: [PATCH 25/51] complex data types now written out directly --- libs/server/Lua/LuaRunner.cs | 134 +++++++++++++++++++++-------- test/Garnet.test/LuaScriptTests.cs | 115 +++++++++++++++++++++++++ 2 files changed, 213 insertions(+), 36 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index b88d465f99..87282db017 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -271,6 +271,8 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ /// public void Compile() { + // TODO: remove exceptions from this path + const int NeededStackSpace = 2; Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); @@ -1011,7 +1013,9 @@ unsafe void RunCommon(ref TResponse resp) } var retType = state.Type(1); - if (retType == LuaType.Nil) + var isNullish = retType is LuaType.Nil or LuaType.UserData or LuaType.Function or LuaType.Thread or LuaType.UserData; + + if (isNullish) { WriteNull(state, 1, ref resp); return; @@ -1059,41 +1063,8 @@ unsafe void RunCommon(ref TResponse resp) state.Pop(1); // Map this table to an array - - // Lua # operator - this stops at nils, so we don't need to explicitly handle them - // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 - var tableLength = state.Length(1); - - while (!RespWriteUtils.WriteArrayLength((int)tableLength, ref resp.BufferCur, resp.BufferEnd)) - resp.SendAndReset(); - - for (var i = 1; i <= tableLength; i++) - { - // Push item at index i onto the stack - var type = state.GetInteger(1, i); - switch (type) - { - case LuaType.String: - WriteString(state, 2, ref resp); - break; - case LuaType.Number: - WriteNumber(state, 2, ref resp); - break; - case LuaType.Boolean: - WriteBoolean(state, 2, ref resp); - break; - default: - throw new NotImplementedException(); - } - } - - // Remove table from stack - state.Pop(1); - } - else - { - // TODO: implement - throw new NotImplementedException(); + var maxStackDepth = NeededStackSize; + WriteArray(state, 1, ref resp, ref maxStackDepth); } } else @@ -1220,6 +1191,86 @@ static void WriteError(Lua state, int top, ref TResponse resp) state.Pop(1); } + + static void WriteArray(Lua state, int top, ref TResponse resp, ref int maxStackDepth) + { + // 1 for the table, 1 for the pending value + const int AdditonalNeededStackSize = 2; + + Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); + Debug.Assert(state.Type(top) == LuaType.Table, "Table was not on top of stack"); + + // Lua # operator - this MAY stop at nils, but isn't guaranteed to + // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 + var maxLen = state.Length(top); + + // Find the TRUE length by scanning for nils + var trueLen = 0; + for (trueLen = 0; trueLen < maxLen; trueLen++) + { + var type = state.GetInteger(top, trueLen + 1); + state.Pop(1); + + if (type == LuaType.Nil) + { + break; + } + } + + while (!RespWriteUtils.WriteArrayLength((int)trueLen, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + var valueStackSlot = top + 1; + + for (var i = 1; i <= trueLen; i++) + { + // Push item at index i onto the stack + var type = state.GetInteger(top, i); + + switch (type) + { + case LuaType.String: + WriteString(state, valueStackSlot, ref resp); + break; + case LuaType.Number: + WriteNumber(state, valueStackSlot, ref resp); + break; + case LuaType.Boolean: + WriteBoolean(state, valueStackSlot, ref resp); + break; + + + case LuaType.Table: + // For tables, we need to recurse - which means we need to check stack sizes again + if (maxStackDepth < valueStackSlot + AdditonalNeededStackSize) + { + try + { + ForceGrowLuaStack(state, AdditonalNeededStackSize); + maxStackDepth += AdditonalNeededStackSize; + } + catch + { + // This is the only place we can raise an exception, cull the Stack + state.SetTop(0); + + throw; + } + } + + WriteArray(state, valueStackSlot, ref resp, ref maxStackDepth); + + break; + + // All other Lua types map to nulls + default: + WriteNull(state, valueStackSlot, ref resp); + break; + } + } + + state.Pop(1); + } } /// @@ -1231,6 +1282,17 @@ static void WriteError(Lua state, int top, ref TResponse resp) /// [MethodImpl(MethodImplOptions.AggressiveInlining)] private void ForceGrowLuaStack(int additionalCapacity) + => ForceGrowLuaStack(state, additionalCapacity); + + /// + /// Ensure there's enough space on the Lua stack for more items. + /// + /// Throws if there is not. + /// + /// Prefer using this to calling directly. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void ForceGrowLuaStack(Lua state, int additionalCapacity) { if (!state.CheckStack(additionalCapacity)) { diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index 695f556b78..06d4d645a8 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -621,5 +621,120 @@ public void NumberArgumentCoercion() var res = (string)db.ScriptEvaluate("return redis.call('GET', 2.1)"); ClassicAssert.AreEqual("world", res); } + + [Test] + public void ComplexLuaReturns() + { + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + // Relatively complicated + { + var res1 = (RedisResult[])db.ScriptEvaluate("return { 1, 'hello', { true, false }, { fizz = 'buzz', hello = 'world' }, 5, 6 }"); + ClassicAssert.AreEqual(6, res1.Length); + ClassicAssert.AreEqual(1, (long)res1[0]); + ClassicAssert.AreEqual("hello", (string)res1[1]); + var res1Sub1 = (RedisResult[])res1[2]; + ClassicAssert.AreEqual(2, res1Sub1.Length); + ClassicAssert.AreEqual(1, (long)res1Sub1[0]); + ClassicAssert.IsTrue(res1Sub1[1].IsNull); + var res1Sub2 = (RedisResult[])res1[2]; + ClassicAssert.AreEqual(2, res1Sub2.Length); + ClassicAssert.IsTrue((bool)res1Sub2[0]); + ClassicAssert.IsTrue(res1Sub2[1].IsNull); + var res1Sub3 = (RedisResult[])res1[3]; + ClassicAssert.AreEqual(0, res1Sub3.Length); + ClassicAssert.AreEqual(5, (long)res1[4]); + ClassicAssert.AreEqual(6, (long)res1[5]); + } + + // Only indexable will be included + { + var res2 = (RedisResult[])db.ScriptEvaluate("return { 1, 2, fizz='buzz' }"); + ClassicAssert.AreEqual(2, res2.Length); + ClassicAssert.AreEqual(1, (long)res2[0]); + ClassicAssert.AreEqual(2, (long)res2[1]); + } + + // Non-string, non-number, are nullish + { + var res3 = (RedisResult[])db.ScriptEvaluate("return { 1, function() end, 3 }"); + ClassicAssert.AreEqual(3, res3.Length); + ClassicAssert.AreEqual(1, (long)res3[0]); + ClassicAssert.IsTrue(res3[1].IsNull); + ClassicAssert.AreEqual(3, (long)res3[2]); + } + + // Nil stops return of subsequent values + { + var res4 = (RedisResult[])db.ScriptEvaluate("return { 1, nil, 3 }"); + ClassicAssert.AreEqual(1, res4.Length); + ClassicAssert.AreEqual(1, (long)res4[0]); + } + + // Incredibly deeply nested return + { + const int Depth = 100; + + var tableDepth = new StringBuilder(); + for (var i = 1; i <= Depth; i++) + { + if (i != 1) + { + tableDepth.Append(", "); + } + tableDepth.Append("{ "); + tableDepth.Append(i); + } + for (var i = 1; i <= Depth; i++) + { + tableDepth.Append(" }"); + } + + var script = "return " + tableDepth.ToString(); + + var res5 = db.ScriptEvaluate(script); + + var cur = res5; + for (var i = 1; i < Depth; i++) + { + var top = (RedisResult[])cur; + ClassicAssert.AreEqual(2, top.Length); + ClassicAssert.AreEqual(i, (long)top[0]); + + cur = top[1]; + } + + // Remainder should have a single element + var remainder = (RedisResult[])cur; + ClassicAssert.AreEqual(1, remainder.Length); + ClassicAssert.AreEqual(Depth, (long)remainder[0]); + } + + // Incredibly wide + { + const int Width = 100_000; + + var tableDepth = new StringBuilder(); + for (var i = 1; i <= Width; i++) + { + if (i != 1) + { + tableDepth.Append(", "); + } + tableDepth.Append("{ " + i + " }"); + } + + var script = "return { " + tableDepth.ToString() + " }"; + + var res5 = (RedisResult[])db.ScriptEvaluate(script); + for (var i = 0; i < Width; i++) + { + var elem = (RedisResult[])res5[i]; + ClassicAssert.AreEqual(1, elem.Length); + ClassicAssert.AreEqual(i + 1, (long)elem[0]); + } + } + } } } \ No newline at end of file From 46c2c9f314d7e65cd95b41079f2ecb09b37f4547 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 17:20:08 -0500 Subject: [PATCH 26/51] Runner (ie. mapping Resp <-> .NET) restored to functionality; bits of cleanup --- benchmark/BDN.benchmark/Lua/LuaScripts.cs | 8 +- libs/server/Lua/LuaCommands.cs | 145 ++------------------- libs/server/Lua/LuaRunner.cs | 147 +++++++++++++++++++--- libs/server/Lua/SessionScriptCache.cs | 8 +- test/Garnet.test/LuaScriptRunnerTests.cs | 20 +-- test/Garnet.test/LuaScriptTests.cs | 12 +- 6 files changed, 167 insertions(+), 173 deletions(-) diff --git a/benchmark/BDN.benchmark/Lua/LuaScripts.cs b/benchmark/BDN.benchmark/Lua/LuaScripts.cs index ae5bea1e9d..e04510a4cb 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScripts.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScripts.cs @@ -35,13 +35,13 @@ public IEnumerable LuaParamsProvider() public void GlobalSetup() { r1 = new LuaRunner("return"); - r1.Compile(); + r1.CompileForRunner(); r2 = new LuaRunner("return 1 + 1"); - r2.Compile(); + r2.CompileForRunner(); r3 = new LuaRunner("return KEYS[1]"); - r3.Compile(); + r3.CompileForRunner(); r4 = new LuaRunner("return redis.call(KEYS[1])"); - r4.Compile(); + r4.CompileForRunner(); } [GlobalCleanup] diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index ed08513911..370782c1e7 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -32,18 +32,16 @@ private unsafe bool TryEVALSHA() var digestAsSpanByteMem = new SpanByteAndMemory(digest.SpanByte); - var result = false; if (!sessionScriptCache.TryGetFromDigest(digestAsSpanByteMem, out var runner)) { if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source)) { - if (!sessionScriptCache.TryLoad(source, digestAsSpanByteMem, out runner, out var error)) + if (!sessionScriptCache.TryLoad(this, source, digestAsSpanByteMem, out runner, out var error)) { - while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) - SendAndReset(); + // TryLoad will have written an error out, it any _ = storeWrapper.storeScriptCache.TryRemove(digestAsSpanByteMem, out _); - return result; + return true; } } } @@ -55,10 +53,10 @@ private unsafe bool TryEVALSHA() } else { - result = ExecuteScript(count - 1, runner); + ExecuteScript(count - 1, runner); } - return result; + return true; } @@ -85,13 +83,10 @@ private unsafe bool TryEVAL() Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; sessionScriptCache.GetScriptDigest(script, digest); - var result = false; - if (!sessionScriptCache.TryLoad(script, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out var error)) + if (!sessionScriptCache.TryLoad(this, script, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out var error)) { - while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) - SendAndReset(); - - return result; + // TryLoad will have written any errors out + return true; } if (runner == null) @@ -101,10 +96,10 @@ private unsafe bool TryEVAL() } else { - result = ExecuteScript(count - 1, runner); + ExecuteScript(count - 1, runner); } - return result; + return true; } /// @@ -200,10 +195,9 @@ private bool NetworkScriptLoad() } var source = parseState.GetArgSliceByRef(0).ToArray(); - if (!sessionScriptCache.TryLoad(source, out var digest, out _, out var error)) + if (!sessionScriptCache.TryLoad(this, source, out var digest, out _, out var error)) { - while (!RespWriteUtils.WriteError(error, ref dcurr, dend)) - SendAndReset(); + // TryLoad will write any errors out } else { @@ -240,7 +234,7 @@ private bool CheckLuaEnabled() /// /// Invoke the execution of a server-side Lua script. /// - private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner) + private void ExecuteScript(int count, LuaRunner scriptRunner) { try { @@ -251,120 +245,7 @@ private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner) logger?.LogError(ex, "Error executing Lua script"); while (!RespWriteUtils.WriteError("ERR " + ex.Message, ref dcurr, dend)) SendAndReset(); - return true; } - return true; } - - //void WriteObject(object scriptResult) - //{ - // if (scriptResult != null) - // { - // if (scriptResult is string s) - // { - // while (!RespWriteUtils.WriteAsciiBulkString(s, ref dcurr, dend)) - // SendAndReset(); - // } - // else if ((scriptResult as byte?) != null && (byte)scriptResult == 36) //equals to $ - // { - // while (!RespWriteUtils.WriteDirect((byte[])scriptResult, ref dcurr, dend)) - // SendAndReset(); - // } - // else if (scriptResult is bool b) - // { - // if (b) - // { - // while (!RespWriteUtils.WriteInteger(1, ref dcurr, dend)) - // SendAndReset(); - // } - // else - // { - // while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) - // SendAndReset(); - // } - // } - // else if (scriptResult is long l) - // { - // while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend)) - // SendAndReset(); - // } - // else if (scriptResult is object[] o) - // { - // var count = o.Length; - // while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) - // SendAndReset(); - - // foreach (var value in o) - // { - // WriteObject(value); - // } - // } - // else if (scriptResult is ErrorResult e) - // { - // while (!RespWriteUtils.WriteError(e.Message, ref dcurr, dend)) - // SendAndReset(); - // } - // else - // { - // // todo: this should all go away - // throw new NotImplementedException(); - // } - // //else if (scriptResult is ArgSlice a) - // //{ - // // while (!RespWriteUtils.WriteBulkString(a.ReadOnlySpan, ref dcurr, dend)) - // // SendAndReset(); - // //} - // //else if (scriptResult is object[] o) - // //{ - // // // Two objects one boolean value and the result from the Lua Call - // // while (!RespWriteUtils.WriteAsciiBulkString(o[1].ToString().AsSpan(), ref dcurr, dend)) - // // SendAndReset(); - // //} - // //else if (scriptResult is LuaTable luaTable) - // //{ - // // try - // // { - // // var retVal = luaTable["err"]; - // // if (retVal != null) - // // { - // // while (!RespWriteUtils.WriteError((string)retVal, ref dcurr, dend)) - // // SendAndReset(); - // // } - // // else - // // { - // // retVal = luaTable["ok"]; - // // if (retVal != null) - // // { - // // while (!RespWriteUtils.WriteAsciiBulkString((string)retVal, ref dcurr, dend)) - // // SendAndReset(); - // // } - // // else - // // { - // // int count = luaTable.Values.Count; - // // while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend)) - // // SendAndReset(); - // // foreach (var value in luaTable.Values) - // // { - // // WriteObject(value); - // // } - // // } - // // } - // // } - // // finally - // // { - // // luaTable.Dispose(); - // // } - // //} - // //else - // //{ - // // throw new LuaScriptException("Unknown return type", ""); - // //} - // } - // else - // { - // while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend)) - // SendAndReset(); - // } - //} } } \ No newline at end of file diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 87282db017..e285fc6e93 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -95,13 +95,23 @@ public unsafe ref byte* BufferCur /// Gets a span that covers the responses as written so far. /// public readonly ReadOnlySpan Response - => new(cur, (int)(BufferEnd - cur)); + + { + get + { + var origin = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(pinnedArr)); + var length = (int)(cur - origin); + + return new(origin, length); + } + } /// public void SendAndReset() { + // We don't actually send anywhere, we grow the backing array var newLen = pinnedArr.Length * 2; - var newPinnedArr = GC.AllocateUninitializedArray(newLen); + var newPinnedArr = GC.AllocateUninitializedArray(newLen, pinned: true); var copyLen = pinnedArr.Length - (int)(BufferEnd - cur); pinnedArr.AsSpan()[..copyLen].CopyTo(newPinnedArr); @@ -267,12 +277,42 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ } /// - /// Compile script + /// Compile script for running in a .NET host. + /// + /// Errors are raised as exceptions. /// - public void Compile() + public unsafe void CompileForRunner() { - // TODO: remove exceptions from this path + var adapter = new RunnerAdapter(GC.AllocateUninitializedArray(64, pinned: true)); + CompileCommon(ref adapter); + + var resp = adapter.Response; + var respStart = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); + var respEnd = respStart + resp.Length; + if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respStart, respEnd)) + { + var errStr = Encoding.UTF8.GetString(errSpan); + throw new GarnetException(errStr); + } + } + /// + /// Compile script for a . + /// + /// Any errors encountered are written out as Resp errors. + /// + public void CompileForSession(RespServerSession session) + { + var adapter = new RespResponseAdapter(session); + CompileCommon(ref adapter); + } + + /// + /// Compile script, writing errors out to given response. + /// + unsafe void CompileCommon(ref TResponse resp) + where TResponse : struct, IResponseAdapter + { const int NeededStackSpace = 2; Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); @@ -293,14 +333,21 @@ public void Compile() var numRets = state.GetTop(); if (numRets == 0) { - throw new GarnetException("Shouldn't happen, no returns from load_sandboxed"); + while (!RespWriteUtils.WriteError("Shouldn't happen, no returns from load_sandboxed"u8, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; } else if (numRets == 1) { var returnType = state.Type(1); if (returnType != LuaType.Function) { - throw new GarnetException($"Could not compile function, got back a {returnType}"); + var errStr = $"Could not compile function, got back a {returnType}"; + while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + return; } functionRegistryIndex = state.Ref(LuaRegistry.Index); @@ -309,11 +356,18 @@ public void Compile() { var error = state.CheckString(2); - throw new GarnetException($"Compilation error: {error}"); + var errStr = $"Compilation error: {error}"; + while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) + resp.SendAndReset(); + + state.Pop(2); + return; } else { + state.Pop(numRets); + throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}"); } } @@ -324,8 +378,7 @@ public void Compile() } finally { - // Force stack empty after compilation, no matter what happens - state.SetTop(0); + AssertLuaStackEmpty(); } } @@ -783,7 +836,7 @@ public void RunForSession(int count, RespServerSession outerSession) /// /// Meant for use from a .NET host rather than in Garnet properly. /// - public object RunForRunner(string[] keys = null, string[] argv = null) + public unsafe object RunForRunner(string[] keys = null, string[] argv = null) { const int InitialSize = 64; @@ -810,8 +863,72 @@ public object RunForRunner(string[] keys = null, string[] argv = null) RunCommon(ref adapter); } - // TODO: convert response into object - throw new NotImplementedException(); + var resp = adapter.Response; + var respCur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp)); + var respEnd = respCur + resp.Length; + + if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respCur, respEnd)) + { + var errStr = Encoding.UTF8.GetString(errSpan); + throw new GarnetException(errStr); + } + + var ret = MapRespToObject(ref respCur, respEnd); + Debug.Assert(respCur == respEnd, "Should have fully consumed response"); + + return ret; + + static object MapRespToObject(ref byte* cur, byte* end) + { + switch (*cur) + { + case (byte)'+': + var simpleStrRes = RespReadUtils.ReadSimpleString(out var simpleStr, ref cur, end); + Debug.Assert(simpleStrRes, "Should never fail"); + + return simpleStr; + + case (byte)':': + var readIntRes = RespReadUtils.Read64Int(out var int64, ref cur, end); + Debug.Assert(readIntRes, "Should never fail"); + + return int64; + + // Error ('-') is handled before call to MapRespToObject + + case (byte)'$': + var length = end - cur; + + if (length >= 5 && new ReadOnlySpan(cur + 1, 4).SequenceEqual("-1\r\n"u8)) + { + return null; + } + + var bulkStrRes = RespReadUtils.ReadStringResponseWithLengthHeader(out var bulkStr, ref cur, end); + Debug.Assert(bulkStrRes, "Should never fail"); + + return bulkStr; + + case (byte)'*': + var arrayLengthRes = RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref cur, end); + Debug.Assert(arrayLengthRes, "Should never fail"); + + if (itemCount == 0) + { + return Array.Empty(); + } + + var array = new object[itemCount]; + for (var i = 0; i < array.Length; i++) + { + array[i] = MapRespToObject(ref cur, end); + } + + return array; + + default: throw new NotImplementedException($"Unexpected sigil {(char)*cur}"); + } + } } /// @@ -1037,10 +1154,6 @@ unsafe void RunCommon(ref TResponse resp) } else if (retType == LuaType.Table) { - // TODO: this is hacky, and doesn't support nested arrays or whatever - // but is good enough for now - // when refactored to avoid intermediate objects this should be fixed - // TODO: because we are dealing with a user provided type, we MUST respect // metatables - so we can't use any of the RawXXX methods // so we need a test that use metatables (and compare to how Redis does this) diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index eddc0d9ffc..616433e689 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -58,15 +58,15 @@ public bool TryGetFromDigest(SpanByteAndMemory digest, out LuaRunner scriptRunne /// /// Load script into the cache /// - public bool TryLoad(byte[] source, out byte[] digest, out LuaRunner runner, out string error) + public bool TryLoad(RespServerSession session, byte[] source, out byte[] digest, out LuaRunner runner, out string error) { digest = new byte[SHA1Len]; GetScriptDigest(source, digest); - return TryLoad(source, new SpanByteAndMemory(new ScriptHashOwner(digest), digest.Length), out runner, out error); + return TryLoad(session, source, new SpanByteAndMemory(new ScriptHashOwner(digest), digest.Length), out runner, out error); } - internal bool TryLoad(byte[] source, SpanByteAndMemory digest, out LuaRunner runner, out string error) + internal bool TryLoad(RespServerSession session, byte[] source, SpanByteAndMemory digest, out LuaRunner runner, out string error) { error = null; @@ -76,7 +76,7 @@ internal bool TryLoad(byte[] source, SpanByteAndMemory digest, out LuaRunner run try { runner = new LuaRunner(source, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); - runner.Compile(); + runner.CompileForSession(session); // need to make sure the key is on the heap, so move it over if needed var storeKeyDigest = digest; diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index eb09a8cf9e..a21e40e689 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -17,7 +17,7 @@ public void CannotRunUnsafeScript() // Try to load an assembly using (var runner = new LuaRunner("luanet.load_assembly('mscorlib')")) { - runner.Compile(); + runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"luanet.load_assembly('mscorlib')\"]:1: attempt to index a nil value (global 'luanet')", ex.Message); } @@ -25,7 +25,7 @@ public void CannotRunUnsafeScript() // Try to call a OS function using (var runner = new LuaRunner("os = require('os'); return os.time();")) { - runner.Compile(); + runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"os = require('os'); return os.time();\"]:1: attempt to call a nil value (global 'require')", ex.Message); } @@ -33,7 +33,7 @@ public void CannotRunUnsafeScript() // Try to execute the input stream using (var runner = new LuaRunner("dofile();")) { - runner.Compile(); + runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"dofile();\"]:1: attempt to call a nil value (global 'dofile')", ex.Message); } @@ -41,7 +41,7 @@ public void CannotRunUnsafeScript() // Try to call a windows executable using (var runner = new LuaRunner("require \"notepad\"")) { - runner.Compile(); + runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"require \"notepad\"\"]:1: attempt to call a nil value (global 'require')", ex.Message); } @@ -49,7 +49,7 @@ public void CannotRunUnsafeScript() // Try to call an OS function using (var runner = new LuaRunner("os.exit();")) { - runner.Compile(); + runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"os.exit();\"]:1: attempt to index a nil value (global 'os')", ex.Message); } @@ -57,7 +57,7 @@ public void CannotRunUnsafeScript() // Try to include a new .net library using (var runner = new LuaRunner("import ('System.Diagnostics');")) { - runner.Compile(); + runner.CompileForRunner(); var ex = Assert.Throws(() => runner.RunForRunner()); ClassicAssert.AreEqual("[string \"import ('System.Diagnostics');\"]:1: attempt to call a nil value (global 'import')", ex.Message); } @@ -69,14 +69,14 @@ public void CanLoadScript() // Code with error using (var runner = new LuaRunner("local;")) { - var ex = Assert.Throws(runner.Compile); + var ex = Assert.Throws(runner.CompileForRunner); ClassicAssert.AreEqual("Compilation error: [string \"local;\"]:1: expected near ';'", ex.Message); } // Code without error using (var runner = new LuaRunner("local list; list = 1; return list;")) { - runner.Compile(); + runner.CompileForRunner(); } } @@ -89,7 +89,7 @@ public void CanRunScript() // Run code without errors using (var runner = new LuaRunner("local list; list = ARGV[1] ; return list;")) { - runner.Compile(); + runner.CompileForRunner(); var res = runner.RunForRunner(keys, args); ClassicAssert.AreEqual("arg1", res); } @@ -97,7 +97,7 @@ public void CanRunScript() // Run code with errors using (var runner = new LuaRunner("local list; list = ; return list;")) { - var ex = Assert.Throws(runner.Compile); + var ex = Assert.Throws(runner.CompileForRunner); ClassicAssert.AreEqual("Compilation error: [string \"local list; list = ; return list;\"]:1: unexpected symbol near ';'", ex.Message); } } diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index 06d4d645a8..cf40a09ae3 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -537,25 +537,25 @@ public void ScriptExistsMultiple() [Test] public void RedisCallErrors() { - // testing that our error replies for redis.call match Redis behavior + // Testing that our error replies for redis.call match Redis behavior // - // todo: exact matching of the hash and line number would also be nice, but that is trickier + // TODO: exact matching of the hash and line number would also be nice, but that is trickier using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); var db = redis.GetDatabase(); - // no args + // No args { var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call()")); ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Please specify at least one argument for this redis lib call")); } - // unknown command + // Unknown command { var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('123')")); ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Unknown Redis command called from script")); } - // bad command type + // Bad command type { var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call({ foo = 'bar'})")); ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); @@ -576,7 +576,7 @@ public void RedisCallErrors() ClassicAssert.IsTrue(exc2.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); } - // other bad arg types + // Other bad arg types { var exc = ClassicAssert.Throws(() => db.ScriptEvaluate("return redis.call('DEL', { foo = 'bar' })")); ClassicAssert.IsTrue(exc.Message.StartsWith("ERR Lua redis lib command arguments must be strings or integers")); From d2f7759e6029121154396f1e2c562f690dd86596 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 18:01:38 -0500 Subject: [PATCH 27/51] knock out some todo; pulling Lua constants in CmdStrings like everything else; avoid copying regular used strings into Lua each time they're needed --- libs/server/Lua/LuaRunner.cs | 124 ++++++++++++++++++++++----------- libs/server/Resp/CmdStrings.cs | 8 +++ 2 files changed, 93 insertions(+), 39 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index e285fc6e93..7551b5e92f 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -128,11 +128,21 @@ public void SendAndReset() readonly LuaFunction garnetCall; // References into Registry on the Lua side - // TODO: essentially all constant strings should be pulled out of registry too to avoid copying cost + // + // These are mix of objects we regularly update, + // constants we want to avoid copying from .NET to Lua, + // and the compiled function definition. readonly int sandboxEnvRegistryIndex; readonly int keysTableRegistryIndex; readonly int argvTableRegistryIndex; readonly int loadSandboxedRegistryIndex; + readonly int okConstStringRegisteryIndex; + readonly int errConstStringRegistryIndex; + readonly int noSessionAvailableConstStringRegisteryIndex; + readonly int pleaseSpecifyRedisCallConstStringRegistryIndex; + readonly int errNoAuthConstStringRegistryIndex; + readonly int errUnknownConstStringRegistryIndex; + readonly int errBadArgConstStringRegistryIndex; int functionRegistryIndex; readonly ReadOnlyMemory source; @@ -153,6 +163,8 @@ public void SendAndReset() /// public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) { + const int NeededStackSize = 1; + this.source = source; this.txnMode = txnMode; this.respServerSession = respServerSession; @@ -170,6 +182,8 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe state = new Lua(); AssertLuaStackEmpty(); + ForceGrowLuaStack(NeededStackSize); + if (txnMode) { txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); @@ -265,6 +279,15 @@ function load_sandboxed(source) Debug.Assert(loadSandboxedType == LuaType.Function, "Unexpected load_sandboxed type"); loadSandboxedRegistryIndex = state.Ref(LuaRegistry.Index); + // Commonly used strings, register them once so we don't have to copy them over each time we need them + okConstStringRegisteryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_OK); + errConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_err); + noSessionAvailableConstStringRegisteryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_No_session_available); + pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call); + errNoAuthConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.RESP_ERR_NOAUTH); + errUnknownConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_ERR_Unknown_Redis_command_called_from_script); + errBadArgConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers); + AssertLuaStackEmpty(); } @@ -276,6 +299,19 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ { } + /// + /// Some strings we use a bunch, and copying them to Lua each time is wasteful + /// + /// So instead we stash them in the Registry and load them by index + /// + int ConstantStringToRegistery(int top, ReadOnlySpan str) + { + AssertLuaStackEmpty(); + + CheckedPushBuffer(top, str); + return state.Ref(LuaRegistry.Index); + } + /// /// Compile script for running in a .NET host. /// @@ -434,7 +470,7 @@ int NoSessionError() ForceGrowLuaStack(NeededStackSpace); - CheckedPushBuffer(NeededStackSpace, "No session available"u8); + CheckedPushConstantString(NeededStackSpace, noSessionAvailableConstStringRegisteryIndex); // this will never return, but we can pretend it does return state.Error(); @@ -454,16 +490,15 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) if (argCount == 0) { - return LuaError("ERR Please specify at least one argument for this redis lib call"u8); + return LuaStaticError(argCount, pleaseSpecifyRedisCallConstStringRegistryIndex); } ForceGrowLuaStack(AdditionalStackSpace); - var neededStackSpace = argCount + AdditionalStackSpace; if (!NativeMethods.CheckBuffer(state.Handle, 1, out var cmdSpan)) { - return ErrorInvalidArgumentType(neededStackSpace); + return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); } // We special-case a few performance-sensitive operations to directly invoke via the storage API @@ -471,12 +506,12 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) { if (!respServerSession.CheckACLPermissions(RespCommand.SET)) { - return LuaError(CmdStrings.RESP_ERR_NOAUTH); + return LuaStaticError(neededStackSpace, errNoAuthConstStringRegistryIndex); } if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan) || !NativeMethods.CheckBuffer(state.Handle, 3, out var valSpan)) { - return ErrorInvalidArgumentType(neededStackSpace); + return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); } // Note these spans are implicitly pinned, as they're actually on the Lua stack @@ -485,19 +520,19 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) _ = api.SET(key, value); - CheckedPushBuffer(neededStackSpace, "OK"u8); + CheckedPushConstantString(neededStackSpace, okConstStringRegisteryIndex); return 1; } else if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "GET"u8) && argCount == 2) { if (!respServerSession.CheckACLPermissions(RespCommand.GET)) { - return LuaError(CmdStrings.RESP_ERR_NOAUTH); + return LuaStaticError(neededStackSpace, errNoAuthConstStringRegistryIndex); } if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan)) { - return ErrorInvalidArgumentType(neededStackSpace); + return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); } // Span is (implicitly) pinned since it's actually on the Lua stack @@ -550,7 +585,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } else { - return ErrorInvalidArgumentType(neededStackSpace); + return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); } } @@ -575,7 +610,6 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) ArrayPool.Shared.Return(cmdArgsArr); } } - } catch (Exception e) { @@ -584,42 +618,25 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // Clear the stack state.SetTop(0); - // Try real hard to raise an error in Lua, but we may just be SOL - // - // We don't use ForceGrowLuaStack here because we're in an exception handler - if (state.CheckStack(AdditionalStackSpace)) - { - // TODO: Remove alloc - var b = Encoding.UTF8.GetBytes(e.Message); - CheckedPushBuffer(AdditionalStackSpace, b); - return state.Error(); - } + ForceGrowLuaStack(1); - throw; + // TODO: Remove alloc + var b = Encoding.UTF8.GetBytes(e.Message); + CheckedPushBuffer(AdditionalStackSpace, b); + return state.Error(); } } - - /// - /// Common failure mode is passing wrong arg, so DRY it up. + /// Cause a Lua error to be raised with a message previously registered. /// - int ErrorInvalidArgumentType(int neededCapacity) - { - CheckedPushBuffer(neededCapacity, "ERR Lua redis lib command arguments must be strings or integers"u8); - return state.Error(); - } - - /// - /// Cause a lua error to be raised with the given message. - /// - int LuaError(ReadOnlySpan msg) + int LuaStaticError(int top, int constStringRegistryIndex) { const int NeededStackSize = 1; ForceGrowLuaStack(NeededStackSize); - CheckedPushBuffer(NeededStackSize, msg); + CheckedPushConstantString(top + NeededStackSize, constStringRegistryIndex); return state.Error(); } @@ -664,7 +681,7 @@ unsafe int ProcessResponse(byte* ptr, int length) if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD)) { // Gets a special response - return LuaError("ERR Unknown Redis command called from script"u8); + return LuaStaticError(NeededStackSize, errUnknownConstStringRegistryIndex); } CheckedPushBuffer(NeededStackSize, errSpan); @@ -1079,6 +1096,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) AssertLuaStackEmpty(); + // TODO: replace with scratchBufferManager static int PrepareString(string raw, ref byte[] arr, ref Span span) { var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length); @@ -1159,7 +1177,7 @@ unsafe void RunCommon(ref TResponse resp) // so we need a test that use metatables (and compare to how Redis does this) // If the key err is in there, we need to short circuit - CheckedPushBuffer(NeededStackSize, "err"u8); + CheckedPushConstantString(NeededStackSize, errConstStringRegistryIndex); var errType = state.GetTable(1); if (errType == LuaType.String) @@ -1442,6 +1460,8 @@ private void AssertLuaStackBelow(int reservedCapacity, [CallerFilePath] string f /// /// This should be used for all PushBuffer calls into Lua. + /// + /// If the string is a constant, consider registering it in the constructor and using instead. /// /// [MethodImpl(MethodImplOptions.AggressiveInlining)] @@ -1487,5 +1507,31 @@ private void CheckedPushBoolean(int reservedCapacity, bool b, [CallerFilePath] s state.PushBoolean(b); } + + /// + /// This should be used to push all known constants strings (registered in constructor with ) + /// into Lua. + /// + /// This avoids extra copying of data between .NET and Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedPushConstantString(int reservedCapacity, int constStringRegistryIndex, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + AssertLuaStackBelow(reservedCapacity, file, method, line); + Debug.Assert(IsConstantStringRegistryIndex(constStringRegistryIndex), "Can't use this with unknown string"); + + var loadRes = state.RawGetInteger(LuaRegistry.Index, constStringRegistryIndex); + Debug.Assert(loadRes == LuaType.String, "Expected constant string to be loaded on stack"); + + // Check if index corresponds to value registered in constructor + bool IsConstantStringRegistryIndex(int index) + => index == okConstStringRegisteryIndex || + index == errConstStringRegistryIndex || + index == noSessionAvailableConstStringRegisteryIndex || + index == pleaseSpecifyRedisCallConstStringRegistryIndex || + index == errNoAuthConstStringRegistryIndex || + index == errUnknownConstStringRegistryIndex || + index == errBadArgConstStringRegistryIndex; + } } } \ No newline at end of file diff --git a/libs/server/Resp/CmdStrings.cs b/libs/server/Resp/CmdStrings.cs index aad5e5e43f..b6beabbfb1 100644 --- a/libs/server/Resp/CmdStrings.cs +++ b/libs/server/Resp/CmdStrings.cs @@ -332,5 +332,13 @@ static partial class CmdStrings public static ReadOnlySpan initiate_replica_sync => "INITIATE_REPLICA_SYNC"u8; public static ReadOnlySpan send_ckpt_file_segment => "SEND_CKPT_FILE_SEGMENT"u8; public static ReadOnlySpan send_ckpt_metadata => "SEND_CKPT_METADATA"u8; + + // Lua scripting strings + public static ReadOnlySpan LUA_OK => "OK"u8; + public static ReadOnlySpan LUA_err => "err"u8; + public static ReadOnlySpan LUA_No_session_available => "No session available"u8; + public static ReadOnlySpan LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call => "ERR Please specify at least one argument for this redis lib call"u8; + public static ReadOnlySpan LUA_ERR_Unknown_Redis_command_called_from_script => "ERR Unknown Redis command called from script"u8; + public static ReadOnlySpan LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers => "ERR Lua redis lib command arguments must be strings or integers"u8; } } \ No newline at end of file From 6e8d40b78a2d9d2e1ef6bd7319a6df28fda89682 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 11 Dec 2024 18:27:03 -0500 Subject: [PATCH 28/51] Benchmark depends on missing sessions always causing nil responses, which is odd, but easy enough to restore --- libs/server/Lua/LuaRunner.cs | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 7551b5e92f..aa9dd1c057 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -435,7 +435,7 @@ public int garnet_call(IntPtr luaStatePtr) if (respServerSession == null) { - return NoSessionError(); + return NoSessionResponse(); } return ProcessCommandFromScripting(respServerSession.basicGarnetApi); @@ -450,7 +450,7 @@ public int garnet_call_txn(IntPtr luaStatePtr) if (respServerSession == null) { - return NoSessionError(); + return NoSessionResponse(); } return ProcessCommandFromScripting(respServerSession.lockableGarnetApi); @@ -459,21 +459,16 @@ public int garnet_call_txn(IntPtr luaStatePtr) /// /// Call somehow came in with no valid resp server session. /// - /// Raise an error. + /// This is used in benchmarking. /// - /// - int NoSessionError() + int NoSessionResponse() { const int NeededStackSpace = 1; - logger?.LogError("Lua call came in without a valid resp session"); - ForceGrowLuaStack(NeededStackSpace); - CheckedPushConstantString(NeededStackSpace, noSessionAvailableConstStringRegisteryIndex); - - // this will never return, but we can pretend it does - return state.Error(); + CheckedPushNil(NeededStackSpace); + return 1; } /// From 135149eafc58a1a2b14a08fad1424c9bc25493b2 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 12:41:40 -0500 Subject: [PATCH 29/51] remove one bespoke array rental --- libs/server/ArgSlice/ScratchBufferManager.cs | 21 +++++++ libs/server/Lua/LuaRunner.cs | 60 ++++++++++---------- 2 files changed, 52 insertions(+), 29 deletions(-) diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index c68c223214..9ac9bcd5d6 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -42,6 +42,12 @@ public ScratchBufferManager() /// public void Reset() => scratchBufferOffset = 0; + /// + /// Return the full buffer managed by this . + /// + public Span FullBuffer() + => scratchBuffer; + /// /// Rewind (pop) the last entry of scratch buffer (rewinding the current scratch buffer offset), /// if it contains the given ArgSlice @@ -329,5 +335,20 @@ public ArgSlice GetSliceFromTail(int length) { return new ArgSlice(scratchBufferHead + scratchBufferOffset - length, length); } + + /// + /// Force backing buffer to grow. + /// + public void GrowBuffer() + { + if (scratchBuffer == null) + { + ExpandScratchBuffer(64); + } + else + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + } + } } } \ No newline at end of file diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index aa9dd1c057..6b9fd3420a 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -72,55 +72,60 @@ public void SendAndReset() /// private unsafe struct RunnerAdapter : IResponseAdapter { - private byte* cur; - private byte[] pinnedArr; + private readonly ScratchBufferManager bufferManager; + private byte* origin; + private byte* curHead; + private byte* curEnd; - internal RunnerAdapter(byte[] initialPinnedArr) + internal RunnerAdapter(ScratchBufferManager bufferManager) { - pinnedArr = initialPinnedArr; - cur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(initialPinnedArr)); - BufferEnd = cur + initialPinnedArr.Length; + this.bufferManager = bufferManager; + this.bufferManager.Reset(); + + var scratchSpace = bufferManager.FullBuffer(); + + origin = curHead = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace)); + curEnd = curHead + scratchSpace.Length; } #pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference /// public unsafe ref byte* BufferCur - => ref cur; + => ref curHead; #pragma warning restore CS9084 /// - public unsafe byte* BufferEnd { readonly get; private set; } + public unsafe byte* BufferEnd + => curEnd; /// /// Gets a span that covers the responses as written so far. /// public readonly ReadOnlySpan Response - { get { - var origin = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(pinnedArr)); - var length = (int)(cur - origin); + var len = (int)(curHead - origin); + + var full = bufferManager.FullBuffer(); - return new(origin, length); + return full[..len]; } } /// public void SendAndReset() { - // We don't actually send anywhere, we grow the backing array - var newLen = pinnedArr.Length * 2; - var newPinnedArr = GC.AllocateUninitializedArray(newLen, pinned: true); - var copyLen = pinnedArr.Length - (int)(BufferEnd - cur); + var len = (int)(curHead - origin); - pinnedArr.AsSpan()[..copyLen].CopyTo(newPinnedArr); + // We don't actually send anywhere, we grow the backing array + bufferManager.GrowBuffer(); - pinnedArr = newPinnedArr; - cur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetArrayDataReference(newPinnedArr)); + var scratchSpace = bufferManager.FullBuffer(); - BufferEnd = cur + newLen; - cur += copyLen; + origin = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace)); + curEnd = origin + scratchSpace.Length; + curHead = origin + len; } } @@ -169,7 +174,7 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe this.txnMode = txnMode; this.respServerSession = respServerSession; this.scratchBufferNetworkSender = scratchBufferNetworkSender; - this.scratchBufferManager = respServerSession?.scratchBufferManager; + this.scratchBufferManager = respServerSession?.scratchBufferManager ?? new(); this.logger = logger; sandboxEnvRegistryIndex = -1; @@ -319,7 +324,7 @@ int ConstantStringToRegistery(int top, ReadOnlySpan str) /// public unsafe void CompileForRunner() { - var adapter = new RunnerAdapter(GC.AllocateUninitializedArray(64, pinned: true)); + var adapter = new RunnerAdapter(scratchBufferManager); CompileCommon(ref adapter); var resp = adapter.Response; @@ -549,9 +554,8 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // in future to provide parse state directly. var trueArgCount = argCount - 1; - // Avoid allocating entirely if fewer than 16 commands (note we only store pointers, we make no copies) - // - // At 17+ we'll rent an array, which might allocate, but typically won't + scratchBufferManager.ResetScratchBuffer(0); + var cmdArgsArr = trueArgCount <= 16 ? null : ArrayPool.Shared.Rent(argCount); var cmdArgs = cmdArgsArr != null ? cmdArgsArr.AsSpan()[..trueArgCount] : stackalloc ArgSlice[trueArgCount]; @@ -850,12 +854,10 @@ public void RunForSession(int count, RespServerSession outerSession) /// public unsafe object RunForRunner(string[] keys = null, string[] argv = null) { - const int InitialSize = 64; - scratchBufferManager?.Reset(); LoadParametersForRunner(keys, argv); - var adapter = new RunnerAdapter(GC.AllocateUninitializedArray(InitialSize, pinned: true)); + var adapter = new RunnerAdapter(scratchBufferManager); if (txnMode && keys?.Length > 0) { From 979b98b436bd60dbf37719a2d307412699b67e77 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 13:01:13 -0500 Subject: [PATCH 30/51] Remvoe another bespoke array rental --- libs/server/ArgSlice/ScratchBufferManager.cs | 48 ++++++++++++ libs/server/Lua/LuaRunner.cs | 82 ++++++++------------ 2 files changed, 82 insertions(+), 48 deletions(-) diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 9ac9bcd5d6..50eca9b0eb 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -222,6 +222,54 @@ public ArgSlice FormatScratch(int headerSize, ReadOnlySpan arg) return retVal; } + public void StartCommand(ReadOnlySpan cmd, int argCount) + { + if (scratchBuffer == null) + ExpandScratchBuffer(64); + + var ptr = scratchBufferHead + scratchBufferOffset; + + while (!RespWriteUtils.WriteArrayLength(argCount+1, ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + scratchBufferOffset = (int)(ptr - scratchBufferHead); + + while (!RespWriteUtils.WriteBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + scratchBufferOffset = (int)(ptr - scratchBufferHead); + } + + public void WriteNullArgument() + { + var ptr = scratchBufferHead + scratchBufferOffset; + + while (!RespWriteUtils.WriteNull(ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + + scratchBufferOffset = (int)(ptr - scratchBufferHead); + } + + public void WriteArgument(ReadOnlySpan arg) + { + var ptr = scratchBufferHead + scratchBufferOffset; + + while (!RespWriteUtils.WriteBulkString(arg, ref ptr, scratchBufferHead + scratchBuffer.Length)) + { + ExpandScratchBuffer(scratchBuffer.Length + 1); + ptr = scratchBufferHead + scratchBufferOffset; + } + + scratchBufferOffset = (int)(ptr - scratchBufferHead); + } + /// /// Format specified command with arguments, as a RESP command. Lua state /// can be specified to handle Lua tables as arguments. diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 6b9fd3420a..3a0a96b94c 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -552,63 +552,49 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized // in future to provide parse state directly. - var trueArgCount = argCount - 1; - scratchBufferManager.ResetScratchBuffer(0); + scratchBufferManager.Reset(); + scratchBufferManager.StartCommand(cmdSpan, argCount - 1); - var cmdArgsArr = trueArgCount <= 16 ? null : ArrayPool.Shared.Rent(argCount); - var cmdArgs = cmdArgsArr != null ? cmdArgsArr.AsSpan()[..trueArgCount] : stackalloc ArgSlice[trueArgCount]; - - try + for (var i = 0; i < argCount - 1; i++) { - for (var i = 0; i < argCount - 1; i++) - { - // Index 1 holds the command, so skip it - var argIx = 2 + i; + var argIx = 2 + i; - var argType = state.Type(argIx); - if (argType == LuaType.Nil) - { - cmdArgs[i] = new ArgSlice(null, -1); - } - else if (argType is LuaType.String or LuaType.Number) - { - // CheckBuffer will coerce a number into a string - // - // Redis nominally converts numbers to integers, but in this case just ToStrings things - var checkRes = NativeMethods.CheckBuffer(state.Handle, argIx, out var span); - Debug.Assert(checkRes, "Should never fail"); - - // Span remains pinned so long as we don't pop the stack - cmdArgs[i] = ArgSlice.FromPinnedSpan(span); - } - else - { - return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); - } + var argType = state.Type(argIx); + if (argType == LuaType.Nil) + { + scratchBufferManager.WriteNullArgument(); + } + else if (argType is LuaType.String or LuaType.Number) + { + // CheckBuffer will coerce a number into a string + // + // Redis nominally converts numbers to integers, but in this case just ToStrings things + var checkRes = NativeMethods.CheckBuffer(state.Handle, argIx, out var span); + Debug.Assert(checkRes, "Should never fail"); + + // Span remains pinned so long as we don't pop the stack + scratchBufferManager.WriteArgument(span); } + else + { + return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); + } + } - var request = scratchBufferManager.FormatCommandAsResp(cmdSpan, cmdArgs); + var request = scratchBufferManager.ViewFullArgSlice(); - // Once the request is formatted, we can release all the args on the Lua stack - // - // This keeps the stack size down for processing the response - state.Pop(argCount); + // Once the request is formatted, we can release all the args on the Lua stack + // + // This keeps the stack size down for processing the response + state.Pop(argCount); - _ = respServerSession.TryConsumeMessages(request.ptr, request.length); + _ = respServerSession.TryConsumeMessages(request.ptr, request.length); - var response = scratchBufferNetworkSender.GetResponse(); - var result = ProcessResponse(response.ptr, response.length); - scratchBufferNetworkSender.Reset(); - return result; - } - finally - { - if (cmdArgsArr != null) - { - ArrayPool.Shared.Return(cmdArgsArr); - } - } + var response = scratchBufferNetworkSender.GetResponse(); + var result = ProcessResponse(response.ptr, response.length); + scratchBufferNetworkSender.Reset(); + return result; } catch (Exception e) { From 360eb16a465eb715c52bc0bfdba2db17efdeafd2 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 14:12:51 -0500 Subject: [PATCH 31/51] Yet another bespoke array rental --- libs/server/ArgSlice/ScratchBufferManager.cs | 84 +++--------------- libs/server/Lua/LuaRunner.cs | 89 ++++++++------------ 2 files changed, 45 insertions(+), 128 deletions(-) diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs index 50eca9b0eb..e2106bfaa7 100644 --- a/libs/server/ArgSlice/ScratchBufferManager.cs +++ b/libs/server/ArgSlice/ScratchBufferManager.cs @@ -222,6 +222,11 @@ public ArgSlice FormatScratch(int headerSize, ReadOnlySpan arg) return retVal; } + /// + /// Start a RESP array to hold a command and arguments. + /// + /// Fill it with calls to and/or . + /// public void StartCommand(ReadOnlySpan cmd, int argCount) { if (scratchBuffer == null) @@ -229,7 +234,7 @@ public void StartCommand(ReadOnlySpan cmd, int argCount) var ptr = scratchBufferHead + scratchBufferOffset; - while (!RespWriteUtils.WriteArrayLength(argCount+1, ref ptr, scratchBufferHead + scratchBuffer.Length)) + while (!RespWriteUtils.WriteArrayLength(argCount + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) { ExpandScratchBuffer(scratchBuffer.Length + 1); ptr = scratchBufferHead + scratchBufferOffset; @@ -244,6 +249,9 @@ public void StartCommand(ReadOnlySpan cmd, int argCount) scratchBufferOffset = (int)(ptr - scratchBufferHead); } + /// + /// Use to fill a RESP array with arguments after a call to . + /// public void WriteNullArgument() { var ptr = scratchBufferHead + scratchBufferOffset; @@ -257,88 +265,20 @@ public void WriteNullArgument() scratchBufferOffset = (int)(ptr - scratchBufferHead); } - public void WriteArgument(ReadOnlySpan arg) - { - var ptr = scratchBufferHead + scratchBufferOffset; - - while (!RespWriteUtils.WriteBulkString(arg, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - /// - /// Format specified command with arguments, as a RESP command. Lua state - /// can be specified to handle Lua tables as arguments. + /// Use to fill a RESP array with arguments after a call to . /// - public ArgSlice FormatCommandAsResp(ReadOnlySpan cmd, ReadOnlySpan args) + public void WriteArgument(ReadOnlySpan arg) { - if (scratchBuffer == null) - ExpandScratchBuffer(64); - - scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected - var commandStartOffset = scratchBufferOffset; var ptr = scratchBufferHead + scratchBufferOffset; - while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length)) + while (!RespWriteUtils.WriteBulkString(arg, ref ptr, scratchBufferHead + scratchBuffer.Length)) { ExpandScratchBuffer(scratchBuffer.Length + 1); ptr = scratchBufferHead + scratchBufferOffset; } - scratchBufferOffset = (int)(ptr - scratchBufferHead); - while (!RespWriteUtils.WriteBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } scratchBufferOffset = (int)(ptr - scratchBufferHead); - - var count = 1; - foreach (var str in args) - { - count++; - - // Smuggling a null-ish value in - if (str.Length < 0) - { - while (!RespWriteUtils.WriteNull(ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - } - else - { - while (!RespWriteUtils.WriteBulkString(str.ReadOnlySpan, ref ptr, scratchBufferHead + scratchBuffer.Length)) - { - ExpandScratchBuffer(scratchBuffer.Length + 1); - ptr = scratchBufferHead + scratchBufferOffset; - } - } - - scratchBufferOffset = (int)(ptr - scratchBufferHead); - } - - if (count != args.Length + 1) - { - var extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1); - if (commandStartOffset < extraSpace) - throw new InvalidOperationException("Invalid number of arguments"); - - commandStartOffset -= extraSpace; - var head = scratchBufferHead + commandStartOffset; - - // There should be space as we have reserved it - _ = RespWriteUtils.WriteArrayLength(count, ref head, scratchBufferHead + scratchBuffer.Length); - } - - var retVal = new ArgSlice(scratchBufferHead + commandStartOffset, scratchBufferOffset - commandStartOffset); - Debug.Assert(scratchBufferOffset <= scratchBuffer.Length); - return retVal; } /// diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 3a0a96b94c..c553fb5d42 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -1022,79 +1022,56 @@ void LoadParametersForRunner(string[] keys, string[] argv) ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); - byte[] encodingBufferArr = null; - Span encodingBuffer = stackalloc byte[64]; - try + if (keys != null) { + // get KEYS on the stack + CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); - if (keys != null) + for (var i = 0; i < keys.Length; i++) { - // get KEYS on the stack - CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); - - for (var i = 0; i < keys.Length; i++) - { - // equivalent to KEYS[i+1] = keys[i] - var key = keys[i]; - - var keyLen = PrepareString(key, ref encodingBufferArr, ref encodingBuffer); - CheckedPushBuffer(NeededStackSize, encodingBuffer[..keyLen]); - - state.RawSetInteger(1, i + 1); - } - - state.Pop(1); + // equivalent to KEYS[i+1] = keys[i] + var key = keys[i]; + PrepareString(key, scratchBufferManager, out var encoded); + CheckedPushBuffer(NeededStackSize, encoded); + state.RawSetInteger(1, i + 1); } - if (argv != null) - { - // get ARGV on the stack - CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); - - for (var i = 0; i < argv.Length; i++) - { - // equivalent to ARGV[i+1] = keys[i] - var arg = argv[i]; - - var argLen = PrepareString(arg, ref encodingBufferArr, ref encodingBuffer); - CheckedPushBuffer(NeededStackSize, encodingBuffer[..argLen]); - - state.RawSetInteger(1, i + 1); - } - - state.Pop(1); - } + state.Pop(1); } - finally + + if (argv != null) { - if (encodingBufferArr != null) + // get ARGV on the stack + CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); + var loadRes = state.GetTable(LuaRegistry.Index); + Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); + + for (var i = 0; i < argv.Length; i++) { - ArrayPool.Shared.Return(encodingBufferArr); + // equivalent to ARGV[i+1] = keys[i] + var arg = argv[i]; + PrepareString(arg, scratchBufferManager, out var encoded); + CheckedPushBuffer(NeededStackSize, encoded); + state.RawSetInteger(1, i + 1); } + + state.Pop(1); } AssertLuaStackEmpty(); - // TODO: replace with scratchBufferManager - static int PrepareString(string raw, ref byte[] arr, ref Span span) + static void PrepareString(string raw, ScratchBufferManager buffer, out ReadOnlySpan strBytes) { var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length); - if (span.Length < maxLen) - { - if (arr != null) - { - ArrayPool.Shared.Return(arr); - } - arr = ArrayPool.Shared.Rent(maxLen); - span = arr; - } + buffer.Reset(); + var argSlice = buffer.CreateArgSlice(maxLen); + var span = argSlice.Span; - return Encoding.UTF8.GetBytes(raw, span); + var written = Encoding.UTF8.GetBytes(raw, span); + strBytes = span[..written]; } } From b286371fce957f39b1ec39940b59ad705f45faf4 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 14:13:05 -0500 Subject: [PATCH 32/51] Cleanup --- libs/server/Lua/LuaRunner.cs | 2 -- 1 file changed, 2 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index c553fb5d42..3893885c16 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using System; -using System.Buffers; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; @@ -154,7 +153,6 @@ public void SendAndReset() readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly RespServerSession respServerSession; - // TODO: all buffers should be rented from this, remove ArrayPool use readonly ScratchBufferManager scratchBufferManager; readonly ILogger logger; readonly Lua state; From f7b541c4faf05fffe20780829d97094ae6e11631 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 14:40:49 -0500 Subject: [PATCH 33/51] more benchmarks for LuaRunner; test that parameter resets work --- .../BDN.benchmark/Lua/LuaRunnerOperations.cs | 220 ++++++++++++++++++ benchmark/BDN.benchmark/Program.cs | 1 + libs/server/Lua/LuaRunner.cs | 2 +- test/Garnet.test/LuaScriptRunnerTests.cs | 29 +++ 4 files changed, 251 insertions(+), 1 deletion(-) create mode 100644 benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs diff --git a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs new file mode 100644 index 0000000000..223d6cceb1 --- /dev/null +++ b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs @@ -0,0 +1,220 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Columns; +using Embedded.perftest; +using Garnet.server; + +namespace BDN.benchmark.Lua +{ + /// + /// Benchmark for non-script running operations in LuaRunner + /// + [MemoryDiagnoser] + [HideColumns(Column.Gen0)] + public unsafe class LuaRunnerOperations + { + private const string SmallScript = "return nil"; + + private const string LargeScript = @" +-- based on a real UDF, with some extraneous ops removed + +local userKey = KEYS[1] +local identifier = ARGV[1] +local currentTime = ARGV[2] +local newSession = ARGV[3] + +local userKeyValue = redis.call(""GET"", userKey) + +local updatedValue = nil +local returnValue = -1 + +if userKeyValue then +-- needs to be updated + + local oldestEntry = nil + local oldestEntryUpdateTime = nil + + local match = nil + + local entryCount = 0 + +-- loop over each entry, looking for one to update + for entry in string.gmatch(userKeyValue, ""([^%|]+)"") do + entryCount = entryCount + 1 + + local entryIdentifier = nil + local entrySessionNumber = -1 + local entryRequestCount = -1 + local entryLastSessionUpdateTime = -1 + + local ix = 0 + for part in string.gmatch(entry, ""([^:]+)"") do + if ix == 0 then + entryIdentifier = part + elseif ix == 1 then + entrySessionNumber = tonumber(part) + elseif ix == 2 then + entryRequestCount = tonumber(part) + elseif ix == 3 then + entryLastSessionUpdateTime = tonumber(part) + else +-- malformed, too many parts + return -1 + end + + ix = ix + 1 + end + + if ix ~= 4 then +-- malformed, too few parts + return -2 + end + + if entryIdentifier == identifier then +-- found the one to update + local updatedEntry = nil + + if tonumber(newSession) == 1 then + local updatedSessionNumber = entrySessionNumber + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(updatedSessionNumber) .. "":1:"" .. tostring(currentTime) + returnValue = 3 + else + local updatedRequestCount = entryRequestCount + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(entrySessionNumber) .. "":"" .. tostring(updatedRequestCount) .. "":"" .. tostring(currentTime) + returnValue = 2 + end + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedEntry = string.gsub(entry, ""%p"", ""%%%1"") + updatedValue = string.gsub(userKeyValue, escapedEntry, updatedEntry) + + break + end + + if oldestEntryUpdateTime == nil or oldestEntryUpdateTime > entryLastSessionUpdateTime then +-- remember the oldest entry, so we can replace it if needed + oldestEntry = entry + oldestEntryUpdateTime = entryLastSessionUpdateTime + end + end + + if updatedValue == nil then +-- we didn't update an existing value, so we need to add it + + local newEntry = identifier .. "":1:1:"" .. tostring(currentTime) + + if entryCount < 20 then +-- there's room, just append it + updatedValue = userKeyValue .. ""|"" .. newEntry + returnValue = 4 + else +-- there isn't room, replace the LRU entry + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedOldestEntry = string.gsub(oldestEntry, ""%p"", ""%%%1"") + + updatedValue = string.gsub(userKeyValue, escapedOldestEntry, newEntry) + returnValue = 5 + end + end +else +-- needs to be created + updatedValue = identifier .. "":1:1:"" .. tostring(currentTime) + + returnValue = 1 +end + +redis.call(""SET"", userKey, updatedValue) + +return returnValue +"; + + /// + /// Lua parameters + /// + [ParamsSource(nameof(LuaParamsProvider))] + public LuaParams Params { get; set; } + + /// + /// Lua parameters provider + /// + public IEnumerable LuaParamsProvider() + { + yield return new(); + } + + private EmbeddedRespServer server; + private RespServerSession session; + + private LuaRunner paramsRunner; + + [GlobalSetup] + public void GlobalSetup() + { + server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true }); + + session = server.GetRespSession(); + paramsRunner = new LuaRunner("return nil"); + } + + [GlobalCleanup] + public void GlobalCleanup() + { + session.Dispose(); + server.Dispose(); + paramsRunner.Dispose(); + } + + [Benchmark] + public void ResetParametersSmall() + { + // First force up + paramsRunner.ResetParameters(1, 1); + + // Then require a small amount of clearing (1 key, 1 arg) + paramsRunner.ResetParameters(0, 0); + } + + [Benchmark] + public void ResetParametersLarge() + { + // First force up + paramsRunner.ResetParameters(10, 10); + + // Then require a large amount of clearing (10 keys, 10 args) + paramsRunner.ResetParameters(0, 0); + } + + [Benchmark] + public void ConstructSmall() + { + using var runner = new LuaRunner(SmallScript); + } + + [Benchmark] + public void ConstructLarge() + { + using var runner = new LuaRunner(LargeScript); + } + + [Benchmark] + public void CompileForSessionSmall() + { + using (var runner = new LuaRunner(SmallScript)) + { + runner.CompileForSession(session); + } + } + + [Benchmark] + public void CompileForSessionLarge() + { + using (var runner = new LuaRunner(LargeScript)) + { + runner.CompileForSession(session); + } + } + } +} \ No newline at end of file diff --git a/benchmark/BDN.benchmark/Program.cs b/benchmark/BDN.benchmark/Program.cs index 06f1b46bde..470908f5e7 100644 --- a/benchmark/BDN.benchmark/Program.cs +++ b/benchmark/BDN.benchmark/Program.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using BDN.benchmark.Lua; using BenchmarkDotNet.Columns; using BenchmarkDotNet.Configs; using BenchmarkDotNet.Environments; diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 3893885c16..d226578d32 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -958,7 +958,7 @@ void RunInTransaction(ref TResponse response) /// /// Remove extra keys and args from KEYS and ARGV globals. /// - void ResetParameters(int nKeys, int nArgs) + internal void ResetParameters(int nKeys, int nArgs) { // TODO: is this faster than punching a function in to do it? const int NeededStackSize = 2; diff --git a/test/Garnet.test/LuaScriptRunnerTests.cs b/test/Garnet.test/LuaScriptRunnerTests.cs index a21e40e689..b7a7a7f0f7 100644 --- a/test/Garnet.test/LuaScriptRunnerTests.cs +++ b/test/Garnet.test/LuaScriptRunnerTests.cs @@ -101,5 +101,34 @@ public void CanRunScript() ClassicAssert.AreEqual("Compilation error: [string \"local list; list = ; return list;\"]:1: unexpected symbol near ';'", ex.Message); } } + + [Test] + public void KeysAndArgsCleared() + { + using (var runner = new LuaRunner("return { KEYS[1], ARGV[1], KEYS[2], ARGV[2] }")) + { + runner.CompileForRunner(); + var res1 = runner.RunForRunner(["hello", "world"], ["fizz", "buzz"]); + var obj1 = (object[])res1; + ClassicAssert.AreEqual(4, obj1.Length); + ClassicAssert.AreEqual("hello", (string)obj1[0]); + ClassicAssert.AreEqual("fizz", (string)obj1[1]); + ClassicAssert.AreEqual("world", (string)obj1[2]); + ClassicAssert.AreEqual("buzz", (string)obj1[3]); + + var res2 = runner.RunForRunner(["abc"], ["def"]); + var obj2 = (object[])res2; + ClassicAssert.AreEqual(2, obj2.Length); + ClassicAssert.AreEqual("abc", (string)obj2[0]); + ClassicAssert.AreEqual("def", (string)obj2[1]); + + var res3 = runner.RunForRunner(["012", "345"], ["678"]); + var obj3 = (object[])res3; + ClassicAssert.AreEqual(3, obj3.Length); + ClassicAssert.AreEqual("012", (string)obj3[0]); + ClassicAssert.AreEqual("678", (string)obj3[1]); + ClassicAssert.AreEqual("345", (string)obj3[2]); + } + } } } \ No newline at end of file From 64c32368fc77e98987fc76e86b7e7538e9d35340 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 15:11:35 -0500 Subject: [PATCH 34/51] most overhead is in pinvoke, so start moving some stuff over --- libs/server/Lua/LuaRunner.cs | 92 ++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 30 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index d226578d32..bc94309dbc 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -140,6 +140,7 @@ public void SendAndReset() readonly int keysTableRegistryIndex; readonly int argvTableRegistryIndex; readonly int loadSandboxedRegistryIndex; + readonly int resetKeysAndArgvRegistryIndex; readonly int okConstStringRegisteryIndex; readonly int errConstStringRegistryIndex; readonly int noSessionAvailableConstStringRegisteryIndex; @@ -235,6 +236,19 @@ function redis.error_reply(text) KEYS = KEYS; ARGV = ARGV; } + -- do resets in the Lua side to minimize pinvokes + function reset_keys_and_argv(fromKey, fromArgv) + local keyCount = #KEYS + for i = fromKey, keyCount do + KEYS[i] = nil + end + + local argvCount = #ARGV + for i = fromArgv, argvCount do + ARGV[i] = nil + end + end + -- responsible for sandboxing user provided code function load_sandboxed(source) if (not source) then return nil end local rawFunc, err = load(source, nil, nil, sandbox_env) @@ -282,6 +296,10 @@ function load_sandboxed(source) Debug.Assert(loadSandboxedType == LuaType.Function, "Unexpected load_sandboxed type"); loadSandboxedRegistryIndex = state.Ref(LuaRegistry.Index); + var resetKeysAndArgvType = state.GetGlobal("reset_keys_and_argv"); + Debug.Assert(resetKeysAndArgvType == LuaType.Function, "Unexpected reset_keys_and_argv type"); + resetKeysAndArgvRegistryIndex = state.Ref(LuaRegistry.Index); + // Commonly used strings, register them once so we don't have to copy them over each time we need them okConstStringRegisteryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_OK); errConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_err); @@ -967,43 +985,57 @@ internal void ResetParameters(int nKeys, int nArgs) ForceGrowLuaStack(NeededStackSize); - if (keyLength > nKeys) + if (keyLength > nKeys || argvLength > nArgs) { - // Get KEYS on the stack - CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); + var getRes = state.RawGetInteger(LuaRegistry.Index, resetKeysAndArgvRegistryIndex); + Debug.Assert(getRes == LuaType.Function, "Unexpected type when loading reset_keys_and_argv"); - // Clear all the values in KEYS that we aren't going to set anyway - for (var i = nKeys + 1; i <= keyLength; i++) - { - CheckedPushNil(NeededStackSize); - state.RawSetInteger(1, i); - } - - state.Pop(1); + CheckedPushNumber(NeededStackSize, nKeys + 1); + CheckedPushNumber(NeededStackSize, nArgs + 1); + var resetRes = state.PCall(2, 0, 0); + Debug.Assert(resetRes == LuaStatus.OK, "Resetting should never fail"); } keyLength = nKeys; - - if (argvLength > nArgs) - { - // Get ARGV on the stack - CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); - - for (var i = nArgs + 1; i <= argvLength; i++) - { - CheckedPushNil(NeededStackSize); - state.RawSetInteger(1, i); - } - - state.Pop(1); - } - argvLength = nArgs; + //if (keyLength > nKeys) + //{ + // // Get KEYS on the stack + // CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); + // var loadRes = state.GetTable(LuaRegistry.Index); + // Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); + + // // Clear all the values in KEYS that we aren't going to set anyway + // for (var i = nKeys + 1; i <= keyLength; i++) + // { + // CheckedPushNil(NeededStackSize); + // state.RawSetInteger(1, i); + // } + + // state.Pop(1); + //} + + //keyLength = nKeys; + + //if (argvLength > nArgs) + //{ + // // Get ARGV on the stack + // CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); + // var loadRes = state.GetTable(LuaRegistry.Index); + // Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); + + // for (var i = nArgs + 1; i <= argvLength; i++) + // { + // CheckedPushNil(NeededStackSize); + // state.RawSetInteger(1, i); + // } + + // state.Pop(1); + //} + + //argvLength = nArgs; + AssertLuaStackEmpty(); } From a08ae3276aa577e66ca842942f00daa67b77e401 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 15:13:41 -0500 Subject: [PATCH 35/51] most overhead is in pinvoke, so start moving some stuff over --- benchmark/BDN.benchmark/Program.cs | 1 - libs/server/Lua/LuaRunner.cs | 40 +----------------------------- 2 files changed, 1 insertion(+), 40 deletions(-) diff --git a/benchmark/BDN.benchmark/Program.cs b/benchmark/BDN.benchmark/Program.cs index 470908f5e7..06f1b46bde 100644 --- a/benchmark/BDN.benchmark/Program.cs +++ b/benchmark/BDN.benchmark/Program.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. -using BDN.benchmark.Lua; using BenchmarkDotNet.Columns; using BenchmarkDotNet.Configs; using BenchmarkDotNet.Environments; diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index bc94309dbc..a44af4c79a 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -978,8 +978,7 @@ void RunInTransaction(ref TResponse response) /// internal void ResetParameters(int nKeys, int nArgs) { - // TODO: is this faster than punching a function in to do it? - const int NeededStackSize = 2; + const int NeededStackSize = 3; AssertLuaStackEmpty(); @@ -999,43 +998,6 @@ internal void ResetParameters(int nKeys, int nArgs) keyLength = nKeys; argvLength = nArgs; - //if (keyLength > nKeys) - //{ - // // Get KEYS on the stack - // CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); - // var loadRes = state.GetTable(LuaRegistry.Index); - // Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); - - // // Clear all the values in KEYS that we aren't going to set anyway - // for (var i = nKeys + 1; i <= keyLength; i++) - // { - // CheckedPushNil(NeededStackSize); - // state.RawSetInteger(1, i); - // } - - // state.Pop(1); - //} - - //keyLength = nKeys; - - //if (argvLength > nArgs) - //{ - // // Get ARGV on the stack - // CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); - // var loadRes = state.GetTable(LuaRegistry.Index); - // Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); - - // for (var i = nArgs + 1; i <= argvLength; i++) - // { - // CheckedPushNil(NeededStackSize); - // state.RawSetInteger(1, i); - // } - - // state.Pop(1); - //} - - //argvLength = nArgs; - AssertLuaStackEmpty(); } From 70b67b1d9ede0de3aeb8171dee4b86106b943b97 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 15:21:27 -0500 Subject: [PATCH 36/51] where we've already proven a type is a number or string, skip a pinvoke to double check the type --- libs/server/Lua/LuaRunner.cs | 13 +++++-------- libs/server/Lua/NativeMethods.cs | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index a44af4c79a..cb4fe8b201 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -583,12 +583,11 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) } else if (argType is LuaType.String or LuaType.Number) { - // CheckBuffer will coerce a number into a string + // KnownStringToBuffer will coerce a number into a string // // Redis nominally converts numbers to integers, but in this case just ToStrings things - var checkRes = NativeMethods.CheckBuffer(state.Handle, argIx, out var span); - Debug.Assert(checkRes, "Should never fail"); - + NativeMethods.KnownStringToBuffer(state.Handle, argIx, out var span); + // Span remains pinned so long as we don't pop the stack scratchBufferManager.WriteArgument(span); } @@ -1229,8 +1228,7 @@ static void WriteString(Lua state, int top, ref TResponse resp) { Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - var checkRes = NativeMethods.CheckBuffer(state.Handle, top, out var buf); - Debug.Assert(checkRes, "Should never fail"); + NativeMethods.KnownStringToBuffer(state.Handle, top, out var buf); while (!RespWriteUtils.WriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); @@ -1266,8 +1264,7 @@ static void WriteError(Lua state, int top, ref TResponse resp) { Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - var errRes = NativeMethods.CheckBuffer(state.Handle, top, out var errBuff); - Debug.Assert(errRes, "Should never fail"); + NativeMethods.KnownStringToBuffer(state.Handle, top, out var errBuff); while (!RespWriteUtils.WriteError(errBuff, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index 367b9f64c1..484e82e5b8 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -22,6 +22,9 @@ namespace Garnet.server /// internal static class NativeMethods { + // TODO: LibraryImport? + // TODO: Suppress GC transition (requires Lua audit) + private const string LuaLibraryName = "lua54"; /// @@ -71,6 +74,24 @@ internal static bool CheckBuffer(lua_State luaState, int index, out ReadOnlySpan } } + /// + /// Call when value at index is KNOWN to be a string or number + /// + /// only remains valid as long as the buffer remains on the stack, + /// use with care. + /// + /// Note that is changes the value on the stack to be a string if it returns true, regardless of + /// what it was originally. + /// + internal static void KnownStringToBuffer(lua_State luaState, int index, out ReadOnlySpan str) + { + var start = lua_tolstring(luaState, index, out var len); + unsafe + { + str = new ReadOnlySpan((byte*)start, (int)len); + } + } + /// /// Pushes given span to stack as a string. /// From 3160ffcca29241c6df6210c91b92b03dd3723591 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 16:38:47 -0500 Subject: [PATCH 37/51] avoid checkstack calls by tracking stack capacity on our side --- libs/server/Lua/LuaRunner.cs | 510 ++++++++++++++++++++++++----------- 1 file changed, 353 insertions(+), 157 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index cb4fe8b201..6f732e8c4a 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -162,6 +162,8 @@ public void SendAndReset() int keyLength, argvLength; + int curStackSize, curStackTop; + /// /// Creates a new runner with the source of the script /// @@ -185,6 +187,7 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe // TODO: custom allocator? state = new Lua(); AssertLuaStackEmpty(); + curStackTop = 0; ForceGrowLuaStack(NeededStackSize); @@ -280,25 +283,20 @@ function load_sandboxed(source) // Register garnet_call in global namespace state.Register("garnet_call", garnetCall); - var sandboxEnvType = state.GetGlobal("sandbox_env"); - Debug.Assert(sandboxEnvType == LuaType.Table, "Unexpected sandbox_env type"); - sandboxEnvRegistryIndex = state.Ref(LuaRegistry.Index); + CheckedGetGlobal(LuaType.Table, "sandbox_env"); + sandboxEnvRegistryIndex = CheckedRef(); - var keyTableType = state.GetGlobal("KEYS"); - Debug.Assert(keyTableType == LuaType.Table, "Unexpected KEYS type"); - keysTableRegistryIndex = state.Ref(LuaRegistry.Index); + CheckedGetGlobal(LuaType.Table, "KEYS"); + keysTableRegistryIndex = CheckedRef(); - var argvTableType = state.GetGlobal("ARGV"); - Debug.Assert(argvTableType == LuaType.Table, "Unexpected ARGV type"); - argvTableRegistryIndex = state.Ref(LuaRegistry.Index); + CheckedGetGlobal(LuaType.Table, "ARGV"); + argvTableRegistryIndex = CheckedRef(); - var loadSandboxedType = state.GetGlobal("load_sandboxed"); - Debug.Assert(loadSandboxedType == LuaType.Function, "Unexpected load_sandboxed type"); - loadSandboxedRegistryIndex = state.Ref(LuaRegistry.Index); + CheckedGetGlobal(LuaType.Function, "load_sandboxed"); + loadSandboxedRegistryIndex = CheckedRef(); - var resetKeysAndArgvType = state.GetGlobal("reset_keys_and_argv"); - Debug.Assert(resetKeysAndArgvType == LuaType.Function, "Unexpected reset_keys_and_argv type"); - resetKeysAndArgvRegistryIndex = state.Ref(LuaRegistry.Index); + CheckedGetGlobal(LuaType.Function, "reset_keys_and_argv"); + resetKeysAndArgvRegistryIndex = CheckedRef(); // Commonly used strings, register them once so we don't have to copy them over each time we need them okConstStringRegisteryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_OK); @@ -330,7 +328,7 @@ int ConstantStringToRegistery(int top, ReadOnlySpan str) AssertLuaStackEmpty(); CheckedPushBuffer(top, str); - return state.Ref(LuaRegistry.Index); + return CheckedRef(); } /// @@ -375,19 +373,21 @@ unsafe void CompileCommon(ref TResponse resp) Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); AssertLuaStackEmpty(); + curStackTop = 0; try { ForceGrowLuaStack(NeededStackSpace); CheckedPushNumber(NeededStackSpace, loadSandboxedRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Function, "Unexpected load_sandboxed type"); + CheckedGetTable(LuaType.Function, (int)LuaRegistry.Index); CheckedPushBuffer(NeededStackSpace, source.Span); - state.Call(1, -1); // Multiple returns allowed + CheckedCall(1, -1); // Multiple returns allowed var numRets = state.GetTop(); + curStackTop = numRets; + if (numRets == 0) { while (!RespWriteUtils.WriteError("Shouldn't happen, no returns from load_sandboxed"u8, ref resp.BufferCur, resp.BufferEnd)) @@ -407,23 +407,23 @@ unsafe void CompileCommon(ref TResponse resp) return; } - functionRegistryIndex = state.Ref(LuaRegistry.Index); + functionRegistryIndex = CheckedRef(); } else if (numRets == 2) { - var error = state.CheckString(2); + NativeMethods.CheckBuffer(state.Handle, 2, out var errorBuf); - var errStr = $"Compilation error: {error}"; + var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}"; while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - state.Pop(2); + CheckedPop(2); return; } else { - state.Pop(numRets); + CheckedPop(numRets); throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}"); } @@ -500,9 +500,13 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) { const int AdditionalStackSpace = 1; + // This is LUA_MINSTACK, which is 20 + curStackSize = 20; + curStackTop = state.GetTop(); + try { - var argCount = state.GetTop(); + var argCount = curStackTop; if (argCount == 0) { @@ -587,7 +591,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // // Redis nominally converts numbers to integers, but in this case just ToStrings things NativeMethods.KnownStringToBuffer(state.Handle, argIx, out var span); - + // Span remains pinned so long as we don't pop the stack scratchBufferManager.WriteArgument(span); } @@ -602,7 +606,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // Once the request is formatted, we can release all the args on the Lua stack // // This keeps the stack size down for processing the response - state.Pop(argCount); + CheckedPop(argCount); _ = respServerSession.TryConsumeMessages(request.ptr, request.length); @@ -617,6 +621,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // Clear the stack state.SetTop(0); + curStackTop = 0; ForceGrowLuaStack(1); @@ -711,8 +716,7 @@ unsafe int ProcessResponse(byte* ptr, int length) if (RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref ptr, ptr + length)) { // Create the new table - state.CreateTable(itemCount, 0); - Debug.Assert(state.GetTop() == 1, "New table should be at top of stack"); + CheckedCreateTable(itemCount, 0); for (var itemIx = 0; itemIx < itemCount; itemIx++) { @@ -732,7 +736,7 @@ unsafe int ProcessResponse(byte* ptr, int length) else { // Error, drop the table we allocated - state.Pop(1); + CheckedPop(1); goto default; } } @@ -744,10 +748,9 @@ unsafe int ProcessResponse(byte* ptr, int length) } // Stack now has table and value at itemIx on it - state.RawSetInteger(1, itemIx + 1); + CheckedRawSetInteger(1, itemIx + 1); } - Debug.Assert(state.GetTop() == 1, "Only the table should be on the stack"); return 1; } goto default; @@ -783,8 +786,7 @@ public void RunForSession(int count, RespServerSession outerSession) { // Get KEYS on the stack CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); - var loadedType = state.RawGet(LuaRegistry.Index); - Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting KEYS"); + CheckedRawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < nKeys; i++) { @@ -800,13 +802,13 @@ public void RunForSession(int count, RespServerSession outerSession) // Equivalent to KEYS[i+1] = key CheckedPushNumber(NeededStackSize, i + 1); CheckedPushBuffer(NeededStackSize, key.ReadOnlySpan); - state.RawSet(1); + CheckedRawSet(1); offset++; } // Remove KEYS from the stack - state.Pop(1); + CheckedPop(1); count -= nKeys; } @@ -815,8 +817,7 @@ public void RunForSession(int count, RespServerSession outerSession) { // Get ARGV on the stack CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); - var loadedType = state.RawGet(LuaRegistry.Index); - Debug.Assert(loadedType == LuaType.Table, "Unexpected type loaded when expecting ARGV"); + CheckedRawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < count; i++) { @@ -825,13 +826,13 @@ public void RunForSession(int count, RespServerSession outerSession) // Equivalent to ARGV[i+1] = argv CheckedPushNumber(NeededStackSize, i + 1); CheckedPushBuffer(NeededStackSize, argv.ReadOnlySpan); - state.RawSet(1); + CheckedRawSet(1); offset++; } // Remove ARGV from the stack - state.Pop(1); + CheckedPop(1); } AssertLuaStackEmpty(); @@ -985,12 +986,12 @@ internal void ResetParameters(int nKeys, int nArgs) if (keyLength > nKeys || argvLength > nArgs) { - var getRes = state.RawGetInteger(LuaRegistry.Index, resetKeysAndArgvRegistryIndex); - Debug.Assert(getRes == LuaType.Function, "Unexpected type when loading reset_keys_and_argv"); + CheckedRawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex); CheckedPushNumber(NeededStackSize, nKeys + 1); CheckedPushNumber(NeededStackSize, nArgs + 1); - var resetRes = state.PCall(2, 0, 0); + + var resetRes = CheckedPCall(2, 0); Debug.Assert(resetRes == LuaStatus.OK, "Resetting should never fail"); } @@ -1017,8 +1018,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) { // get KEYS on the stack CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for KEYS"); + CheckedGetTable(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < keys.Length; i++) { @@ -1026,18 +1026,17 @@ void LoadParametersForRunner(string[] keys, string[] argv) var key = keys[i]; PrepareString(key, scratchBufferManager, out var encoded); CheckedPushBuffer(NeededStackSize, encoded); - state.RawSetInteger(1, i + 1); + CheckedRawSetInteger(1, i + 1); } - state.Pop(1); + CheckedPop(1); } if (argv != null) { // get ARGV on the stack CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Table, "Unexpected type for ARGV"); + CheckedGetTable(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < argv.Length; i++) { @@ -1045,10 +1044,10 @@ void LoadParametersForRunner(string[] keys, string[] argv) var arg = argv[i]; PrepareString(arg, scratchBufferManager, out var encoded); CheckedPushBuffer(NeededStackSize, encoded); - state.RawSetInteger(1, i + 1); + CheckedRawSetInteger(1, i + 1); } - state.Pop(1); + CheckedPop(1); } AssertLuaStackEmpty(); @@ -1084,17 +1083,16 @@ unsafe void RunCommon(ref TResponse resp) ForceGrowLuaStack(NeededStackSize); CheckedPushNumber(NeededStackSize, functionRegistryIndex); - var loadRes = state.GetTable(LuaRegistry.Index); - Debug.Assert(loadRes == LuaType.Function, "Unexpected type for function to invoke"); + CheckedGetTable(LuaType.Function, (int)LuaRegistry.Index); - var callRes = state.PCall(0, 1, 0); + var callRes = CheckedPCall(0, 1); if (callRes == LuaStatus.OK) { // The actual call worked, handle the response - if (state.GetTop() == 0) + if (curStackTop == 0) { - WriteNull(state, 0, ref resp); + WriteNull(this, ref resp); return; } @@ -1103,22 +1101,22 @@ unsafe void RunCommon(ref TResponse resp) if (isNullish) { - WriteNull(state, 1, ref resp); + WriteNull(this, ref resp); return; } else if (retType == LuaType.Number) { - WriteNumber(state, 1, ref resp); + WriteNumber(this, ref resp); return; } else if (retType == LuaType.String) { - WriteString(state, 1, ref resp); + WriteString(this, ref resp); return; } else if (retType == LuaType.Boolean) { - WriteBoolean(state, 1, ref resp); + WriteBoolean(this, ref resp); return; } else if (retType == LuaType.Table) @@ -1130,38 +1128,36 @@ unsafe void RunCommon(ref TResponse resp) // If the key err is in there, we need to short circuit CheckedPushConstantString(NeededStackSize, errConstStringRegistryIndex); - var errType = state.GetTable(1); + var errType = CheckedGetTable(null, 1); if (errType == LuaType.String) { - WriteError(state, 2, ref resp); + WriteError(this, ref resp); // Remove table from stack - state.Pop(1); + CheckedPop(1); return; } // Remove whatever we read from the table under the "err" key - state.Pop(1); + CheckedPop(1); // Map this table to an array - var maxStackDepth = NeededStackSize; - WriteArray(state, 1, ref resp, ref maxStackDepth); + WriteArray(this, ref resp); } } else { // An error was raised - var stackTop = state.GetTop(); - if (stackTop == 0) + if (curStackTop == 0) { while (!RespWriteUtils.WriteError("ERR An error occurred while invoking a Lua script"u8, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); return; } - else if (stackTop == 1) + else if (curStackTop == 1) { if (NativeMethods.CheckBuffer(state.Handle, 1, out var errBuf)) { @@ -1169,18 +1165,19 @@ unsafe void RunCommon(ref TResponse resp) resp.SendAndReset(); } - state.Pop(1); + CheckedPop(1); return; } else { - logger?.LogError("Got an unexpected number of values back from a pcall error {stackTop} {callRes}", stackTop, callRes); + logger?.LogError("Got an unexpected number of values back from a pcall error {callRes}", callRes); while (!RespWriteUtils.WriteError("ERR Unexpected error response"u8, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - state.Pop(stackTop); + state.SetTop(0); + curStackTop = 0; return; } @@ -1192,60 +1189,54 @@ unsafe void RunCommon(ref TResponse resp) } // Write a null RESP value, remove the top value on the stack if there is one - static void WriteNull(Lua state, int top, ref TResponse resp) + static void WriteNull(LuaRunner runner, ref TResponse resp) { - Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - while (!RespWriteUtils.WriteNull(ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); // The stack _could_ be empty if we're writing a null, so check before popping - if (top != 0) + if (runner.curStackTop != 0) { - state.Pop(1); + runner.CheckedPop(1); } } // Writes the number on the top of the stack, removes it from the stack - static void WriteNumber(Lua state, int top, ref TResponse resp) + static void WriteNumber(LuaRunner runner, ref TResponse resp) { - Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - Debug.Assert(state.Type(top) == LuaType.Number, "Number was not on top of stack"); + Debug.Assert(runner.state.Type(runner.curStackTop) == LuaType.Number, "Number was not on top of stack"); // Redis unconditionally converts all "number" replies to integer replies so we match that // // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - var num = (long)state.CheckNumber(top); + var num = (long)runner.state.CheckNumber(runner.curStackTop); while (!RespWriteUtils.WriteInteger(num, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - state.Pop(1); + runner.CheckedPop(1); } // Writes the string on the top of the stack, removes it from the stack - static void WriteString(Lua state, int top, ref TResponse resp) + static void WriteString(LuaRunner runner, ref TResponse resp) { - Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - - NativeMethods.KnownStringToBuffer(state.Handle, top, out var buf); + NativeMethods.KnownStringToBuffer(runner.state.Handle, runner.curStackTop, out var buf); while (!RespWriteUtils.WriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - state.Pop(1); + runner.CheckedPop(1); } // Writes the boolean on the top of the stack, removes it from the stack - static void WriteBoolean(Lua state, int top, ref TResponse resp) + static void WriteBoolean(LuaRunner runner, ref TResponse resp) { - Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - Debug.Assert(state.Type(top) == LuaType.Boolean, "Boolean was not on top of stack"); + Debug.Assert(runner.state.Type(runner.curStackTop) == LuaType.Boolean, "Boolean was not on top of stack"); // Redis maps Lua false to null, and Lua true to 1 this is strange, but documented // // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - if (state.ToBoolean(top)) + if (runner.state.ToBoolean(runner.curStackTop)) { while (!RespWriteUtils.WriteInteger(1, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); @@ -1256,40 +1247,38 @@ static void WriteBoolean(Lua state, int top, ref TResponse resp) resp.SendAndReset(); } - state.Pop(1); + runner.CheckedPop(1); } // Writes the string on the top of the stack out as an error, removes the string from the stack - static void WriteError(Lua state, int top, ref TResponse resp) + static void WriteError(LuaRunner runner, ref TResponse resp) { - Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - - NativeMethods.KnownStringToBuffer(state.Handle, top, out var errBuff); + NativeMethods.KnownStringToBuffer(runner.state.Handle, runner.curStackTop, out var errBuff); while (!RespWriteUtils.WriteError(errBuff, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - state.Pop(1); + runner.CheckedPop(1); } - static void WriteArray(Lua state, int top, ref TResponse resp, ref int maxStackDepth) + static void WriteArray(LuaRunner runner, ref TResponse resp) { // 1 for the table, 1 for the pending value const int AdditonalNeededStackSize = 2; - Debug.Assert(state.GetTop() == top, "Lua stack was not expected size"); - Debug.Assert(state.Type(top) == LuaType.Table, "Table was not on top of stack"); + Debug.Assert(runner.state.Type(runner.curStackTop) == LuaType.Table, "Table was not on top of stack"); // Lua # operator - this MAY stop at nils, but isn't guaranteed to // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 - var maxLen = state.Length(top); + var maxLen = runner.state.Length(runner.curStackTop); + // TODO: is it faster to punch a function in for this? // Find the TRUE length by scanning for nils var trueLen = 0; for (trueLen = 0; trueLen < maxLen; trueLen++) { - var type = state.GetInteger(top, trueLen + 1); - state.Pop(1); + var type = runner.CheckedGetInteger(null, runner.curStackTop, trueLen + 1); + runner.CheckedPop(1); if (type == LuaType.Nil) { @@ -1300,111 +1289,83 @@ static void WriteArray(Lua state, int top, ref TResponse resp, ref int maxStackD while (!RespWriteUtils.WriteArrayLength((int)trueLen, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - var valueStackSlot = top + 1; - for (var i = 1; i <= trueLen; i++) { // Push item at index i onto the stack - var type = state.GetInteger(top, i); + var type = runner.CheckedGetInteger(null, runner.curStackTop, i); switch (type) { case LuaType.String: - WriteString(state, valueStackSlot, ref resp); + WriteString(runner, ref resp); break; case LuaType.Number: - WriteNumber(state, valueStackSlot, ref resp); + WriteNumber(runner, ref resp); break; case LuaType.Boolean: - WriteBoolean(state, valueStackSlot, ref resp); + WriteBoolean(runner, ref resp); break; case LuaType.Table: // For tables, we need to recurse - which means we need to check stack sizes again - if (maxStackDepth < valueStackSlot + AdditonalNeededStackSize) + if (runner.curStackSize < runner.curStackTop + AdditonalNeededStackSize) { try { - ForceGrowLuaStack(state, AdditonalNeededStackSize); - maxStackDepth += AdditonalNeededStackSize; + runner.ForceGrowLuaStack(AdditonalNeededStackSize); } catch { // This is the only place we can raise an exception, cull the Stack - state.SetTop(0); + runner.state.SetTop(0); + runner.curStackTop = 0; throw; } } - WriteArray(state, valueStackSlot, ref resp, ref maxStackDepth); + WriteArray(runner, ref resp); break; // All other Lua types map to nulls default: - WriteNull(state, valueStackSlot, ref resp); + WriteNull(runner, ref resp); break; } } - state.Pop(1); + runner.CheckedPop(1); } } - /// - /// Ensure there's enough space on the Lua stack for more items. - /// - /// Throws if there is not. - /// - /// Prefer using this to calling directly. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void ForceGrowLuaStack(int additionalCapacity) - => ForceGrowLuaStack(state, additionalCapacity); + // TODO: I think we'd prefer all these helpers factor into their own file /// /// Ensure there's enough space on the Lua stack for more items. /// /// Throws if there is not. /// - /// Prefer using this to calling directly. + /// Maintains to avoid unnecessary p/invokes. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void ForceGrowLuaStack(Lua state, int additionalCapacity) + private void ForceGrowLuaStack(int additionalCapacity) { - if (!state.CheckStack(additionalCapacity)) + var availableSpace = curStackSize - curStackTop; + + if (availableSpace >= additionalCapacity) { - throw new GarnetException("Could not reserve additional capacity on the Lua stack"); + return; } - } - /// - /// Check that the Lua stack is empty in DEBUG builds. - /// - /// This is never necessary for correctness, but is often useful to find logical bugs. - /// - [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] - private void AssertLuaStackEmpty([CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - Debug.Assert(state.GetTop() == 0, $"Lua stack not empty when expected ({method}:{line} in {file})"); - } + var needed = additionalCapacity - availableSpace; + if (!state.CheckStack(needed)) + { + throw new GarnetException("Could not reserve additional capacity on the Lua stack"); + } - /// - /// Check the Lua stack has not grown beyond the capacity we initially reserved. - /// - /// This asserts (in DEBUG) that the next .PushXXX will succeed. - /// - /// In practice, Lua almost always gives us enough space (default is ~20 slots) but that's not guaranteed and can be false - /// for complicated redis.call invocations. - /// - [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] - private void AssertLuaStackBelow(int reservedCapacity, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - Debug.Assert(state.GetTop() < reservedCapacity, $"About to push to Lua stack without having reserved sufficient capacity."); + curStackSize += additionalCapacity; } /// @@ -1419,6 +1380,7 @@ private void CheckedPushBuffer(int reservedCapacity, ReadOnlySpan buffer, AssertLuaStackBelow(reservedCapacity, file, method, line); NativeMethods.PushBuffer(state.Handle, buffer); + curStackTop++; } /// @@ -1431,6 +1393,7 @@ private void CheckedPushNil(int reservedCapacity, [CallerFilePath] string file = AssertLuaStackBelow(reservedCapacity, file, method, line); state.PushNil(); + curStackTop++; } /// @@ -1443,6 +1406,7 @@ private void CheckedPushNumber(int reservedCapacity, double number, [CallerFileP AssertLuaStackBelow(reservedCapacity, file, method, line); state.PushNumber(number); + curStackTop++; } /// @@ -1455,6 +1419,202 @@ private void CheckedPushBoolean(int reservedCapacity, bool b, [CallerFilePath] s AssertLuaStackBelow(reservedCapacity, file, method, line); state.PushBoolean(b); + curStackTop++; + } + + /// + /// This should be used for all Pop calls into Lua. + /// + /// Maintains to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedPop(int num) + { + state.Pop(num); + curStackTop -= num; + + AssertLuaStackExpected(); + } + + /// + /// This should be used for all Calls into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedCall(int args, int rets) + { + var oldStackTop = curStackTop; + state.Call(args, rets); + + if (rets < 0) + { + curStackTop = state.GetTop(); + } + else + { + curStackTop = oldStackTop - (args + 1) + rets; + } + + AssertLuaStackExpected(); + } + + /// + /// This should be used for all PCalls into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private LuaStatus CheckedPCall(int args, int rets) + { + var oldStack = curStackTop; + var res = state.PCall(args, rets, 0); + + if (res != LuaStatus.OK || rets < 0) + { + curStackTop = state.GetTop(); + } + else + { + curStackTop = oldStack - (args + 1) + rets; + } + + AssertLuaStackExpected(); + + return res; + } + + /// + /// This should be used for all RawSetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedRawSetInteger(int stackIndex, int tableIndex) + { + state.RawSetInteger(stackIndex, tableIndex); + curStackTop--; + + AssertLuaStackExpected(); + } + + /// This should be used for all RawSets into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedRawSet(int stackIndex) + { + state.RawSet(stackIndex); + curStackTop -= 2; + + AssertLuaStackExpected(); + } + + /// + /// This should be used for all RawGetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedRawGetInteger(LuaType expectedType, int stackIndex, int tableIndex) + { + var actual = state.RawGetInteger(stackIndex, tableIndex); + Debug.Assert(actual == expectedType, "Unexpected type received"); + curStackTop++; + + AssertLuaStackExpected(); + } + + /// + /// This should be used for all GetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private LuaType CheckedGetInteger(LuaType? expectedType, int stackIndex, int tableIndex) + { + var actual = state.GetInteger(stackIndex, tableIndex); + Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); + curStackTop++; + + AssertLuaStackExpected(); + + return actual; + } + + /// + /// This should be used for all RawGets into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void CheckedRawGet(LuaType expectedType, int stackIndex) + { + var actual = state.RawGet(stackIndex); + Debug.Assert(actual == expectedType, "Unexpected type received"); + + AssertLuaStackExpected(); + } + + /// + /// This should be used for all GetTables into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private LuaType CheckedGetTable(LuaType? expectedType, int stackIndex) + { + var actual = state.GetTable(stackIndex); + Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); + + AssertLuaStackExpected(); + + return actual; + } + + /// + /// This should be used for all Refs into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int CheckedRef() + { + var ret = state.Ref(LuaRegistry.Index); + curStackTop--; + + AssertLuaStackExpected(); + + return ret; + } + + /// + /// This should be used for all CreateTables into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + private void CheckedCreateTable(int numArr, int numRec) + { + state.CreateTable(numArr, numRec); + curStackTop++; + + AssertLuaStackExpected(); + } + + /// + /// This should be used for all GetGlobals into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + private void CheckedGetGlobal(LuaType expectedType, string globalName) + { + var type = state.GetGlobal(globalName); + Debug.Assert(type == expectedType, "Unexpected type received"); + + curStackTop++; + + AssertLuaStackExpected(); } /// @@ -1469,8 +1629,7 @@ private void CheckedPushConstantString(int reservedCapacity, int constStringRegi AssertLuaStackBelow(reservedCapacity, file, method, line); Debug.Assert(IsConstantStringRegistryIndex(constStringRegistryIndex), "Can't use this with unknown string"); - var loadRes = state.RawGetInteger(LuaRegistry.Index, constStringRegistryIndex); - Debug.Assert(loadRes == LuaType.String, "Expected constant string to be loaded on stack"); + CheckedRawGetInteger(LuaType.String, (int)LuaRegistry.Index, constStringRegistryIndex); // Check if index corresponds to value registered in constructor bool IsConstantStringRegistryIndex(int index) @@ -1482,5 +1641,42 @@ bool IsConstantStringRegistryIndex(int index) index == errUnknownConstStringRegistryIndex || index == errBadArgConstStringRegistryIndex; } + + /// + /// Check that the Lua stack is empty in DEBUG builds. + /// + /// This is never necessary for correctness, but is often useful to find logical bugs. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private void AssertLuaStackEmpty([CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + Debug.Assert(state.GetTop() == 0, $"Lua stack not empty when expected ({method}:{line} in {file})"); + } + + /// + /// Check that the Lua stack top is where expected in DEBUG builds. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private void AssertLuaStackExpected([CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + Debug.Assert(state.GetTop() == curStackTop, $"Lua stack not where expected ({method}:{line} in {file})"); + } + + /// + /// Check the Lua stack has not grown beyond the capacity we initially reserved. + /// + /// This asserts (in DEBUG) that the next .PushXXX will succeed. + /// + /// In practice, Lua almost always gives us enough space (default is ~20 slots) but that's not guaranteed and can be false + /// for complicated redis.call invocations. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private void AssertLuaStackBelow(int reservedCapacity, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + { + Debug.Assert(state.GetTop() < reservedCapacity, $"About to push to Lua stack without having reserved sufficient capacity."); + } } } \ No newline at end of file From 00765b9d846a5fa13847eace4f4620d216ebd98c Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Thu, 12 Dec 2024 17:31:00 -0500 Subject: [PATCH 38/51] remove more allocs --- .../BDN.benchmark/Lua/LuaRunnerOperations.cs | 18 +- libs/server/Lua/LuaRunner.cs | 195 +++++++++++------- libs/server/Lua/NativeMethods.cs | 21 +- 3 files changed, 148 insertions(+), 86 deletions(-) diff --git a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs index 223d6cceb1..27ecf2467d 100644 --- a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs +++ b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs @@ -150,6 +150,9 @@ public IEnumerable LuaParamsProvider() private LuaRunner paramsRunner; + private LuaRunner smallCompileRunner; + private LuaRunner largeCompileRunner; + [GlobalSetup] public void GlobalSetup() { @@ -157,6 +160,9 @@ public void GlobalSetup() session = server.GetRespSession(); paramsRunner = new LuaRunner("return nil"); + + smallCompileRunner = new LuaRunner(SmallScript); + largeCompileRunner = new LuaRunner(LargeScript); } [GlobalCleanup] @@ -202,19 +208,15 @@ public void ConstructLarge() [Benchmark] public void CompileForSessionSmall() { - using (var runner = new LuaRunner(SmallScript)) - { - runner.CompileForSession(session); - } + smallCompileRunner.ResetCompilation(); + smallCompileRunner.CompileForSession(session); } [Benchmark] public void CompileForSessionLarge() { - using (var runner = new LuaRunner(LargeScript)) - { - runner.CompileForSession(session); - } + largeCompileRunner.ResetCompilation(); + largeCompileRunner.CompileForSession(session); } } } \ No newline at end of file diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 6f732e8c4a..3053f08bb1 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -128,6 +128,82 @@ public void SendAndReset() } } + const string LoaderBlock = @" +import = function () end +redis = {} +function redis.call(...) + return garnet_call(...) +end +function redis.status_reply(text) + return text +end +function redis.error_reply(text) + return { err = 'ERR ' .. text } +end +KEYS = {} +ARGV = {} +sandbox_env = { + tostring = tostring; + next = next; + assert = assert; + tonumber = tonumber; + rawequal = rawequal; + collectgarbage = collectgarbage; + coroutine = coroutine; + type = type; + select = select; + unpack = table.unpack; + gcinfo = gcinfo; + pairs = pairs; + loadstring = loadstring; + ipairs = ipairs; + error = error; + redis = redis; + math = math; + table = table; + string = string; + KEYS = KEYS; + ARGV = ARGV; +} +-- do resets in the Lua side to minimize pinvokes +function reset_keys_and_argv(fromKey, fromArgv) + local keyCount = #KEYS + for i = fromKey, keyCount do + KEYS[i] = nil + end + + local argvCount = #ARGV + for i = fromArgv, argvCount do + ARGV[i] = nil + end +end +-- responsible for sandboxing user provided code +function load_sandboxed(source) + if (not source) then return nil end + local rawFunc, err = load(source, nil, nil, sandbox_env) + + -- compilation error is returned directly + if err then + return rawFunc, err + end + + -- otherwise we wrap the compiled function in a helper + return function() + local rawRet = rawFunc() + + -- handle ok response wrappers without crossing the pinvoke boundary + -- err response wrapper requires a bit more work, but is also rarer + if rawRet and type(rawRet) == ""table"" and rawRet.ok then + return rawRet.ok + end + + return rawRet + end +end +"; + + private static readonly ReadOnlyMemory LoaderBlockBytes = Encoding.UTF8.GetBytes(LoaderBlock); + // Rooted to keep function pointer alive readonly LuaFunction garnetCall; @@ -187,9 +263,10 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe // TODO: custom allocator? state = new Lua(); AssertLuaStackEmpty(); - curStackTop = 0; - ForceGrowLuaStack(NeededStackSize); + ForceGrowLuaStack(20); + + curStackSize = 20; if (txnMode) { @@ -202,80 +279,14 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe garnetCall = garnet_call; } - var sandboxRes = state.DoString(@" - import = function () end - redis = {} - function redis.call(...) - return garnet_call(...) - end - function redis.status_reply(text) - return text - end - function redis.error_reply(text) - return { err = 'ERR ' .. text } - end - KEYS = {} - ARGV = {} - sandbox_env = { - tostring = tostring; - next = next; - assert = assert; - tonumber = tonumber; - rawequal = rawequal; - collectgarbage = collectgarbage; - coroutine = coroutine; - type = type; - select = select; - unpack = table.unpack; - gcinfo = gcinfo; - pairs = pairs; - loadstring = loadstring; - ipairs = ipairs; - error = error; - redis = redis; - math = math; - table = table; - string = string; - KEYS = KEYS; - ARGV = ARGV; - } - -- do resets in the Lua side to minimize pinvokes - function reset_keys_and_argv(fromKey, fromArgv) - local keyCount = #KEYS - for i = fromKey, keyCount do - KEYS[i] = nil - end - - local argvCount = #ARGV - for i = fromArgv, argvCount do - ARGV[i] = nil - end - end - -- responsible for sandboxing user provided code - function load_sandboxed(source) - if (not source) then return nil end - local rawFunc, err = load(source, nil, nil, sandbox_env) - - -- compilation error is returned directly - if err then - return rawFunc, err - end - - -- otherwise we wrap the compiled function in a helper - return function() - local rawRet = rawFunc() - - -- handle ok response wrappers without crossing the pinvoke boundary - -- err response wrapper requires a bit more work, but is also rarer - if rawRet and type(rawRet) == ""table"" and rawRet.ok then - return rawRet.ok - end - - return rawRet - end - end - "); - if (sandboxRes) + var loadRes = CheckedLoadBuffer(LoaderBlockBytes.Span); + if (loadRes != LuaStatus.OK) + { + throw new GarnetException("Could load loader into Lua"); + } + + var sandboxRes = CheckedPCall(0, -1); + if (sandboxRes != LuaStatus.OK) { throw new GarnetException("Could not initialize Lua sandbox state"); } @@ -362,6 +373,18 @@ public void CompileForSession(RespServerSession session) CompileCommon(ref adapter); } + /// + /// Drops compiled function, just for benchmarking purposes. + /// + public void ResetCompilation() + { + if (functionRegistryIndex != -1) + { + state.Unref(LuaRegistry.Index, functionRegistryIndex); + functionRegistryIndex = -1; + } + } + /// /// Compile script, writing errors out to given response. /// @@ -1594,6 +1617,7 @@ private int CheckedRef() /// /// Maintains and to minimize p/invoke calls. /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] private void CheckedCreateTable(int numArr, int numRec) { state.CreateTable(numArr, numRec); @@ -1617,6 +1641,25 @@ private void CheckedGetGlobal(LuaType expectedType, string globalName) AssertLuaStackExpected(); } + /// + /// This should be used for all LoadBuffers into Lua. + /// + /// Note that this is different from pushing a buffer, as the loaded buffer is compiled. + /// + /// Maintains and to minimize p/invoke calls. + /// + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private LuaStatus CheckedLoadBuffer(ReadOnlySpan buffer) + { + var ret = NativeMethods.LoadBuffer(state.Handle, buffer); + curStackTop++; + + AssertLuaStackExpected(); + + return ret; + } + /// /// This should be used to push all known constants strings (registered in constructor with ) /// into Lua. diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index 484e82e5b8..702cfe05e7 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -7,10 +7,8 @@ using KeraLua; using static System.Net.WebRequestMethods; using charptr_t = System.IntPtr; -using lua_Integer = System.Int64; using lua_State = System.IntPtr; using size_t = System.UIntPtr; -using voidptr_t = System.IntPtr; namespace Garnet.server { @@ -45,6 +43,12 @@ internal static class NativeMethods [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] private static extern charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); + /// + /// see: https://www.lua.org/manual/5.3/manual.html#luaL_loadbufferx + /// + [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); + /// /// Returns true if the given index on the stack holds a string or a number. /// @@ -104,5 +108,18 @@ internal static unsafe void PushBuffer(lua_State luaState, ReadOnlySpan st lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length); } } + + /// + /// Push given span to stack, and compiles it. + /// + /// Provided data is copied, and can be reused once this call returns. + /// + internal static unsafe LuaStatus LoadBuffer(lua_State luaState, ReadOnlySpan str) + { + fixed (byte* ptr = str) + { + return luaL_loadbufferx(luaState, (charptr_t)ptr, (size_t)str.Length, (charptr_t)UIntPtr.Zero, (charptr_t)UIntPtr.Zero); + } + } } } From e94813ae1f116dd275323d7a6dc5b4c0f262286f Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 10:03:40 -0500 Subject: [PATCH 39/51] remove more allocs --- libs/server/Lua/LuaCommands.cs | 33 ++++++++++++++++--------- libs/server/Lua/SessionScriptCache.cs | 35 ++++++++++++++++++--------- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index 370782c1e7..2f1c70e82c 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.Collections.Generic; using Garnet.common; using Microsoft.Extensions.Logging; using Tsavorite.core; @@ -36,7 +37,7 @@ private unsafe bool TryEVALSHA() { if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source)) { - if (!sessionScriptCache.TryLoad(this, source, digestAsSpanByteMem, out runner, out var error)) + if (!sessionScriptCache.TryLoad(this, source, digestAsSpanByteMem, out runner, out _, out var error)) { // TryLoad will have written an error out, it any @@ -77,13 +78,13 @@ private unsafe bool TryEVAL() return AbortWithWrongNumberOfArguments("EVAL"); } - var script = parseState.GetArgSliceByRef(0).ToArray(); + ref var script = ref parseState.GetArgSliceByRef(0); // that this is stack allocated is load bearing - if it moves, things will break Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; - sessionScriptCache.GetScriptDigest(script, digest); + sessionScriptCache.GetScriptDigest(script.ReadOnlySpan, digest); - if (!sessionScriptCache.TryLoad(this, script, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out var error)) + if (!sessionScriptCache.TryLoad(this, script.ReadOnlySpan, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out _, out var error)) { // TryLoad will have written any errors out return true; @@ -194,17 +195,27 @@ private bool NetworkScriptLoad() return AbortWithWrongNumberOfArguments("script|load"); } - var source = parseState.GetArgSliceByRef(0).ToArray(); - if (!sessionScriptCache.TryLoad(this, source, out var digest, out _, out var error)) + ref var source = ref parseState.GetArgSliceByRef(0); + + Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; + sessionScriptCache.GetScriptDigest(source.Span, digest); + + if (sessionScriptCache.TryLoad(this, source.ReadOnlySpan, SpanByteAndMemory.FromPinnedSpan(digest), out _, out var digestOnHeap, out var error)) { // TryLoad will write any errors out - } - else - { // Add script to the store dictionary - var scriptKey = new SpanByteAndMemory(new ScriptHashOwner(digest.AsMemory()), digest.Length); - _ = storeWrapper.storeScriptCache.TryAdd(scriptKey, source); + if (digestOnHeap == null) + { + var newAlloc = new SpanByteAndMemory(new ScriptHashOwner(digest.ToArray())); + _ = storeWrapper.storeScriptCache.TryAdd(newAlloc, source.ToArray()); + } + else + { + _ = storeWrapper.storeScriptCache.TryAdd(digestOnHeap.Value, source.ToArray()); + } + + while (!RespWriteUtils.WriteBulkString(digest, ref dcurr, dend)) SendAndReset(); diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index 616433e689..eaf1e26083 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -56,26 +56,32 @@ public bool TryGetFromDigest(SpanByteAndMemory digest, out LuaRunner scriptRunne => scriptCache.TryGetValue(digest, out scriptRunner); /// - /// Load script into the cache + /// Load script into the cache. + /// + /// If necessary, will be set so the allocation can be reused. /// - public bool TryLoad(RespServerSession session, byte[] source, out byte[] digest, out LuaRunner runner, out string error) - { - digest = new byte[SHA1Len]; - GetScriptDigest(source, digest); - - return TryLoad(session, source, new SpanByteAndMemory(new ScriptHashOwner(digest), digest.Length), out runner, out error); - } - - internal bool TryLoad(RespServerSession session, byte[] source, SpanByteAndMemory digest, out LuaRunner runner, out string error) + internal bool TryLoad( + RespServerSession session, + ReadOnlySpan source, + SpanByteAndMemory digest, + out LuaRunner runner, + out SpanByteAndMemory? digestOnHeap, + out string error + ) { error = null; if (scriptCache.TryGetValue(digest, out runner)) + { + digestOnHeap = null; return true; + } try { - runner = new LuaRunner(source, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); + var sourceOnHeap = source.ToArray(); + + runner = new LuaRunner(sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); runner.CompileForSession(session); // need to make sure the key is on the heap, so move it over if needed @@ -85,7 +91,11 @@ internal bool TryLoad(RespServerSession session, byte[] source, SpanByteAndMemor var into = new byte[storeKeyDigest.Length]; storeKeyDigest.AsReadOnlySpan().CopyTo(into); - storeKeyDigest = new SpanByteAndMemory(new ScriptHashOwner(into), into.Length); + digestOnHeap = storeKeyDigest = new SpanByteAndMemory(new ScriptHashOwner(into), into.Length); + } + else + { + digestOnHeap = digest; } _ = scriptCache.TryAdd(storeKeyDigest, runner); @@ -93,6 +103,7 @@ internal bool TryLoad(RespServerSession session, byte[] source, SpanByteAndMemor catch (Exception ex) { error = ex.Message; + digestOnHeap = null; return false; } From 159bfb3f485ad0d1e8d2367226ddac1348a0cdf2 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 10:39:03 -0500 Subject: [PATCH 40/51] add a benchmark for script operations --- .../Lua/LuaScriptCacheOperations.cs | 154 ++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs new file mode 100644 index 0000000000..1469c91796 --- /dev/null +++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using BenchmarkDotNet.Attributes; +using Embedded.perftest; +using Garnet.common; +using Garnet.server; +using Garnet.server.Auth; +using Tsavorite.core; + +namespace BDN.benchmark.Lua +{ + [MemoryDiagnoser] + public class LuaScriptCacheOperations + { + /// + /// Lua parameters + /// + [ParamsSource(nameof(LuaParamsProvider))] + public LuaParams Params { get; set; } + + /// + /// Lua parameters provider + /// + public IEnumerable LuaParamsProvider() + { + yield return new(); + } + + private EmbeddedRespServer server; + private StoreWrapper storeWrapper; + private SessionScriptCache sessionScriptCache; + private RespServerSession session; + + private byte[] outerHitDigest; + private byte[] innerHitDigest; + private byte[] missDigest; + + [GlobalSetup] + public void GlobalSetup() + { + server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true }); + storeWrapper = server.StoreWrapper; + sessionScriptCache = new SessionScriptCache(storeWrapper, new GarnetNoAuthAuthenticator()); + session = server.GetRespSession(); + + outerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + sessionScriptCache.GetScriptDigest("return 1"u8, outerHitDigest); + if (!storeWrapper.storeScriptCache.TryAdd(SpanByteAndMemory.FromPinnedSpan(outerHitDigest), "return 1"u8.ToArray())) + { + throw new InvalidOperationException("Should have been able to load into global cache"); + } + + innerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + sessionScriptCache.GetScriptDigest("return 1 + 1"u8, innerHitDigest); + if (!storeWrapper.storeScriptCache.TryAdd(SpanByteAndMemory.FromPinnedSpan(innerHitDigest), "return 1 + 1"u8.ToArray())) + { + throw new InvalidOperationException("Should have been able to load into global cache"); + } + + missDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + sessionScriptCache.GetScriptDigest("foobar"u8, missDigest); + } + + [GlobalCleanup] + public void GlobalCleanup() + { + session?.Dispose(); + server?.Dispose(); + } + + [IterationSetup] + public void IterationSetup() + { + // Force lookup to do work + sessionScriptCache.Clear(); + + // Make outer hit available for every iteration + if (!sessionScriptCache.TryLoad(session, "return 1"u8, SpanByteAndMemory.FromPinnedSpan(outerHitDigest), out _, out _, out var error)) + { + throw new InvalidOperationException($"Should have been able to load: {error}"); + } + } + + [Benchmark] + public void LookupHit() + { + _ = sessionScriptCache.TryGetFromDigest(SpanByteAndMemory.FromPinnedSpan(outerHitDigest), out _); + } + + [Benchmark] + public void LookupMiss() + { + _ = sessionScriptCache.TryGetFromDigest(SpanByteAndMemory.FromPinnedSpan(missDigest), out _); + } + + [Benchmark] + public void LoadOuterHit() + { + // First if returns true + // + // This is the common case + LoadScript(outerHitDigest); + } + + [Benchmark] + public void LoadInnerHit() + { + // First if returns false, second if returns true + // + // This is expected, but rare + LoadScript(innerHitDigest); + } + + [Benchmark] + public void LoadMiss() + { + // First if returns false, second if returns false + // + // This is extremely unlikely, basically implies an error on the client + LoadScript(missDigest); + } + + [Benchmark] + public void Digest() + { + Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; + sessionScriptCache.GetScriptDigest("return 1 + redis.call('GET', KEYS[1])"u8, digest); + } + + /// + /// The moral equivalent to our cache load operation. + /// + private void LoadScript(Span digest) + { + AsciiUtils.ToLowerInPlace(digest); + + var digestAsSpanByteMem = new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)); + + if (!sessionScriptCache.TryGetFromDigest(digestAsSpanByteMem, out var runner)) + { + if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source)) + { + if (!sessionScriptCache.TryLoad(session, source, digestAsSpanByteMem, out runner, out _, out var error)) + { + // TryLoad will have written an error out, it any + + _ = storeWrapper.storeScriptCache.TryRemove(digestAsSpanByteMem, out _); + } + } + } + } + } +} From 2a46b5b85f92def9633e1525220fed0bbf60e94a Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 11:07:13 -0500 Subject: [PATCH 41/51] script lookup is in the hot path, so optimize the key type we're using a bit --- .../Lua/LuaScriptCacheOperations.cs | 20 +++--- libs/server/Lua/LuaCommands.cs | 48 +++++++++----- libs/server/Lua/ScriptHashKey.cs | 66 +++++++++++++++++++ libs/server/Lua/SessionScriptCache.cs | 33 +++++----- libs/server/Resp/SpanByteAndMemoryComparer.cs | 38 ----------- libs/server/StoreWrapper.cs | 8 ++- 6 files changed, 126 insertions(+), 87 deletions(-) create mode 100644 libs/server/Lua/ScriptHashKey.cs delete mode 100644 libs/server/Resp/SpanByteAndMemoryComparer.cs diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs index 1469c91796..872c396232 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs @@ -46,14 +46,14 @@ public void GlobalSetup() outerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); sessionScriptCache.GetScriptDigest("return 1"u8, outerHitDigest); - if (!storeWrapper.storeScriptCache.TryAdd(SpanByteAndMemory.FromPinnedSpan(outerHitDigest), "return 1"u8.ToArray())) + if (!storeWrapper.storeScriptCache.TryAdd(new(outerHitDigest), "return 1"u8.ToArray())) { throw new InvalidOperationException("Should have been able to load into global cache"); } innerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); sessionScriptCache.GetScriptDigest("return 1 + 1"u8, innerHitDigest); - if (!storeWrapper.storeScriptCache.TryAdd(SpanByteAndMemory.FromPinnedSpan(innerHitDigest), "return 1 + 1"u8.ToArray())) + if (!storeWrapper.storeScriptCache.TryAdd(new(innerHitDigest), "return 1 + 1"u8.ToArray())) { throw new InvalidOperationException("Should have been able to load into global cache"); } @@ -76,7 +76,7 @@ public void IterationSetup() sessionScriptCache.Clear(); // Make outer hit available for every iteration - if (!sessionScriptCache.TryLoad(session, "return 1"u8, SpanByteAndMemory.FromPinnedSpan(outerHitDigest), out _, out _, out var error)) + if (!sessionScriptCache.TryLoad(session, "return 1"u8, new(outerHitDigest), out _, out _, out var error)) { throw new InvalidOperationException($"Should have been able to load: {error}"); } @@ -85,13 +85,13 @@ public void IterationSetup() [Benchmark] public void LookupHit() { - _ = sessionScriptCache.TryGetFromDigest(SpanByteAndMemory.FromPinnedSpan(outerHitDigest), out _); + _ = sessionScriptCache.TryGetFromDigest(new(outerHitDigest), out _); } [Benchmark] public void LookupMiss() { - _ = sessionScriptCache.TryGetFromDigest(SpanByteAndMemory.FromPinnedSpan(missDigest), out _); + _ = sessionScriptCache.TryGetFromDigest(new(missDigest), out _); } [Benchmark] @@ -135,17 +135,17 @@ private void LoadScript(Span digest) { AsciiUtils.ToLowerInPlace(digest); - var digestAsSpanByteMem = new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)); + var digestKey = new ScriptHashKey(digest); - if (!sessionScriptCache.TryGetFromDigest(digestAsSpanByteMem, out var runner)) + if (!sessionScriptCache.TryGetFromDigest(digestKey, out var runner)) { - if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source)) + if (storeWrapper.storeScriptCache.TryGetValue(digestKey, out var source)) { - if (!sessionScriptCache.TryLoad(session, source, digestAsSpanByteMem, out runner, out _, out var error)) + if (!sessionScriptCache.TryLoad(session, source, digestKey, out runner, out _, out var error)) { // TryLoad will have written an error out, it any - _ = storeWrapper.storeScriptCache.TryRemove(digestAsSpanByteMem, out _); + _ = storeWrapper.storeScriptCache.TryRemove(digestKey, out _); } } } diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs index 2f1c70e82c..f9a4df47f8 100644 --- a/libs/server/Lua/LuaCommands.cs +++ b/libs/server/Lua/LuaCommands.cs @@ -29,20 +29,27 @@ private unsafe bool TryEVALSHA() } ref var digest = ref parseState.GetArgSliceByRef(0); - AsciiUtils.ToLowerInPlace(digest.Span); - var digestAsSpanByteMem = new SpanByteAndMemory(digest.SpanByte); + LuaRunner runner = null; - if (!sessionScriptCache.TryGetFromDigest(digestAsSpanByteMem, out var runner)) + // Length check is mandatory, as ScriptHashKey assumes correct length + if (digest.length == SessionScriptCache.SHA1Len) { - if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source)) + AsciiUtils.ToLowerInPlace(digest.Span); + + var scriptKey = new ScriptHashKey(digest.Span); + + if (!sessionScriptCache.TryGetFromDigest(scriptKey, out runner)) { - if (!sessionScriptCache.TryLoad(this, source, digestAsSpanByteMem, out runner, out _, out var error)) + if (storeWrapper.storeScriptCache.TryGetValue(scriptKey, out var source)) { - // TryLoad will have written an error out, it any + if (!sessionScriptCache.TryLoad(this, source, scriptKey, out runner, out _, out var error)) + { + // TryLoad will have written an error out, it any - _ = storeWrapper.storeScriptCache.TryRemove(digestAsSpanByteMem, out _); - return true; + _ = storeWrapper.storeScriptCache.TryRemove(scriptKey, out _); + return true; + } } } } @@ -84,7 +91,7 @@ private unsafe bool TryEVAL() Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; sessionScriptCache.GetScriptDigest(script.ReadOnlySpan, digest); - if (!sessionScriptCache.TryLoad(this, script.ReadOnlySpan, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out _, out var error)) + if (!sessionScriptCache.TryLoad(this, script.ReadOnlySpan, new ScriptHashKey(digest), out var runner, out _, out var error)) { // TryLoad will have written any errors out return true; @@ -118,7 +125,7 @@ private bool NetworkScriptExists() return AbortWithWrongNumberOfArguments("script|exists"); } - // returns an array where each element is a 0 if the script does not exist, and a 1 if it does + // Returns an array where each element is a 0 if the script does not exist, and a 1 if it does while (!RespWriteUtils.WriteArrayLength(parseState.Count, ref dcurr, dend)) SendAndReset(); @@ -126,11 +133,17 @@ private bool NetworkScriptExists() for (var shaIx = 0; shaIx < parseState.Count; shaIx++) { ref var sha1 = ref parseState.GetArgSliceByRef(shaIx); - AsciiUtils.ToLowerInPlace(sha1.Span); + var exists = 0; - var sha1Arg = new SpanByteAndMemory(sha1.SpanByte); + // Length check is required, as ScriptHashKey makes a hard assumption + if (sha1.length == SessionScriptCache.SHA1Len) + { + AsciiUtils.ToLowerInPlace(sha1.Span); - var exists = storeWrapper.storeScriptCache.ContainsKey(sha1Arg) ? 1 : 0; + var sha1Arg = new ScriptHashKey(sha1.Span); + + exists = storeWrapper.storeScriptCache.ContainsKey(sha1Arg) ? 1 : 0; + } while (!RespWriteUtils.WriteArrayItem(exists, ref dcurr, dend)) SendAndReset(); @@ -200,23 +213,22 @@ private bool NetworkScriptLoad() Span digest = stackalloc byte[SessionScriptCache.SHA1Len]; sessionScriptCache.GetScriptDigest(source.Span, digest); - if (sessionScriptCache.TryLoad(this, source.ReadOnlySpan, SpanByteAndMemory.FromPinnedSpan(digest), out _, out var digestOnHeap, out var error)) + if (sessionScriptCache.TryLoad(this, source.ReadOnlySpan, new(digest), out _, out var digestOnHeap, out var error)) { // TryLoad will write any errors out // Add script to the store dictionary if (digestOnHeap == null) { - var newAlloc = new SpanByteAndMemory(new ScriptHashOwner(digest.ToArray())); - _ = storeWrapper.storeScriptCache.TryAdd(newAlloc, source.ToArray()); + var newAlloc = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true); + digest.CopyTo(newAlloc); + _ = storeWrapper.storeScriptCache.TryAdd(new(newAlloc), source.ToArray()); } else { _ = storeWrapper.storeScriptCache.TryAdd(digestOnHeap.Value, source.ToArray()); } - - while (!RespWriteUtils.WriteBulkString(digest, ref dcurr, dend)) SendAndReset(); } diff --git a/libs/server/Lua/ScriptHashKey.cs b/libs/server/Lua/ScriptHashKey.cs new file mode 100644 index 0000000000..adcc0a602d --- /dev/null +++ b/libs/server/Lua/ScriptHashKey.cs @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; + +namespace Garnet.server +{ + /// + /// Specialized key type for storing script hashes. + /// + public readonly struct ScriptHashKey : IEquatable + { + // Necessary to keep this alive + private readonly byte[] arrRef; + private readonly unsafe long* ptr; + + internal unsafe ScriptHashKey(ReadOnlySpan stackSpan) + { + Debug.Assert(stackSpan.Length == SessionScriptCache.SHA1Len, "Only one valid length for script hash keys"); + + ptr = (long*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(stackSpan)); + } + + internal unsafe ScriptHashKey(byte[] pohArr) + : this(pohArr.AsSpan()) + { + arrRef = pohArr; + } + + /// + /// Copy key data. + /// + public unsafe void CopyTo(Span into) + { + new Span(ptr, SessionScriptCache.SHA1Len).CopyTo(into); + } + + /// + public unsafe bool Equals(ScriptHashKey other) + { + Debug.Assert(SessionScriptCache.SHA1Len == 40, "Making a hard assumption that we're comparing 40 bytes"); + + var a = ptr; + var b = other.ptr; + + return + *(a++) == *(b++) && + *(a++) == *(b++) && + *(a++) == *(b++) && + *(a++) == *(b++) && + *(a++) == *(b++); + } + + /// + public override unsafe int GetHashCode() + => *(int*)ptr; + + /// + public override bool Equals([NotNullWhen(true)] object obj) + => obj is ScriptHashKey other && Equals(other); + } +} \ No newline at end of file diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index eaf1e26083..93499f84e1 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -25,7 +25,7 @@ internal sealed class SessionScriptCache : IDisposable readonly ScratchBufferNetworkSender scratchBufferNetworkSender; readonly StoreWrapper storeWrapper; readonly ILogger logger; - readonly Dictionary scriptCache = new(SpanByteAndMemoryComparer.Instance); + readonly Dictionary scriptCache = []; readonly byte[] hash = new byte[SHA1Len / 2]; public SessionScriptCache(StoreWrapper storeWrapper, IGarnetAuthenticator authenticator, ILogger logger = null) @@ -52,7 +52,7 @@ public void SetUser(User user) /// /// Try get script runner for given digest /// - public bool TryGetFromDigest(SpanByteAndMemory digest, out LuaRunner scriptRunner) + public bool TryGetFromDigest(ScriptHashKey digest, out LuaRunner scriptRunner) => scriptCache.TryGetValue(digest, out scriptRunner); /// @@ -63,9 +63,9 @@ public bool TryGetFromDigest(SpanByteAndMemory digest, out LuaRunner scriptRunne internal bool TryLoad( RespServerSession session, ReadOnlySpan source, - SpanByteAndMemory digest, + ScriptHashKey digest, out LuaRunner runner, - out SpanByteAndMemory? digestOnHeap, + out ScriptHashKey? digestOnHeap, out string error ) { @@ -84,20 +84,17 @@ out string error runner = new LuaRunner(sourceOnHeap, storeWrapper.serverOptions.LuaTransactionMode, processor, scratchBufferNetworkSender, logger); runner.CompileForSession(session); - // need to make sure the key is on the heap, so move it over if needed - var storeKeyDigest = digest; - if (storeKeyDigest.IsSpanByte) - { - var into = new byte[storeKeyDigest.Length]; - storeKeyDigest.AsReadOnlySpan().CopyTo(into); - - digestOnHeap = storeKeyDigest = new SpanByteAndMemory(new ScriptHashOwner(into), into.Length); - } - else - { - digestOnHeap = digest; - } - + // Need to make sure the key is on the heap, so move it over + // + // There's an implicit assumption that all callers are using unmanaged memory. + // If that becomes untrue, there's an optimization opportunity to re-use the + // managed memory here. + var into = GC.AllocateUninitializedArray(SHA1Len, pinned: true); + digest.CopyTo(into); + + ScriptHashKey storeKeyDigest = new(into); + digestOnHeap = storeKeyDigest; + _ = scriptCache.TryAdd(storeKeyDigest, runner); } catch (Exception ex) diff --git a/libs/server/Resp/SpanByteAndMemoryComparer.cs b/libs/server/Resp/SpanByteAndMemoryComparer.cs deleted file mode 100644 index e3ecc4dded..0000000000 --- a/libs/server/Resp/SpanByteAndMemoryComparer.cs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -using System; -using System.Collections.Generic; -using Tsavorite.core; - -namespace Garnet.server -{ - /// - /// equality comparer. - /// - public sealed class SpanByteAndMemoryComparer : IEqualityComparer - { - /// - /// The default instance. - /// - /// Used to avoid allocating new comparers. - public static readonly SpanByteAndMemoryComparer Instance = new(); - - private SpanByteAndMemoryComparer() { } - - /// - public bool Equals(SpanByteAndMemory left, SpanByteAndMemory right) - => left.AsReadOnlySpan().SequenceEqual(right.AsReadOnlySpan()); - - /// - public unsafe int GetHashCode(SpanByteAndMemory key) - { - var hash = new HashCode(); - hash.AddBytes(key.AsReadOnlySpan()); - - var ret = hash.ToHashCode(); - - return ret; - } - } -} \ No newline at end of file diff --git a/libs/server/StoreWrapper.cs b/libs/server/StoreWrapper.cs index 49b30eff39..428003f335 100644 --- a/libs/server/StoreWrapper.cs +++ b/libs/server/StoreWrapper.cs @@ -99,8 +99,10 @@ public sealed class StoreWrapper internal readonly string run_id; private SingleWriterMultiReaderLock _checkpointTaskLock; - // Lua script cache - public readonly ConcurrentDictionary storeScriptCache; + /// + /// Lua script cache + /// + public readonly ConcurrentDictionary storeScriptCache; public readonly TimeSpan loggingFrequncy; @@ -153,7 +155,7 @@ public StoreWrapper( // Initialize store scripting cache if (serverOptions.EnableLua) - this.storeScriptCache = new(SpanByteAndMemoryComparer.Instance); + this.storeScriptCache = []; if (accessControlList == null) { From dcfe9d2a795dfd74391cb2d420cc5df54f697d67 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 13:25:17 -0500 Subject: [PATCH 42/51] expand ScriptOperations benchmark to actually invoke some functions and do some logic --- .../BDN.benchmark/Lua/LuaRunnerOperations.cs | 2 - .../Lua/LuaScriptCacheOperations.cs | 1 - benchmark/BDN.benchmark/Lua/LuaScripts.cs | 2 - .../Operations/OperationsBase.cs | 2 +- .../Operations/ScriptOperations.cs | 190 ++++++++++++++++++ libs/server/Lua/LuaRunner.cs | 1 + 6 files changed, 192 insertions(+), 6 deletions(-) diff --git a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs index 27ecf2467d..ecd5002bfb 100644 --- a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs +++ b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Columns; using Embedded.perftest; using Garnet.server; @@ -12,7 +11,6 @@ namespace BDN.benchmark.Lua /// Benchmark for non-script running operations in LuaRunner /// [MemoryDiagnoser] - [HideColumns(Column.Gen0)] public unsafe class LuaRunnerOperations { private const string SmallScript = "return nil"; diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs index 872c396232..ca36c75870 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs @@ -6,7 +6,6 @@ using Garnet.common; using Garnet.server; using Garnet.server.Auth; -using Tsavorite.core; namespace BDN.benchmark.Lua { diff --git a/benchmark/BDN.benchmark/Lua/LuaScripts.cs b/benchmark/BDN.benchmark/Lua/LuaScripts.cs index e04510a4cb..3cff4b30e3 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScripts.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScripts.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Columns; using Garnet.server; namespace BDN.benchmark.Lua @@ -11,7 +10,6 @@ namespace BDN.benchmark.Lua /// Benchmark for Lua /// [MemoryDiagnoser] - [HideColumns(Column.Gen0)] public unsafe class LuaScripts { /// diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs index 8d58631fe9..be4fa09608 100644 --- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs +++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs @@ -39,7 +39,7 @@ public IEnumerable OperationParamsProvider() /// 25 us = 4 Mops/sec /// 100 us = 1 Mops/sec /// - const int batchSize = 100; + internal const int batchSize = 100; internal EmbeddedRespServer server; internal RespServerSession session; diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs index f068a7c0ac..b544c14345 100644 --- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs +++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. +using System.Runtime.CompilerServices; +using System.Security.Cryptography; +using System.Text; using BenchmarkDotNet.Attributes; namespace BDN.benchmark.Operations @@ -11,6 +14,124 @@ namespace BDN.benchmark.Operations [MemoryDiagnoser] public unsafe class ScriptOperations : OperationsBase { + // Small script that does 1 operation and no logic + const string SmallScriptText = @"return redis.call('GET', KEYS[1]);"; + + // Large script that does 2 operations and lots of logic + const string LargeScriptText = @" +-- based on a real UDF, with some extraneous ops removed + +local userKey = KEYS[1] +local identifier = ARGV[1] +local currentTime = ARGV[2] +local newSession = ARGV[3] + +local userKeyValue = redis.call(""GET"", userKey) + +local updatedValue = nil +local returnValue = -1 + +if userKeyValue then +-- needs to be updated + + local oldestEntry = nil + local oldestEntryUpdateTime = nil + + local match = nil + + local entryCount = 0 + +-- loop over each entry, looking for one to update + for entry in string.gmatch(userKeyValue, ""([^%|]+)"") do + entryCount = entryCount + 1 + + local entryIdentifier = nil + local entrySessionNumber = -1 + local entryRequestCount = -1 + local entryLastSessionUpdateTime = -1 + + local ix = 0 + for part in string.gmatch(entry, ""([^:]+)"") do + if ix == 0 then + entryIdentifier = part + elseif ix == 1 then + entrySessionNumber = tonumber(part) + elseif ix == 2 then + entryRequestCount = tonumber(part) + elseif ix == 3 then + entryLastSessionUpdateTime = tonumber(part) + else +-- malformed, too many parts + return -1 + end + + ix = ix + 1 + end + + if ix ~= 4 then +-- malformed, too few parts + return -2 + end + + if entryIdentifier == identifier then +-- found the one to update + local updatedEntry = nil + + if tonumber(newSession) == 1 then + local updatedSessionNumber = entrySessionNumber + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(updatedSessionNumber) .. "":1:"" .. tostring(currentTime) + returnValue = 3 + else + local updatedRequestCount = entryRequestCount + 1 + updatedEntry = entryIdentifier .. "":"" .. tostring(entrySessionNumber) .. "":"" .. tostring(updatedRequestCount) .. "":"" .. tostring(currentTime) + returnValue = 2 + end + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedEntry = string.gsub(entry, ""%p"", ""%%%1"") + updatedValue = string.gsub(userKeyValue, escapedEntry, updatedEntry) + + break + end + + if oldestEntryUpdateTime == nil or oldestEntryUpdateTime > entryLastSessionUpdateTime then +-- remember the oldest entry, so we can replace it if needed + oldestEntry = entry + oldestEntryUpdateTime = entryLastSessionUpdateTime + end + end + + if updatedValue == nil then +-- we didn't update an existing value, so we need to add it + + local newEntry = identifier .. "":1:1:"" .. tostring(currentTime) + + if entryCount < 20 then +-- there's room, just append it + updatedValue = userKeyValue .. ""|"" .. newEntry + returnValue = 4 + else +-- there isn't room, replace the LRU entry + +-- have to escape the replacement, since Lua doesn't have a literal replace :/ + local escapedOldestEntry = string.gsub(oldestEntry, ""%p"", ""%%%1"") + + updatedValue = string.gsub(userKeyValue, escapedOldestEntry, newEntry) + returnValue = 5 + end + end +else +-- needs to be created + updatedValue = identifier .. "":1:1:"" .. tostring(currentTime) + + returnValue = 1 +end + +redis.call(""SET"", userKey, updatedValue) + +return returnValue +"; + static ReadOnlySpan SCRIPT_LOAD => "*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n$8\r\nreturn 1\r\n"u8; byte[] scriptLoadRequestBuffer; byte* scriptLoadRequestBufferPointer; @@ -33,6 +154,12 @@ public unsafe class ScriptOperations : OperationsBase byte[] evalShaRequestBuffer; byte* evalShaRequestBufferPointer; + byte[] evalShaSmallScriptBuffer; + byte* evalShaSmallScriptBufferPointer; + + byte[] evalShaLargeScriptBuffer; + byte* evalShaLargeScriptBufferPointer; + public override void GlobalSetup() { base.GlobalSetup(); @@ -50,6 +177,57 @@ public override void GlobalSetup() SetupOperation(ref evalRequestBuffer, ref evalRequestBufferPointer, EVAL); SetupOperation(ref evalShaRequestBuffer, ref evalShaRequestBufferPointer, EVALSHA); + + // Setup small script + var loadSmallScript = $"*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n${SmallScriptText.Length}\r\n{SmallScriptText}\r\n"; + var loadSmallScriptBytes = Encoding.UTF8.GetBytes(loadSmallScript); + fixed (byte* loadPtr = loadSmallScriptBytes) + { + _ = session.TryConsumeMessages(loadPtr, loadSmallScriptBytes.Length); + } + + var smallScriptHash = string.Join("", SHA1.HashData(Encoding.UTF8.GetBytes(SmallScriptText)).Select(static x => x.ToString("x2"))); + var evalShaSmallScript = $"*4\r\n$7\r\nEVALSHA\r\n$40\r\n{smallScriptHash}\r\n$1\r\n1\r\n$3\r\nfoo\r\n"; + evalShaSmallScriptBuffer = GC.AllocateUninitializedArray(evalShaSmallScript.Length * batchSize, pinned: true); + for(var i = 0; i < batchSize; i++) + { + var start = i * evalShaSmallScript.Length; + Encoding.UTF8.GetBytes(evalShaSmallScript, evalShaSmallScriptBuffer.AsSpan().Slice(start, evalShaSmallScript.Length)); + } + evalShaSmallScriptBufferPointer = (byte*)Unsafe.AsPointer(ref evalShaSmallScriptBuffer[0]); + + // Setup large script + var loadLargeScript = $"*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n${LargeScriptText.Length}\r\n{LargeScriptText}\r\n"; + var loadLargeScriptBytes = Encoding.UTF8.GetBytes(loadLargeScript); + fixed (byte* loadPtr = loadLargeScriptBytes) + { + _ = session.TryConsumeMessages(loadPtr, loadLargeScriptBytes.Length); + } + + var largeScriptHash = string.Join("", SHA1.HashData(Encoding.UTF8.GetBytes(LargeScriptText)).Select(static x => x.ToString("x2"))); + var largeScriptEvals = new List(); + for (var i = 0; i < batchSize; i++) + { + var evalShaLargeScript = $"*7\r\n$7\r\nEVALSHA\r\n$40\r\n{largeScriptHash}\r\n$1\r\n1\r\n$5\r\nhello\r\n"; + + var id = Guid.NewGuid().ToString(); + evalShaLargeScript += $"${id.Length}\r\n"; + evalShaLargeScript += $"{id}\r\n"; + + var time = (i * 10).ToString(); + evalShaLargeScript += $"${time.Length}\r\n"; + evalShaLargeScript += $"{time}\r\n"; + + var newSession = i % 2; + evalShaLargeScript += "$1\r\n"; + evalShaLargeScript += $"{newSession}\r\n"; + + var asBytes = Encoding.UTF8.GetBytes(evalShaLargeScript); + largeScriptEvals.AddRange(asBytes); + } + evalShaLargeScriptBuffer = GC.AllocateUninitializedArray(largeScriptEvals.Count, pinned: true); + largeScriptEvals.CopyTo(evalShaLargeScriptBuffer); + evalShaLargeScriptBufferPointer = (byte*)Unsafe.AsPointer(ref evalShaLargeScriptBuffer[0]); } [Benchmark] @@ -81,5 +259,17 @@ public void EvalSha() { _ = session.TryConsumeMessages(evalShaRequestBufferPointer, evalShaRequestBuffer.Length); } + + [Benchmark] + public void SmallScript() + { + _ = session.TryConsumeMessages(evalShaSmallScriptBufferPointer, evalShaSmallScriptBuffer.Length); + } + + [Benchmark] + public void LargeScript() + { + _ = session.TryConsumeMessages(evalShaLargeScriptBufferPointer, evalShaLargeScriptBuffer.Length); + } } } \ No newline at end of file diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 3053f08bb1..08b8c84a08 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -940,6 +940,7 @@ static object MapRespToObject(ref byte* cur, byte* end) if (length >= 5 && new ReadOnlySpan(cur + 1, 4).SequenceEqual("-1\r\n"u8)) { + cur += 5; return null; } From 4cdfa974b1cd8c2248eac2065b5a9bce96afca05 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 15:09:57 -0500 Subject: [PATCH 43/51] huge cleanup; move all the Lua interop into a dedicated class, normalize stack checking, normalize assertions --- libs/server/Lua/LuaRunner.cs | 708 ++++++----------------------- libs/server/Lua/LuaStateWrapper.cs | 552 ++++++++++++++++++++++ 2 files changed, 698 insertions(+), 562 deletions(-) create mode 100644 libs/server/Lua/LuaStateWrapper.cs diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 08b8c84a08..a27b8ae15c 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -2,6 +2,7 @@ // Licensed under the MIT license. using System; +using System.ComponentModel.DataAnnotations; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; @@ -232,21 +233,19 @@ function load_sandboxed(source) readonly ScratchBufferManager scratchBufferManager; readonly ILogger logger; - readonly Lua state; readonly TxnKeyEntries txnKeyEntries; readonly bool txnMode; - int keyLength, argvLength; + // This cannot be readonly, as it is a mutable struct + LuaStateWrapper state; - int curStackSize, curStackTop; + int keyLength, argvLength; /// /// Creates a new runner with the source of the script /// public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null) { - const int NeededStackSize = 1; - this.source = source; this.txnMode = txnMode; this.respServerSession = respServerSession; @@ -261,13 +260,8 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe functionRegistryIndex = -1; // TODO: custom allocator? - state = new Lua(); - AssertLuaStackEmpty(); - - ForceGrowLuaStack(20); - - curStackSize = 20; - + state = new LuaStateWrapper(new Lua()); + if (txnMode) { txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); @@ -279,13 +273,13 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe garnetCall = garnet_call; } - var loadRes = CheckedLoadBuffer(LoaderBlockBytes.Span); + var loadRes = state.LoadBuffer(LoaderBlockBytes.Span); if (loadRes != LuaStatus.OK) { throw new GarnetException("Could load loader into Lua"); } - var sandboxRes = CheckedPCall(0, -1); + var sandboxRes = state.PCall(0, -1); if (sandboxRes != LuaStatus.OK) { throw new GarnetException("Could not initialize Lua sandbox state"); @@ -294,31 +288,31 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe // Register garnet_call in global namespace state.Register("garnet_call", garnetCall); - CheckedGetGlobal(LuaType.Table, "sandbox_env"); - sandboxEnvRegistryIndex = CheckedRef(); + state.GetGlobal(LuaType.Table, "sandbox_env"); + sandboxEnvRegistryIndex = state.Ref(); - CheckedGetGlobal(LuaType.Table, "KEYS"); - keysTableRegistryIndex = CheckedRef(); + state.GetGlobal(LuaType.Table, "KEYS"); + keysTableRegistryIndex = state.Ref(); - CheckedGetGlobal(LuaType.Table, "ARGV"); - argvTableRegistryIndex = CheckedRef(); + state.GetGlobal(LuaType.Table, "ARGV"); + argvTableRegistryIndex = state.Ref(); - CheckedGetGlobal(LuaType.Function, "load_sandboxed"); - loadSandboxedRegistryIndex = CheckedRef(); + state.GetGlobal(LuaType.Function, "load_sandboxed"); + loadSandboxedRegistryIndex = state.Ref(); - CheckedGetGlobal(LuaType.Function, "reset_keys_and_argv"); - resetKeysAndArgvRegistryIndex = CheckedRef(); + state.GetGlobal(LuaType.Function, "reset_keys_and_argv"); + resetKeysAndArgvRegistryIndex = state.Ref(); // Commonly used strings, register them once so we don't have to copy them over each time we need them - okConstStringRegisteryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_OK); - errConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_err); - noSessionAvailableConstStringRegisteryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_No_session_available); - pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call); - errNoAuthConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.RESP_ERR_NOAUTH); - errUnknownConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_ERR_Unknown_Redis_command_called_from_script); - errBadArgConstStringRegistryIndex = ConstantStringToRegistery(NeededStackSize, CmdStrings.LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers); - - AssertLuaStackEmpty(); + okConstStringRegisteryIndex = ConstantStringToRegistery(CmdStrings.LUA_OK); + errConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_err); + noSessionAvailableConstStringRegisteryIndex = ConstantStringToRegistery(CmdStrings.LUA_No_session_available); + pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call); + errNoAuthConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.RESP_ERR_NOAUTH); + errUnknownConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_ERR_Unknown_Redis_command_called_from_script); + errBadArgConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers); + + state.ExpectLuaStackEmpty(); } /// @@ -334,12 +328,10 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ /// /// So instead we stash them in the Registry and load them by index /// - int ConstantStringToRegistery(int top, ReadOnlySpan str) + int ConstantStringToRegistery(ReadOnlySpan str) { - AssertLuaStackEmpty(); - - CheckedPushBuffer(top, str); - return CheckedRef(); + state.PushBuffer(str); + return state.Ref(); } /// @@ -395,21 +387,19 @@ unsafe void CompileCommon(ref TResponse resp) Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times"); - AssertLuaStackEmpty(); - curStackTop = 0; + state.ExpectLuaStackEmpty(); try { - ForceGrowLuaStack(NeededStackSpace); + state.ForceMinimumStackCapacity(NeededStackSpace); - CheckedPushNumber(NeededStackSpace, loadSandboxedRegistryIndex); - CheckedGetTable(LuaType.Function, (int)LuaRegistry.Index); + state.PushNumber(loadSandboxedRegistryIndex); + state.GetTable(LuaType.Function, (int)LuaRegistry.Index); - CheckedPushBuffer(NeededStackSpace, source.Span); - CheckedCall(1, -1); // Multiple returns allowed + state.PushBuffer(source.Span); + state.Call(1, -1); // Multiple returns allowed - var numRets = state.GetTop(); - curStackTop = numRets; + var numRets = state.StackTop; if (numRets == 0) { @@ -430,23 +420,23 @@ unsafe void CompileCommon(ref TResponse resp) return; } - functionRegistryIndex = CheckedRef(); + functionRegistryIndex = state.Ref(); } else if (numRets == 2) { - NativeMethods.CheckBuffer(state.Handle, 2, out var errorBuf); + state.CheckBuffer(2, out var errorBuf); var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}"; while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - CheckedPop(2); + state.Pop(2); return; } else { - CheckedPop(numRets); + state.Pop(numRets); throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}"); } @@ -458,7 +448,7 @@ unsafe void CompileCommon(ref TResponse resp) } finally { - AssertLuaStackEmpty(); + state.ExpectLuaStackEmpty(); } } @@ -467,7 +457,7 @@ unsafe void CompileCommon(ref TResponse resp) /// public void Dispose() { - state?.Dispose(); + state.Dispose(); } /// @@ -475,7 +465,7 @@ public void Dispose() /// public int garnet_call(IntPtr luaStatePtr) { - Debug.Assert(state.Handle == luaStatePtr, "Unexpected state provided in call"); + state.CallFromLuaEntered(luaStatePtr); if (respServerSession == null) { @@ -490,7 +480,7 @@ public int garnet_call(IntPtr luaStatePtr) /// public int garnet_call_txn(IntPtr luaStatePtr) { - Debug.Assert(state.Handle == luaStatePtr, "Unexpected state provided in call"); + state.CallFromLuaEntered(luaStatePtr); if (respServerSession == null) { @@ -509,9 +499,9 @@ int NoSessionResponse() { const int NeededStackSpace = 1; - ForceGrowLuaStack(NeededStackSpace); + state.ForceMinimumStackCapacity(NeededStackSpace); - CheckedPushNil(NeededStackSpace); + state.PushNil(); return 1; } @@ -523,25 +513,20 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) { const int AdditionalStackSpace = 1; - // This is LUA_MINSTACK, which is 20 - curStackSize = 20; - curStackTop = state.GetTop(); - try { - var argCount = curStackTop; + var argCount = state.StackTop; if (argCount == 0) { - return LuaStaticError(argCount, pleaseSpecifyRedisCallConstStringRegistryIndex); + return LuaStaticError(pleaseSpecifyRedisCallConstStringRegistryIndex); } - ForceGrowLuaStack(AdditionalStackSpace); - var neededStackSpace = argCount + AdditionalStackSpace; + state.ForceMinimumStackCapacity(AdditionalStackSpace); - if (!NativeMethods.CheckBuffer(state.Handle, 1, out var cmdSpan)) + if (!state.CheckBuffer(1, out var cmdSpan)) { - return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); + return LuaStaticError(errBadArgConstStringRegistryIndex); } // We special-case a few performance-sensitive operations to directly invoke via the storage API @@ -549,12 +534,12 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) { if (!respServerSession.CheckACLPermissions(RespCommand.SET)) { - return LuaStaticError(neededStackSpace, errNoAuthConstStringRegistryIndex); + return LuaStaticError(errNoAuthConstStringRegistryIndex); } - if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan) || !NativeMethods.CheckBuffer(state.Handle, 3, out var valSpan)) + if (!state.CheckBuffer(2, out var keySpan) || !state.CheckBuffer(3, out var valSpan)) { - return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); + return LuaStaticError(errBadArgConstStringRegistryIndex); } // Note these spans are implicitly pinned, as they're actually on the Lua stack @@ -563,19 +548,19 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) _ = api.SET(key, value); - CheckedPushConstantString(neededStackSpace, okConstStringRegisteryIndex); + state.PushConstantString(okConstStringRegisteryIndex); return 1; } else if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "GET"u8) && argCount == 2) { if (!respServerSession.CheckACLPermissions(RespCommand.GET)) { - return LuaStaticError(neededStackSpace, errNoAuthConstStringRegistryIndex); + return LuaStaticError(errNoAuthConstStringRegistryIndex); } - if (!NativeMethods.CheckBuffer(state.Handle, 2, out var keySpan)) + if (!state.CheckBuffer(2, out var keySpan)) { - return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); + return LuaStaticError(errBadArgConstStringRegistryIndex); } // Span is (implicitly) pinned since it's actually on the Lua stack @@ -583,11 +568,11 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) var status = api.GET(key, out var value); if (status == GarnetStatus.OK) { - CheckedPushBuffer(neededStackSpace, value.ReadOnlySpan); + state.PushBuffer(value.ReadOnlySpan); } else { - CheckedPushNil(neededStackSpace); + state.PushNil(); } return 1; @@ -613,14 +598,14 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // KnownStringToBuffer will coerce a number into a string // // Redis nominally converts numbers to integers, but in this case just ToStrings things - NativeMethods.KnownStringToBuffer(state.Handle, argIx, out var span); + state.KnownStringToBuffer(argIx, out var span); // Span remains pinned so long as we don't pop the stack scratchBufferManager.WriteArgument(span); } else { - return LuaStaticError(neededStackSpace, errBadArgConstStringRegistryIndex); + return LuaStaticError(errBadArgConstStringRegistryIndex); } } @@ -629,7 +614,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) // Once the request is formatted, we can release all the args on the Lua stack // // This keeps the stack size down for processing the response - CheckedPop(argCount); + state.Pop(argCount); _ = respServerSession.TryConsumeMessages(request.ptr, request.length); @@ -642,30 +627,21 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) { logger?.LogError(e, "During Lua script execution"); - // Clear the stack - state.SetTop(0); - curStackTop = 0; - - ForceGrowLuaStack(1); - - // TODO: Remove alloc - var b = Encoding.UTF8.GetBytes(e.Message); - CheckedPushBuffer(AdditionalStackSpace, b); - return state.Error(); + return state.RaiseError(e.Message); } } /// /// Cause a Lua error to be raised with a message previously registered. /// - int LuaStaticError(int top, int constStringRegistryIndex) + int LuaStaticError(int constStringRegistryIndex) { const int NeededStackSize = 1; - ForceGrowLuaStack(NeededStackSize); + state.ForceMinimumStackCapacity(NeededStackSize); - CheckedPushConstantString(top + NeededStackSize, constStringRegistryIndex); - return state.Error(); + state.PushConstantString(constStringRegistryIndex); + return state.RaiseErrorFromStack(); } /// @@ -677,9 +653,7 @@ unsafe int ProcessResponse(byte* ptr, int length) { const int NeededStackSize = 3; - AssertLuaStackEmpty(); - - ForceGrowLuaStack(NeededStackSize); + state.ForceMinimumStackCapacity(NeededStackSize); switch (*ptr) { @@ -688,7 +662,7 @@ unsafe int ProcessResponse(byte* ptr, int length) length--; if (RespReadUtils.ReadAsSpan(out var resultSpan, ref ptr, ptr + length)) { - CheckedPushBuffer(NeededStackSize, resultSpan); + state.PushBuffer(resultSpan); return 1; } goto default; @@ -696,7 +670,7 @@ unsafe int ProcessResponse(byte* ptr, int length) case (byte)':': if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) { - CheckedPushNumber(NeededStackSize, number); + state.PushNumber(number); return 1; } goto default; @@ -709,11 +683,11 @@ unsafe int ProcessResponse(byte* ptr, int length) if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD)) { // Gets a special response - return LuaStaticError(NeededStackSize, errUnknownConstStringRegistryIndex); + return LuaStaticError(errUnknownConstStringRegistryIndex); } - CheckedPushBuffer(NeededStackSize, errSpan); - return state.Error(); + state.PushBuffer(errSpan); + return state.RaiseErrorFromStack(); } goto default; @@ -723,13 +697,13 @@ unsafe int ProcessResponse(byte* ptr, int length) { // Bulk null strings are mapped to FALSE // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - CheckedPushBoolean(NeededStackSize, false); + state.PushBoolean(false); return 1; } else if (RespReadUtils.ReadSpanWithLengthHeader(out var bulkSpan, ref ptr, ptr + length)) { - CheckedPushBuffer(NeededStackSize, bulkSpan); + state.PushBuffer(bulkSpan); return 1; } @@ -739,7 +713,7 @@ unsafe int ProcessResponse(byte* ptr, int length) if (RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref ptr, ptr + length)) { // Create the new table - CheckedCreateTable(itemCount, 0); + state.CreateTable(itemCount, 0); for (var itemIx = 0; itemIx < itemCount; itemIx++) { @@ -750,16 +724,16 @@ unsafe int ProcessResponse(byte* ptr, int length) { // Null strings are mapped to false // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - CheckedPushBoolean(NeededStackSize, false); + state.PushBoolean(false); } else if (RespReadUtils.ReadSpanWithLengthHeader(out var strSpan, ref ptr, ptr + length)) { - CheckedPushBuffer(NeededStackSize, strSpan); + state.PushBuffer(strSpan); } else { // Error, drop the table we allocated - CheckedPop(1); + state.Pop(1); goto default; } } @@ -771,7 +745,7 @@ unsafe int ProcessResponse(byte* ptr, int length) } // Stack now has table and value at itemIx on it - CheckedRawSetInteger(1, itemIx + 1); + state.RawSetInteger(1, itemIx + 1); } return 1; @@ -792,9 +766,7 @@ public void RunForSession(int count, RespServerSession outerSession) { const int NeededStackSize = 3; - AssertLuaStackEmpty(); - - ForceGrowLuaStack(NeededStackSize); + state.ForceMinimumStackCapacity(NeededStackSize); scratchBufferManager.Reset(); @@ -808,8 +780,8 @@ public void RunForSession(int count, RespServerSession outerSession) if (nKeys > 0) { // Get KEYS on the stack - CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); - CheckedRawGet(LuaType.Table, (int)LuaRegistry.Index); + state.PushNumber(keysTableRegistryIndex); + state.RawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < nKeys; i++) { @@ -823,15 +795,15 @@ public void RunForSession(int count, RespServerSession outerSession) } // Equivalent to KEYS[i+1] = key - CheckedPushNumber(NeededStackSize, i + 1); - CheckedPushBuffer(NeededStackSize, key.ReadOnlySpan); - CheckedRawSet(1); + state.PushNumber(i + 1); + state.PushBuffer(key.ReadOnlySpan); + state.RawSet(1); offset++; } // Remove KEYS from the stack - CheckedPop(1); + state.Pop(1); count -= nKeys; } @@ -839,27 +811,25 @@ public void RunForSession(int count, RespServerSession outerSession) if (count > 0) { // Get ARGV on the stack - CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); - CheckedRawGet(LuaType.Table, (int)LuaRegistry.Index); + state.PushNumber(argvTableRegistryIndex); + state.RawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < count; i++) { ref var argv = ref parseState.GetArgSliceByRef(offset); // Equivalent to ARGV[i+1] = argv - CheckedPushNumber(NeededStackSize, i + 1); - CheckedPushBuffer(NeededStackSize, argv.ReadOnlySpan); - CheckedRawSet(1); + state.PushNumber(i + 1); + state.PushBuffer(argv.ReadOnlySpan); + state.RawSet(1); offset++; } // Remove ARGV from the stack - CheckedPop(1); + state.Pop(1); } - AssertLuaStackEmpty(); - var adapter = new RespResponseAdapter(outerSession); if (txnMode && nKeys > 0) @@ -1004,25 +974,21 @@ internal void ResetParameters(int nKeys, int nArgs) { const int NeededStackSize = 3; - AssertLuaStackEmpty(); - - ForceGrowLuaStack(NeededStackSize); + state.ForceMinimumStackCapacity(NeededStackSize); if (keyLength > nKeys || argvLength > nArgs) { - CheckedRawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex); + state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex); - CheckedPushNumber(NeededStackSize, nKeys + 1); - CheckedPushNumber(NeededStackSize, nArgs + 1); + state.PushNumber(nKeys + 1); + state.PushNumber(nArgs + 1); - var resetRes = CheckedPCall(2, 0); + var resetRes = state.PCall(2, 0); Debug.Assert(resetRes == LuaStatus.OK, "Resetting should never fail"); } keyLength = nKeys; argvLength = nArgs; - - AssertLuaStackEmpty(); } /// @@ -1032,50 +998,46 @@ void LoadParametersForRunner(string[] keys, string[] argv) { const int NeededStackSize = 2; - AssertLuaStackEmpty(); - - ForceGrowLuaStack(NeededStackSize); + state.ForceMinimumStackCapacity(NeededStackSize); ResetParameters(keys?.Length ?? 0, argv?.Length ?? 0); if (keys != null) { // get KEYS on the stack - CheckedPushNumber(NeededStackSize, keysTableRegistryIndex); - CheckedGetTable(LuaType.Table, (int)LuaRegistry.Index); + state.PushNumber(keysTableRegistryIndex); + state.GetTable(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < keys.Length; i++) { // equivalent to KEYS[i+1] = keys[i] var key = keys[i]; PrepareString(key, scratchBufferManager, out var encoded); - CheckedPushBuffer(NeededStackSize, encoded); - CheckedRawSetInteger(1, i + 1); + state.PushBuffer(encoded); + state.RawSetInteger(1, i + 1); } - CheckedPop(1); + state.Pop(1); } if (argv != null) { // get ARGV on the stack - CheckedPushNumber(NeededStackSize, argvTableRegistryIndex); - CheckedGetTable(LuaType.Table, (int)LuaRegistry.Index); + state.PushNumber(argvTableRegistryIndex); + state.GetTable(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < argv.Length; i++) { // equivalent to ARGV[i+1] = keys[i] var arg = argv[i]; PrepareString(arg, scratchBufferManager, out var encoded); - CheckedPushBuffer(NeededStackSize, encoded); - CheckedRawSetInteger(1, i + 1); + state.PushBuffer(encoded); + state.RawSetInteger(1, i + 1); } - CheckedPop(1); + state.Pop(1); } - AssertLuaStackEmpty(); - static void PrepareString(string raw, ScratchBufferManager buffer, out ReadOnlySpan strBytes) { var maxLen = Encoding.UTF8.GetMaxByteCount(raw.Length); @@ -1100,21 +1062,19 @@ unsafe void RunCommon(ref TResponse resp) // TODO: mapping is dependent on Resp2 vs Resp3 settings // and that's not implemented at all - AssertLuaStackEmpty(); - try { - ForceGrowLuaStack(NeededStackSize); + state.ForceMinimumStackCapacity(NeededStackSize); - CheckedPushNumber(NeededStackSize, functionRegistryIndex); - CheckedGetTable(LuaType.Function, (int)LuaRegistry.Index); + state.PushNumber(functionRegistryIndex); + _ = state.GetTable(LuaType.Function, (int)LuaRegistry.Index); - var callRes = CheckedPCall(0, 1); + var callRes = state.PCall(0, 1); if (callRes == LuaStatus.OK) { // The actual call worked, handle the response - if (curStackTop == 0) + if (state.StackTop == 0) { WriteNull(this, ref resp); return; @@ -1150,21 +1110,21 @@ unsafe void RunCommon(ref TResponse resp) // so we need a test that use metatables (and compare to how Redis does this) // If the key err is in there, we need to short circuit - CheckedPushConstantString(NeededStackSize, errConstStringRegistryIndex); + state.PushConstantString(errConstStringRegistryIndex); - var errType = CheckedGetTable(null, 1); + var errType =state.GetTable(null, 1); if (errType == LuaType.String) { WriteError(this, ref resp); // Remove table from stack - CheckedPop(1); + state.Pop(1); return; } // Remove whatever we read from the table under the "err" key - CheckedPop(1); + state.Pop(1); // Map this table to an array WriteArray(this, ref resp); @@ -1174,22 +1134,22 @@ unsafe void RunCommon(ref TResponse resp) { // An error was raised - if (curStackTop == 0) + if (state.StackTop == 0) { while (!RespWriteUtils.WriteError("ERR An error occurred while invoking a Lua script"u8, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); return; } - else if (curStackTop == 1) + else if (state.StackTop == 1) { - if (NativeMethods.CheckBuffer(state.Handle, 1, out var errBuf)) + if (state.CheckBuffer(1, out var errBuf)) { while (!RespWriteUtils.WriteError(errBuf, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); } - CheckedPop(1); + state.Pop(1); return; } @@ -1200,8 +1160,7 @@ unsafe void RunCommon(ref TResponse resp) while (!RespWriteUtils.WriteError("ERR Unexpected error response"u8, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - state.SetTop(0); - curStackTop = 0; + state.ClearStack(); return; } @@ -1209,7 +1168,7 @@ unsafe void RunCommon(ref TResponse resp) } finally { - AssertLuaStackEmpty(); + state.ExpectLuaStackEmpty(); } // Write a null RESP value, remove the top value on the stack if there is one @@ -1219,48 +1178,48 @@ static void WriteNull(LuaRunner runner, ref TResponse resp) resp.SendAndReset(); // The stack _could_ be empty if we're writing a null, so check before popping - if (runner.curStackTop != 0) + if (runner.state.StackTop != 0) { - runner.CheckedPop(1); + runner.state.Pop(1); } } // Writes the number on the top of the stack, removes it from the stack static void WriteNumber(LuaRunner runner, ref TResponse resp) { - Debug.Assert(runner.state.Type(runner.curStackTop) == LuaType.Number, "Number was not on top of stack"); + Debug.Assert(runner.state.Type(runner.state.StackTop) == LuaType.Number, "Number was not on top of stack"); // Redis unconditionally converts all "number" replies to integer replies so we match that // // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - var num = (long)runner.state.CheckNumber(runner.curStackTop); + var num = (long)runner.state.CheckNumber(runner.state.StackTop); while (!RespWriteUtils.WriteInteger(num, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - runner.CheckedPop(1); + runner.state.Pop(1); } // Writes the string on the top of the stack, removes it from the stack static void WriteString(LuaRunner runner, ref TResponse resp) { - NativeMethods.KnownStringToBuffer(runner.state.Handle, runner.curStackTop, out var buf); + runner.state.KnownStringToBuffer(runner.state.StackTop, out var buf); while (!RespWriteUtils.WriteBulkString(buf, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - runner.CheckedPop(1); + runner.state.Pop(1); } // Writes the boolean on the top of the stack, removes it from the stack static void WriteBoolean(LuaRunner runner, ref TResponse resp) { - Debug.Assert(runner.state.Type(runner.curStackTop) == LuaType.Boolean, "Boolean was not on top of stack"); + Debug.Assert(runner.state.Type(runner.state.StackTop) == LuaType.Boolean, "Boolean was not on top of stack"); // Redis maps Lua false to null, and Lua true to 1 this is strange, but documented // // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion - if (runner.state.ToBoolean(runner.curStackTop)) + if (runner.state.ToBoolean(runner.state.StackTop)) { while (!RespWriteUtils.WriteInteger(1, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); @@ -1271,18 +1230,18 @@ static void WriteBoolean(LuaRunner runner, ref TResponse resp) resp.SendAndReset(); } - runner.CheckedPop(1); + runner.state.Pop(1); } // Writes the string on the top of the stack out as an error, removes the string from the stack static void WriteError(LuaRunner runner, ref TResponse resp) { - NativeMethods.KnownStringToBuffer(runner.state.Handle, runner.curStackTop, out var errBuff); + runner.state.KnownStringToBuffer(runner.state.StackTop, out var errBuff); while (!RespWriteUtils.WriteError(errBuff, ref resp.BufferCur, resp.BufferEnd)) resp.SendAndReset(); - runner.CheckedPop(1); + runner.state.Pop(1); } static void WriteArray(LuaRunner runner, ref TResponse resp) @@ -1290,19 +1249,19 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) // 1 for the table, 1 for the pending value const int AdditonalNeededStackSize = 2; - Debug.Assert(runner.state.Type(runner.curStackTop) == LuaType.Table, "Table was not on top of stack"); + Debug.Assert(runner.state.Type(runner.state.StackTop) == LuaType.Table, "Table was not on top of stack"); // Lua # operator - this MAY stop at nils, but isn't guaranteed to // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 - var maxLen = runner.state.Length(runner.curStackTop); + var maxLen = runner.state.Length(runner.state.StackTop); // TODO: is it faster to punch a function in for this? // Find the TRUE length by scanning for nils var trueLen = 0; for (trueLen = 0; trueLen < maxLen; trueLen++) { - var type = runner.CheckedGetInteger(null, runner.curStackTop, trueLen + 1); - runner.CheckedPop(1); + var type = runner.state.GetInteger(null, runner.state.StackTop, trueLen + 1); + runner.state.Pop(1); if (type == LuaType.Nil) { @@ -1316,7 +1275,7 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) for (var i = 1; i <= trueLen; i++) { // Push item at index i onto the stack - var type = runner.CheckedGetInteger(null, runner.curStackTop, i); + var type = runner.state.GetInteger(null, runner.state.StackTop, i); switch (type) { @@ -1329,28 +1288,10 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) case LuaType.Boolean: WriteBoolean(runner, ref resp); break; - - case LuaType.Table: // For tables, we need to recurse - which means we need to check stack sizes again - if (runner.curStackSize < runner.curStackTop + AdditonalNeededStackSize) - { - try - { - runner.ForceGrowLuaStack(AdditonalNeededStackSize); - } - catch - { - // This is the only place we can raise an exception, cull the Stack - runner.state.SetTop(0); - runner.curStackTop = 0; - - throw; - } - } - + runner.state.ForceMinimumStackCapacity(AdditonalNeededStackSize); WriteArray(runner, ref resp); - break; // All other Lua types map to nulls @@ -1360,367 +1301,10 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) } } - runner.CheckedPop(1); - } - } - - // TODO: I think we'd prefer all these helpers factor into their own file - - /// - /// Ensure there's enough space on the Lua stack for more items. - /// - /// Throws if there is not. - /// - /// Maintains to avoid unnecessary p/invokes. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void ForceGrowLuaStack(int additionalCapacity) - { - var availableSpace = curStackSize - curStackTop; - - if (availableSpace >= additionalCapacity) - { - return; - } - - var needed = additionalCapacity - availableSpace; - if (!state.CheckStack(needed)) - { - throw new GarnetException("Could not reserve additional capacity on the Lua stack"); - } - - curStackSize += additionalCapacity; - } - - /// - /// This should be used for all PushBuffer calls into Lua. - /// - /// If the string is a constant, consider registering it in the constructor and using instead. - /// - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedPushBuffer(int reservedCapacity, ReadOnlySpan buffer, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - AssertLuaStackBelow(reservedCapacity, file, method, line); - - NativeMethods.PushBuffer(state.Handle, buffer); - curStackTop++; - } - - /// - /// This should be used for all PushNil calls into Lua. - /// - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedPushNil(int reservedCapacity, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - AssertLuaStackBelow(reservedCapacity, file, method, line); - - state.PushNil(); - curStackTop++; - } - - /// - /// This should be used for all PushNumber calls into Lua. - /// - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedPushNumber(int reservedCapacity, double number, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - AssertLuaStackBelow(reservedCapacity, file, method, line); - - state.PushNumber(number); - curStackTop++; - } - - /// - /// This should be used for all PushBoolean calls into Lua. - /// - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedPushBoolean(int reservedCapacity, bool b, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - AssertLuaStackBelow(reservedCapacity, file, method, line); - - state.PushBoolean(b); - curStackTop++; - } - - /// - /// This should be used for all Pop calls into Lua. - /// - /// Maintains to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedPop(int num) - { - state.Pop(num); - curStackTop -= num; - - AssertLuaStackExpected(); - } - - /// - /// This should be used for all Calls into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedCall(int args, int rets) - { - var oldStackTop = curStackTop; - state.Call(args, rets); - - if (rets < 0) - { - curStackTop = state.GetTop(); - } - else - { - curStackTop = oldStackTop - (args + 1) + rets; + runner.state.Pop(1); } - - AssertLuaStackExpected(); } - /// - /// This should be used for all PCalls into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private LuaStatus CheckedPCall(int args, int rets) - { - var oldStack = curStackTop; - var res = state.PCall(args, rets, 0); - - if (res != LuaStatus.OK || rets < 0) - { - curStackTop = state.GetTop(); - } - else - { - curStackTop = oldStack - (args + 1) + rets; - } - - AssertLuaStackExpected(); - - return res; - } - - /// - /// This should be used for all RawSetIntegers into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedRawSetInteger(int stackIndex, int tableIndex) - { - state.RawSetInteger(stackIndex, tableIndex); - curStackTop--; - - AssertLuaStackExpected(); - } - - /// This should be used for all RawSets into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedRawSet(int stackIndex) - { - state.RawSet(stackIndex); - curStackTop -= 2; - - AssertLuaStackExpected(); - } - - /// - /// This should be used for all RawGetIntegers into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedRawGetInteger(LuaType expectedType, int stackIndex, int tableIndex) - { - var actual = state.RawGetInteger(stackIndex, tableIndex); - Debug.Assert(actual == expectedType, "Unexpected type received"); - curStackTop++; - - AssertLuaStackExpected(); - } - - /// - /// This should be used for all GetIntegers into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private LuaType CheckedGetInteger(LuaType? expectedType, int stackIndex, int tableIndex) - { - var actual = state.GetInteger(stackIndex, tableIndex); - Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); - curStackTop++; - - AssertLuaStackExpected(); - - return actual; - } - - /// - /// This should be used for all RawGets into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedRawGet(LuaType expectedType, int stackIndex) - { - var actual = state.RawGet(stackIndex); - Debug.Assert(actual == expectedType, "Unexpected type received"); - - AssertLuaStackExpected(); - } - - /// - /// This should be used for all GetTables into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private LuaType CheckedGetTable(LuaType? expectedType, int stackIndex) - { - var actual = state.GetTable(stackIndex); - Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); - - AssertLuaStackExpected(); - - return actual; - } - - /// - /// This should be used for all Refs into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private int CheckedRef() - { - var ret = state.Ref(LuaRegistry.Index); - curStackTop--; - - AssertLuaStackExpected(); - - return ret; - } - - /// - /// This should be used for all CreateTables into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedCreateTable(int numArr, int numRec) - { - state.CreateTable(numArr, numRec); - curStackTop++; - - AssertLuaStackExpected(); - } - - /// - /// This should be used for all GetGlobals into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - private void CheckedGetGlobal(LuaType expectedType, string globalName) - { - var type = state.GetGlobal(globalName); - Debug.Assert(type == expectedType, "Unexpected type received"); - - curStackTop++; - - AssertLuaStackExpected(); - } - - /// - /// This should be used for all LoadBuffers into Lua. - /// - /// Note that this is different from pushing a buffer, as the loaded buffer is compiled. - /// - /// Maintains and to minimize p/invoke calls. - /// - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private LuaStatus CheckedLoadBuffer(ReadOnlySpan buffer) - { - var ret = NativeMethods.LoadBuffer(state.Handle, buffer); - curStackTop++; - - AssertLuaStackExpected(); - - return ret; - } - - /// - /// This should be used to push all known constants strings (registered in constructor with ) - /// into Lua. - /// - /// This avoids extra copying of data between .NET and Lua. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private void CheckedPushConstantString(int reservedCapacity, int constStringRegistryIndex, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - AssertLuaStackBelow(reservedCapacity, file, method, line); - Debug.Assert(IsConstantStringRegistryIndex(constStringRegistryIndex), "Can't use this with unknown string"); - - CheckedRawGetInteger(LuaType.String, (int)LuaRegistry.Index, constStringRegistryIndex); - - // Check if index corresponds to value registered in constructor - bool IsConstantStringRegistryIndex(int index) - => index == okConstStringRegisteryIndex || - index == errConstStringRegistryIndex || - index == noSessionAvailableConstStringRegisteryIndex || - index == pleaseSpecifyRedisCallConstStringRegistryIndex || - index == errNoAuthConstStringRegistryIndex || - index == errUnknownConstStringRegistryIndex || - index == errBadArgConstStringRegistryIndex; - } - - /// - /// Check that the Lua stack is empty in DEBUG builds. - /// - /// This is never necessary for correctness, but is often useful to find logical bugs. - /// - [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] - private void AssertLuaStackEmpty([CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - Debug.Assert(state.GetTop() == 0, $"Lua stack not empty when expected ({method}:{line} in {file})"); - } - - /// - /// Check that the Lua stack top is where expected in DEBUG builds. - /// - [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] - private void AssertLuaStackExpected([CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - Debug.Assert(state.GetTop() == curStackTop, $"Lua stack not where expected ({method}:{line} in {file})"); - } - - /// - /// Check the Lua stack has not grown beyond the capacity we initially reserved. - /// - /// This asserts (in DEBUG) that the next .PushXXX will succeed. - /// - /// In practice, Lua almost always gives us enough space (default is ~20 slots) but that's not guaranteed and can be false - /// for complicated redis.call invocations. - /// - [Conditional("DEBUG")] - [MethodImpl(MethodImplOptions.NoInlining)] - private void AssertLuaStackBelow(int reservedCapacity, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) - { - Debug.Assert(state.GetTop() < reservedCapacity, $"About to push to Lua stack without having reserved sufficient capacity."); - } + } } \ No newline at end of file diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs new file mode 100644 index 0000000000..157f8aab8d --- /dev/null +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -0,0 +1,552 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +using System; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Text; +using Garnet.common; +using KeraLua; + +namespace Garnet.server +{ + /// + /// For performance purposes, we need to track some additional state alongside the + /// raw Lua runtime. + /// + /// This type does that. + /// + internal struct LuaStateWrapper : IDisposable + { + private const int LUA_MINSTACK = 20; + + private readonly Lua state; + + private int curStackSize; + + internal LuaStateWrapper(Lua state) + { + this.state = state; + + curStackSize = LUA_MINSTACK; + StackTop = 0; + + AssertLuaStackExpected(); + } + + /// + public readonly void Dispose() + { + state.Dispose(); + } + + /// + /// Current top item in the stack. + /// + /// 0 implies the stack is empty. + /// + internal int StackTop { get; private set; } + + /// + /// Call when ambient state indicates that the Lua stack is in fact empty. + /// + /// Maintains to avoid unnecessary p/invokes. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ExpectLuaStackEmpty() + { + StackTop = 0; + AssertLuaStackExpected(); + } + + /// + /// Ensure there's enough space on the Lua stack for more items. + /// + /// Throws if there is not. + /// + /// Maintains to avoid unnecessary p/invokes. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void ForceMinimumStackCapacity(int additionalCapacity) + { + var availableSpace = curStackSize - StackTop; + + if (availableSpace >= additionalCapacity) + { + return; + } + + var needed = additionalCapacity - availableSpace; + if (!state.CheckStack(needed)) + { + throw new GarnetException("Could not reserve additional capacity on the Lua stack"); + } + + curStackSize += additionalCapacity; + } + + /// + /// Call when the Lua runtime calls back into .NET code. + /// + /// Figures out the state of the Lua stack once, to avoid unnecessary p/invokes. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void CallFromLuaEntered(IntPtr luaStatePtr) + { + Debug.Assert(luaStatePtr == state.Handle, "Unexpected Lua state presented"); + + StackTop = state.GetTop(); + curStackSize = StackTop > LUA_MINSTACK ? StackTop : LUA_MINSTACK; + } + + /// + /// This should be used for all CheckBuffer calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly bool CheckBuffer(int index, out ReadOnlySpan str) + { + AssertLuaStackIndexInBounds(index); + + return NativeMethods.CheckBuffer(state.Handle, index, out str); + } + + /// + /// This should be used for all Type calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly LuaType Type(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return state.Type(stackIndex); + } + + /// + /// This should be used for all PushBuffer calls into Lua. + /// + /// If the string is a constant, consider registering it in the constructor and using instead. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushBuffer(ReadOnlySpan buffer) + { + AssertLuaStackNotFull(); + + NativeMethods.PushBuffer(state.Handle, buffer); + UpdateStackTop(1); + } + + /// + /// This should be used for all PushNil calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushNil() + { + AssertLuaStackNotFull(); + + state.PushNil(); + UpdateStackTop(1); + } + + /// + /// This should be used for all PushNumber calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushNumber(double number) + { + AssertLuaStackNotFull(); + + state.PushNumber(number); + UpdateStackTop(1); + } + + /// + /// This should be used for all PushBoolean calls into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushBoolean(bool b) + { + AssertLuaStackNotFull(); + + state.PushBoolean(b); + UpdateStackTop(1); + } + + /// + /// This should be used for all Pop calls into Lua. + /// + /// Maintains to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Pop(int num) + { + state.Pop(num); + UpdateStackTop(-num); + } + + /// + /// This should be used for all Calls into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Call(int args, int rets) + { + // We have to copy this off, as once we Call curStackTop could be modified + var oldStackTop = StackTop; + state.Call(args, rets); + + if (rets < 0) + { + StackTop = state.GetTop(); + AssertLuaStackExpected(); + } + else + { + var newPosition = oldStackTop - (args + 1) + rets; + var update = newPosition - StackTop; + UpdateStackTop(update); + } + } + + /// + /// This should be used for all PCalls into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal LuaStatus PCall(int args, int rets) + { + // We have to copy this off, as once we Call curStackTop could be modified + var oldStackTop = StackTop; + var res = state.PCall(args, rets, 0); + + if (res != LuaStatus.OK || rets < 0) + { + StackTop = state.GetTop(); + AssertLuaStackExpected(); + } + else + { + var newPosition = oldStackTop - (args + 1) + rets; + var update = newPosition - StackTop; + UpdateStackTop(update); + } + + return res; + } + + /// + /// This should be used for all RawSetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void RawSetInteger(int stackIndex, int tableIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + state.RawSetInteger(stackIndex, tableIndex); + UpdateStackTop(-1); + } + + /// + /// This should be used for all RawSets into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void RawSet(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + state.RawSet(stackIndex); + UpdateStackTop(-2); + } + + /// + /// This should be used for all RawGetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void RawGetInteger(LuaType expectedType, int stackIndex, int tableIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + AssertLuaStackNotFull(); + + var actual = state.RawGetInteger(stackIndex, tableIndex); + Debug.Assert(actual == expectedType, "Unexpected type received"); + + UpdateStackTop(1); + } + + /// + /// This should be used for all GetIntegers into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal LuaType GetInteger(LuaType? expectedType, int stackIndex, int tableIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + AssertLuaStackNotFull(); + + var actual = state.GetInteger(stackIndex, tableIndex); + Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); + + UpdateStackTop(1); + + return actual; + } + + /// + /// This should be used for all RawGets into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly void RawGet(LuaType expectedType, int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + var actual = state.RawGet(stackIndex); + Debug.Assert(actual == expectedType, "Unexpected type received"); + + AssertLuaStackExpected(); + } + + /// + /// This should be used for all GetTables into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly LuaType GetTable(LuaType? expectedType, int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + var actual = state.GetTable(stackIndex); + Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); + + AssertLuaStackExpected(); + + return actual; + } + + /// + /// This should be used for all Refs into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal int Ref() + { + var ret = state.Ref(LuaRegistry.Index); + UpdateStackTop(-1); + + return ret; + } + + /// + /// This should be used for all Unrefs into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly void Unref(LuaRegistry registery, int reference) + { + state.Unref(registery, reference); + } + + /// + /// This should be used for all CreateTables into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void CreateTable(int numArr, int numRec) + { + AssertLuaStackNotFull(); + + state.CreateTable(numArr, numRec); + UpdateStackTop(1); + } + + /// + /// This should be used for all GetGlobals into Lua. + /// + /// Maintains and to minimize p/invoke calls. + /// + internal void GetGlobal(LuaType expectedType, string globalName) + { + AssertLuaStackNotFull(); + + var type = state.GetGlobal(globalName); + Debug.Assert(type == expectedType, "Unexpected type received"); + + UpdateStackTop(1); + } + + /// + /// This should be used for all LoadBuffers into Lua. + /// + /// Note that this is different from pushing a buffer, as the loaded buffer is compiled. + /// + /// Maintains and to minimize p/invoke calls. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal LuaStatus LoadBuffer(ReadOnlySpan buffer) + { + AssertLuaStackNotFull(); + + var ret = NativeMethods.LoadBuffer(state.Handle, buffer); + + UpdateStackTop(1); + + return ret; + } + + /// + /// Call when value at index is KNOWN to be a string or number + /// + /// only remains valid as long as the buffer remains on the stack, + /// use with care. + /// + /// Note that is changes the value on the stack to be a string if it returns true, regardless of + /// what it was originally. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly void KnownStringToBuffer(int stackIndex, out ReadOnlySpan str) + { + AssertLuaStackIndexInBounds(stackIndex); + + Debug.Assert(state.Type(stackIndex) is LuaType.String or LuaType.Number, "Called with non-string, non-number"); + + NativeMethods.KnownStringToBuffer(state.Handle, stackIndex, out str); + } + + /// + /// This should be used for all CheckNumbers into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly double CheckNumber(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return state.CheckNumber(stackIndex); + } + + /// + /// This should be used for all ToBooleans into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly bool ToBoolean(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return state.ToBoolean(stackIndex); + } + + /// + /// This should be used for all Lengths into Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal readonly long Length(int stackIndex) + { + AssertLuaStackIndexInBounds(stackIndex); + + return state.Length(stackIndex); + } + + /// + /// Call to register a function in the Lua global namespace. + /// + internal readonly void Register(string name, LuaFunction func) + => state.Register(name, func); + + /// + /// This should be used to push all known constants strings into Lua. + /// + /// This avoids extra copying of data between .NET and Lua. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void PushConstantString(int constStringRegistryIndex, [CallerFilePath] string file = null, [CallerMemberName] string method = null, [CallerLineNumber] int line = -1) + => RawGetInteger(LuaType.String, (int)LuaRegistry.Index, constStringRegistryIndex); + + // Rarely used + + /// + /// Remove everything from the Lua stack. + /// + internal void ClearStack() + { + state.SetTop(0); + StackTop = 0; + + AssertLuaStackExpected(); + } + + /// + /// Clear the stack and raise an error with the given message. + /// + internal int RaiseError(string msg) + { + ClearStack(); + + var b = Encoding.UTF8.GetBytes(msg); + return RaiseErrorFromStack(); + } + + /// + /// Raise an error, where the top of the stack is the error message. + /// + internal readonly int RaiseErrorFromStack() + { + Debug.Assert(StackTop != 0, "Expected error message on the stack"); + + return state.Error(); + } + + /// + /// Helper to update . + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void UpdateStackTop(int by) + { + StackTop += by; + AssertLuaStackExpected(); + } + + // Conditional compilation checks + + /// + /// Check that the given index refers to a valid part of the stack. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly void AssertLuaStackIndexInBounds(int stackIndex) + { + Debug.Assert(stackIndex == (int)LuaRegistry.Index || (stackIndex > 0 && stackIndex <= StackTop), "Lua stack index out of bounds"); + } + + /// + /// Check that the Lua stack top is where expected in DEBUG builds. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly void AssertLuaStackExpected() + { + Debug.Assert(state.GetTop() == StackTop, "Lua stack not where expected"); + } + + /// + /// Check that there's space to push some number of elements. + /// + [Conditional("DEBUG")] + [MethodImpl(MethodImplOptions.NoInlining)] + private readonly void AssertLuaStackNotFull(int probe = 1) + { + Debug.Assert((StackTop + probe) <= curStackSize, "Lua stack should have been grown before pushing"); + } + } +} From 2fa996aacc4bff72e04aebe028b0f37d1df1789e Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 15:14:32 -0500 Subject: [PATCH 44/51] switch to LibraryImport since we're on .NET 8 --- libs/server/Lua/NativeMethods.cs | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index 702cfe05e7..f78f8cf104 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -5,7 +5,6 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using KeraLua; -using static System.Net.WebRequestMethods; using charptr_t = System.IntPtr; using lua_State = System.IntPtr; using size_t = System.UIntPtr; @@ -18,9 +17,8 @@ namespace Garnet.server /// Long term we'll want to try and push these upstreams and move to just using KeraLua, /// but for now we're just defining them ourselves. /// - internal static class NativeMethods + internal static partial class NativeMethods { - // TODO: LibraryImport? // TODO: Suppress GC transition (requires Lua audit) private const string LuaLibraryName = "lua54"; @@ -28,26 +26,30 @@ internal static class NativeMethods /// /// see: https://www.lua.org/manual/5.3/manual.html#lua_tolstring /// - [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern charptr_t lua_tolstring(lua_State L, int index, out size_t len); + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial charptr_t lua_tolstring(lua_State L, int index, out size_t len); /// /// see: https://www.lua.org/manual/5.3/manual.html#lua_type /// - [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern LuaType lua_type(lua_State L, int index); + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial LuaType lua_type(lua_State L, int index); /// /// see: https://www.lua.org/manual/5.3/manual.html#lua_pushlstring /// - [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); /// /// see: https://www.lua.org/manual/5.3/manual.html#luaL_loadbufferx /// - [DllImport(LuaLibraryName, CallingConvention = CallingConvention.Cdecl)] - private static extern LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); /// /// Returns true if the given index on the stack holds a string or a number. From f93e93f3df3d8f51c02af183feaa780781301cdd Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 16:15:52 -0500 Subject: [PATCH 45/51] do a big audit of Lua invokes and mark where GC transition can be suppressed - it's unclear if .NET is actually doing that today, but it's safe --- libs/server/Lua/LuaRunner.cs | 25 +++-- libs/server/Lua/LuaStateWrapper.cs | 28 +++--- libs/server/Lua/NativeMethods.cs | 151 ++++++++++++++++++++++++++--- 3 files changed, 165 insertions(+), 39 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index a27b8ae15c..2e8b73c331 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -2,7 +2,6 @@ // Licensed under the MIT license. using System; -using System.ComponentModel.DataAnnotations; using System.Diagnostics; using System.Linq; using System.Runtime.CompilerServices; @@ -393,8 +392,8 @@ unsafe void CompileCommon(ref TResponse resp) { state.ForceMinimumStackCapacity(NeededStackSpace); - state.PushNumber(loadSandboxedRegistryIndex); - state.GetTable(LuaType.Function, (int)LuaRegistry.Index); + state.PushInteger(loadSandboxedRegistryIndex); + _ = state.GetTable(LuaType.Function, (int)LuaRegistry.Index); state.PushBuffer(source.Span); state.Call(1, -1); // Multiple returns allowed @@ -670,7 +669,7 @@ unsafe int ProcessResponse(byte* ptr, int length) case (byte)':': if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length)) { - state.PushNumber(number); + state.PushInteger(number); return 1; } goto default; @@ -780,7 +779,7 @@ public void RunForSession(int count, RespServerSession outerSession) if (nKeys > 0) { // Get KEYS on the stack - state.PushNumber(keysTableRegistryIndex); + state.PushInteger(keysTableRegistryIndex); state.RawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < nKeys; i++) @@ -795,7 +794,7 @@ public void RunForSession(int count, RespServerSession outerSession) } // Equivalent to KEYS[i+1] = key - state.PushNumber(i + 1); + state.PushInteger(i + 1); state.PushBuffer(key.ReadOnlySpan); state.RawSet(1); @@ -811,7 +810,7 @@ public void RunForSession(int count, RespServerSession outerSession) if (count > 0) { // Get ARGV on the stack - state.PushNumber(argvTableRegistryIndex); + state.PushInteger(argvTableRegistryIndex); state.RawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < count; i++) @@ -819,7 +818,7 @@ public void RunForSession(int count, RespServerSession outerSession) ref var argv = ref parseState.GetArgSliceByRef(offset); // Equivalent to ARGV[i+1] = argv - state.PushNumber(i + 1); + state.PushInteger(i + 1); state.PushBuffer(argv.ReadOnlySpan); state.RawSet(1); @@ -980,8 +979,8 @@ internal void ResetParameters(int nKeys, int nArgs) { state.RawGetInteger(LuaType.Function, (int)LuaRegistry.Index, resetKeysAndArgvRegistryIndex); - state.PushNumber(nKeys + 1); - state.PushNumber(nArgs + 1); + state.PushInteger(nKeys + 1); + state.PushInteger(nArgs + 1); var resetRes = state.PCall(2, 0); Debug.Assert(resetRes == LuaStatus.OK, "Resetting should never fail"); @@ -1005,7 +1004,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) if (keys != null) { // get KEYS on the stack - state.PushNumber(keysTableRegistryIndex); + state.PushInteger(keysTableRegistryIndex); state.GetTable(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < keys.Length; i++) @@ -1023,7 +1022,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) if (argv != null) { // get ARGV on the stack - state.PushNumber(argvTableRegistryIndex); + state.PushInteger(argvTableRegistryIndex); state.GetTable(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < argv.Length; i++) @@ -1066,7 +1065,7 @@ unsafe void RunCommon(ref TResponse resp) { state.ForceMinimumStackCapacity(NeededStackSize); - state.PushNumber(functionRegistryIndex); + state.PushInteger(functionRegistryIndex); _ = state.GetTable(LuaType.Function, (int)LuaRegistry.Index); var callRes = state.PCall(0, 1); diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs index 157f8aab8d..0c8c7b9b22 100644 --- a/libs/server/Lua/LuaStateWrapper.cs +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -95,7 +95,7 @@ internal void CallFromLuaEntered(IntPtr luaStatePtr) { Debug.Assert(luaStatePtr == state.Handle, "Unexpected Lua state presented"); - StackTop = state.GetTop(); + StackTop = NativeMethods.GetTop(state.Handle); curStackSize = StackTop > LUA_MINSTACK ? StackTop : LUA_MINSTACK; } @@ -118,7 +118,7 @@ internal readonly LuaType Type(int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); - return state.Type(stackIndex); + return NativeMethods.Type(state.Handle, stackIndex); } /// @@ -143,19 +143,20 @@ internal void PushNil() { AssertLuaStackNotFull(); - state.PushNil(); + NativeMethods.PushNil(state.Handle); UpdateStackTop(1); } /// - /// This should be used for all PushNumber calls into Lua. + /// This should be used for all PushInteger calls into Lua. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void PushNumber(double number) + internal void PushInteger(long number) { AssertLuaStackNotFull(); - state.PushNumber(number); + NativeMethods.PushInteger(state.Handle, number); + UpdateStackTop(1); } @@ -167,7 +168,7 @@ internal void PushBoolean(bool b) { AssertLuaStackNotFull(); - state.PushBoolean(b); + NativeMethods.PushBoolean(state.Handle, b); UpdateStackTop(1); } @@ -179,7 +180,8 @@ internal void PushBoolean(bool b) [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void Pop(int num) { - state.Pop(num); + NativeMethods.Pop(state.Handle, num); + UpdateStackTop(-num); } @@ -197,7 +199,7 @@ internal void Call(int args, int rets) if (rets < 0) { - StackTop = state.GetTop(); + StackTop = NativeMethods.GetTop(state.Handle); AssertLuaStackExpected(); } else @@ -222,7 +224,7 @@ internal LuaStatus PCall(int args, int rets) if (res != LuaStatus.OK || rets < 0) { - StackTop = state.GetTop(); + StackTop = NativeMethods.GetTop(state.Handle); AssertLuaStackExpected(); } else @@ -420,7 +422,7 @@ internal readonly void KnownStringToBuffer(int stackIndex, out ReadOnlySpan @@ -536,7 +538,7 @@ private readonly void AssertLuaStackIndexInBounds(int stackIndex) [MethodImpl(MethodImplOptions.NoInlining)] private readonly void AssertLuaStackExpected() { - Debug.Assert(state.GetTop() == StackTop, "Lua stack not where expected"); + Debug.Assert(NativeMethods.GetTop(state.Handle) == StackTop, "Lua stack not where expected"); } /// diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index f78f8cf104..4bc8c9fd46 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -5,9 +5,9 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using KeraLua; -using charptr_t = System.IntPtr; -using lua_State = System.IntPtr; -using size_t = System.UIntPtr; +using charptr_t = nint; +using lua_State = nint; +using size_t = nuint; namespace Garnet.server { @@ -24,32 +24,99 @@ internal static partial class NativeMethods private const string LuaLibraryName = "lua54"; /// - /// see: https://www.lua.org/manual/5.3/manual.html#lua_tolstring + /// see: https://www.lua.org/manual/5.4/manual.html#lua_tolstring /// [LibraryImport(LuaLibraryName)] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] private static partial charptr_t lua_tolstring(lua_State L, int index, out size_t len); /// - /// see: https://www.lua.org/manual/5.3/manual.html#lua_type + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushlstring /// [LibraryImport(LuaLibraryName)] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] - private static partial LuaType lua_type(lua_State L, int index); + private static partial charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); /// - /// see: https://www.lua.org/manual/5.3/manual.html#lua_pushlstring + /// see: https://www.lua.org/manual/5.4/manual.html#luaL_loadbufferx /// [LibraryImport(LuaLibraryName)] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] - private static partial charptr_t lua_pushlstring(lua_State L, charptr_t s, size_t len); + private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); + + // GC Transition suppressed - only do this after auditing the Lua method and confirming constant-ish, fast, runtime w/o allocations /// - /// see: https://www.lua.org/manual/5.3/manual.html#luaL_loadbufferx + /// see: https://www.lua.org/manual/5.4/manual.html#lua_gettop + /// + /// Does basically nothing, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_gettop /// [LibraryImport(LuaLibraryName)] - [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] - private static partial LuaStatus luaL_loadbufferx(lua_State luaState, charptr_t buff, size_t sz, charptr_t name, charptr_t mode); + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial int lua_gettop(lua_State luaState); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_type + /// + /// Does some very basic ifs and then returns, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_type + /// And + /// see: https://www.lua.org/source/5.4/lapi.c.html#index2value + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial LuaType lua_type(lua_State L, int index); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushnil + /// + /// Does some very small writes, and stack size is pre-validated, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushnil + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_pushnil(lua_State L); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushinteger + /// + /// Does some very small writes, and stack size is pre-validated, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushinteger + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_pushinteger(lua_State L, long num); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_pushboolean + /// + /// Does some very small writes, and stack size is pre-validated, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_pushboolean + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_pushboolean(lua_State L, int b); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_toboolean + /// + /// Does some very basic ifs and then returns an int, so suppressing GC transition. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_toboolean + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial int lua_toboolean(lua_State L, int ix); + + /// + /// see: https://www.lua.org/manual/5.4/manual.html#lua_settop + /// + /// We aren't pushing complex types, so none of the close logic should run. + /// see: https://www.lua.org/source/5.4/lapi.c.html#lua_settop + /// + [LibraryImport(LuaLibraryName)] + [UnmanagedCallConv(CallConvs = [typeof(CallConvSuppressGCTransition)])] + private static partial void lua_settop(lua_State L, int num); /// /// Returns true if the given index on the stack holds a string or a number. @@ -66,7 +133,7 @@ internal static bool CheckBuffer(lua_State luaState, int index, out ReadOnlySpan { var type = lua_type(luaState, index); - if (type != LuaType.String && type != LuaType.Number) + if (type is not LuaType.String and not LuaType.Number) { str = []; return false; @@ -107,7 +174,7 @@ internal static unsafe void PushBuffer(lua_State luaState, ReadOnlySpan st { fixed (byte* ptr = str) { - lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length); + _ = lua_pushlstring(luaState, (charptr_t)ptr, (size_t)str.Length); } } @@ -123,5 +190,63 @@ internal static unsafe LuaStatus LoadBuffer(lua_State luaState, ReadOnlySpan + /// Get the top index on the stack. + /// + /// 0 indicates empty. + /// + /// Differs from by suppressing GC transition. + /// + internal static int GetTop(lua_State luaState) + => lua_gettop(luaState); + + /// + /// Gets the type of the value at the stack index. + /// + /// Differs from by suppressing GC transition. + /// + internal static LuaType Type(lua_State luaState, int index) + => lua_type(luaState, index); + + /// + /// Pushes a nil value onto the stack. + /// + /// Differs from by suppressing GC transition. + /// + internal static void PushNil(lua_State luaState) + => lua_pushnil(luaState); + + /// + /// Pushes a double onto the stack. + /// + /// Differs from by suppressing GC transition. + /// + internal static void PushInteger(lua_State luaState, long num) + => lua_pushinteger(luaState, num); + + /// + /// Pushes a boolean onto the stack. + /// + /// Differs from by suppressing GC transition. + /// + internal static void PushBoolean(lua_State luaState, bool b) + => lua_pushboolean(luaState, b ? 1 : 0); + + /// + /// Read a boolean off the stack + /// + /// Differs from by suppressing GC transition. + /// + internal static bool ToBoolean(lua_State luaState, int index) + => lua_toboolean(luaState, index) != 0; + + /// + /// Remove some number of items from the stack. + /// + /// Differs form by suppressing GC transition. + /// + internal static void Pop(lua_State luaState, int num) + => lua_settop(luaState, -num - 1); } } From 9e170dc5b175e07f3cf8ae98ce8c15bff9e1079b Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 16:20:49 -0500 Subject: [PATCH 46/51] add a benchmark that returns an array, as there's an outstanding TODO to look at removing some p/invokes --- .../BDN.benchmark/Operations/ScriptOperations.cs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs index b544c14345..6986cc9cab 100644 --- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs +++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs @@ -160,6 +160,10 @@ public unsafe class ScriptOperations : OperationsBase byte[] evalShaLargeScriptBuffer; byte* evalShaLargeScriptBufferPointer; + static ReadOnlySpan ARRAY_RETURN => "*3\r\n$4\r\nEVAL\r\n$22\r\nreturn {1, 2, 3, 4, 5}\r\n$1\r\n0\r\n"u8; + byte[] arrayReturnRequestBuffer; + byte* arrayReturnRequestBufferPointer; + public override void GlobalSetup() { base.GlobalSetup(); @@ -178,6 +182,8 @@ public override void GlobalSetup() SetupOperation(ref evalShaRequestBuffer, ref evalShaRequestBufferPointer, EVALSHA); + SetupOperation(ref arrayReturnRequestBuffer, ref arrayReturnRequestBufferPointer, ARRAY_RETURN); + // Setup small script var loadSmallScript = $"*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n${SmallScriptText.Length}\r\n{SmallScriptText}\r\n"; var loadSmallScriptBytes = Encoding.UTF8.GetBytes(loadSmallScript); @@ -271,5 +277,11 @@ public void LargeScript() { _ = session.TryConsumeMessages(evalShaLargeScriptBufferPointer, evalShaLargeScriptBuffer.Length); } + + [Benchmark] + public void ArrayReturn() + { + _ = session.TryConsumeMessages(arrayReturnRequestBufferPointer, arrayReturnRequestBuffer.Length); + } } } \ No newline at end of file From 9684003731137101b8a0062e9d10ac7c44b01e7e Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 16:44:20 -0500 Subject: [PATCH 47/51] nope --- libs/server/Lua/LuaRunner.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 2e8b73c331..bde5252ec0 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -1254,7 +1254,6 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 var maxLen = runner.state.Length(runner.state.StackTop); - // TODO: is it faster to punch a function in for this? // Find the TRUE length by scanning for nils var trueLen = 0; for (trueLen = 0; trueLen < maxLen; trueLen++) From e0c8c047335ab341b923168b02ffc6186000773d Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 17:27:59 -0500 Subject: [PATCH 48/51] add a test for metatable behavior matching Redis; switch to Raw operations where now allowed; more closely sync methods exposed in Lua to those provided by Redis --- libs/server/Lua/LuaRunner.cs | 59 ++++++++++++++++++------------ libs/server/Lua/LuaStateWrapper.cs | 43 +++------------------- test/Garnet.test/LuaScriptTests.cs | 49 +++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 61 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index bde5252ec0..3c80f1d136 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -143,25 +143,38 @@ function redis.error_reply(text) KEYS = {} ARGV = {} sandbox_env = { - tostring = tostring; - next = next; + _G = _G; + _VERSION = _VERSION; + assert = assert; - tonumber = tonumber; - rawequal = rawequal; collectgarbage = collectgarbage; coroutine = coroutine; - type = type; - select = select; - unpack = table.unpack; + error = error; gcinfo = gcinfo; - pairs = pairs; - loadstring = loadstring; + -- explicitly not allowing getfenv + getmetatable = getmetatable; ipairs = ipairs; - error = error; - redis = redis; + load = load; + loadstring = loadstring; math = math; - table = table; + next = next; + pairs = pairs; + pcall = pcall; + rawequal = rawequal; + rawget = rawget; + rawset = rawset; + redis = redis; + select = select; + -- explicitly not allowing setfenv string = string; + setmetatable = setmetatable; + table = table; + tonumber = tonumber; + tostring = tostring; + type = type; + unpack = table.unpack; + xpcall = xpcall; + KEYS = KEYS; ARGV = ARGV; } @@ -393,7 +406,7 @@ unsafe void CompileCommon(ref TResponse resp) state.ForceMinimumStackCapacity(NeededStackSpace); state.PushInteger(loadSandboxedRegistryIndex); - _ = state.GetTable(LuaType.Function, (int)LuaRegistry.Index); + _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index); state.PushBuffer(source.Span); state.Call(1, -1); // Multiple returns allowed @@ -1005,7 +1018,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) { // get KEYS on the stack state.PushInteger(keysTableRegistryIndex); - state.GetTable(LuaType.Table, (int)LuaRegistry.Index); + _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < keys.Length; i++) { @@ -1023,7 +1036,7 @@ void LoadParametersForRunner(string[] keys, string[] argv) { // get ARGV on the stack state.PushInteger(argvTableRegistryIndex); - state.GetTable(LuaType.Table, (int)LuaRegistry.Index); + _ = state.RawGet(LuaType.Table, (int)LuaRegistry.Index); for (var i = 0; i < argv.Length; i++) { @@ -1066,7 +1079,7 @@ unsafe void RunCommon(ref TResponse resp) state.ForceMinimumStackCapacity(NeededStackSize); state.PushInteger(functionRegistryIndex); - _ = state.GetTable(LuaType.Function, (int)LuaRegistry.Index); + _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index); var callRes = state.PCall(0, 1); if (callRes == LuaStatus.OK) @@ -1104,14 +1117,12 @@ unsafe void RunCommon(ref TResponse resp) } else if (retType == LuaType.Table) { - // TODO: because we are dealing with a user provided type, we MUST respect - // metatables - so we can't use any of the RawXXX methods - // so we need a test that use metatables (and compare to how Redis does this) + // Redis does not respect metatables, so RAW access is ok here // If the key err is in there, we need to short circuit state.PushConstantString(errConstStringRegistryIndex); - var errType =state.GetTable(null, 1); + var errType = state.RawGet(null, 1); if (errType == LuaType.String) { WriteError(this, ref resp); @@ -1245,6 +1256,8 @@ static void WriteError(LuaRunner runner, ref TResponse resp) static void WriteArray(LuaRunner runner, ref TResponse resp) { + // Redis does not respect metatables, so RAW access is ok here + // 1 for the table, 1 for the pending value const int AdditonalNeededStackSize = 2; @@ -1252,13 +1265,13 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) // Lua # operator - this MAY stop at nils, but isn't guaranteed to // See: https://www.lua.org/manual/5.3/manual.html#3.4.7 - var maxLen = runner.state.Length(runner.state.StackTop); + var maxLen = runner.state.RawLen(runner.state.StackTop); // Find the TRUE length by scanning for nils var trueLen = 0; for (trueLen = 0; trueLen < maxLen; trueLen++) { - var type = runner.state.GetInteger(null, runner.state.StackTop, trueLen + 1); + var type = runner.state.RawGetInteger(null, runner.state.StackTop, trueLen + 1); runner.state.Pop(1); if (type == LuaType.Nil) @@ -1273,7 +1286,7 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) for (var i = 1; i <= trueLen; i++) { // Push item at index i onto the stack - var type = runner.state.GetInteger(null, runner.state.StackTop, i); + var type = runner.state.RawGetInteger(null, runner.state.StackTop, i); switch (type) { diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs index 0c8c7b9b22..e4ebb0d611 100644 --- a/libs/server/Lua/LuaStateWrapper.cs +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -271,29 +271,12 @@ internal void RawSet(int stackIndex) /// Maintains and to minimize p/invoke calls. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void RawGetInteger(LuaType expectedType, int stackIndex, int tableIndex) + internal LuaType RawGetInteger(LuaType? expectedType, int stackIndex, int tableIndex) { AssertLuaStackIndexInBounds(stackIndex); AssertLuaStackNotFull(); var actual = state.RawGetInteger(stackIndex, tableIndex); - Debug.Assert(actual == expectedType, "Unexpected type received"); - - UpdateStackTop(1); - } - - /// - /// This should be used for all GetIntegers into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal LuaType GetInteger(LuaType? expectedType, int stackIndex, int tableIndex) - { - AssertLuaStackIndexInBounds(stackIndex); - AssertLuaStackNotFull(); - - var actual = state.GetInteger(stackIndex, tableIndex); Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); UpdateStackTop(1); @@ -307,27 +290,11 @@ internal LuaType GetInteger(LuaType? expectedType, int stackIndex, int tableInde /// Maintains and to minimize p/invoke calls. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal readonly void RawGet(LuaType expectedType, int stackIndex) + internal readonly LuaType RawGet(LuaType? expectedType, int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); var actual = state.RawGet(stackIndex); - Debug.Assert(actual == expectedType, "Unexpected type received"); - - AssertLuaStackExpected(); - } - - /// - /// This should be used for all GetTables into Lua. - /// - /// Maintains and to minimize p/invoke calls. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal readonly LuaType GetTable(LuaType? expectedType, int stackIndex) - { - AssertLuaStackIndexInBounds(stackIndex); - - var actual = state.GetTable(stackIndex); Debug.Assert(expectedType == null || actual == expectedType, "Unexpected type received"); AssertLuaStackExpected(); @@ -450,14 +417,14 @@ internal readonly bool ToBoolean(int stackIndex) } /// - /// This should be used for all Lengths into Lua. + /// This should be used for all RawLens into Lua. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal readonly long Length(int stackIndex) + internal readonly long RawLen(int stackIndex) { AssertLuaStackIndexInBounds(stackIndex); - return state.Length(stackIndex); + return state.RawLen(stackIndex); } /// diff --git a/test/Garnet.test/LuaScriptTests.cs b/test/Garnet.test/LuaScriptTests.cs index cf40a09ae3..ace0593340 100644 --- a/test/Garnet.test/LuaScriptTests.cs +++ b/test/Garnet.test/LuaScriptTests.cs @@ -736,5 +736,54 @@ public void ComplexLuaReturns() } } } + + [Test] + public void MetatableReturn() + { + const string Script = @" + local table = { abc = 'def', ghi = 'jkl' } + local ret = setmetatable( + table, + { + __len = function (self) + return 4 + end, + __index = function(self, key) + if key == 1 then + return KEYS[1] + end + + if key == 2 then + return ARGV[1] + end + + if key == 3 then + return self.ghi + end + + if key == 4 then + return self.abc + end + + return nil + end + } + ) + + -- prove that metatables WORK but also that we don't deconstruct them for returns + return { ret, ret[1], ret[2], ret[3], ret[4] }"; + + using var redis = ConnectionMultiplexer.Connect(TestUtils.GetConfig()); + var db = redis.GetDatabase(); + + var ret = (RedisResult[])db.ScriptEvaluate(Script, [(RedisKey)"foo"], [(RedisValue)"bar"]); + ClassicAssert.AreEqual(5, ret.Length); + var firstEmpty = (RedisResult[])ret[0]; + ClassicAssert.AreEqual(0, firstEmpty.Length); + ClassicAssert.AreEqual("foo", (string)ret[1]); + ClassicAssert.AreEqual("bar", (string)ret[2]); + ClassicAssert.AreEqual("jkl", (string)ret[3]); + ClassicAssert.AreEqual("def", (string)ret[4]); + } } } \ No newline at end of file From a5aead959d86f49391ca302d02526a34c7cf22c0 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 17:31:29 -0500 Subject: [PATCH 49/51] todone --- .../BDN.benchmark/Operations/ScriptOperations.cs | 2 +- libs/server/Lua/LuaRunner.cs | 4 +--- libs/server/Lua/LuaStateWrapper.cs | 2 +- libs/server/Lua/NativeMethods.cs | 4 +--- libs/server/Lua/SessionScriptCache.cs | 11 ++--------- 5 files changed, 6 insertions(+), 17 deletions(-) diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs index 6986cc9cab..531ff5d16f 100644 --- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs +++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs @@ -215,7 +215,7 @@ public override void GlobalSetup() for (var i = 0; i < batchSize; i++) { var evalShaLargeScript = $"*7\r\n$7\r\nEVALSHA\r\n$40\r\n{largeScriptHash}\r\n$1\r\n1\r\n$5\r\nhello\r\n"; - + var id = Guid.NewGuid().ToString(); evalShaLargeScript += $"${id.Length}\r\n"; evalShaLargeScript += $"{id}\r\n"; diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index 3c80f1d136..f1ce7576eb 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -273,7 +273,7 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe // TODO: custom allocator? state = new LuaStateWrapper(new Lua()); - + if (txnMode) { txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext); @@ -1315,7 +1315,5 @@ static void WriteArray(LuaRunner runner, ref TResponse resp) runner.state.Pop(1); } } - - } } \ No newline at end of file diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs index e4ebb0d611..4e91f06f86 100644 --- a/libs/server/Lua/LuaStateWrapper.cs +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -518,4 +518,4 @@ private readonly void AssertLuaStackNotFull(int probe = 1) Debug.Assert((StackTop + probe) <= curStackSize, "Lua stack should have been grown before pushing"); } } -} +} \ No newline at end of file diff --git a/libs/server/Lua/NativeMethods.cs b/libs/server/Lua/NativeMethods.cs index 4bc8c9fd46..cc6bcf085b 100644 --- a/libs/server/Lua/NativeMethods.cs +++ b/libs/server/Lua/NativeMethods.cs @@ -19,8 +19,6 @@ namespace Garnet.server /// internal static partial class NativeMethods { - // TODO: Suppress GC transition (requires Lua audit) - private const string LuaLibraryName = "lua54"; /// @@ -249,4 +247,4 @@ internal static bool ToBoolean(lua_State luaState, int index) internal static void Pop(lua_State luaState, int num) => lua_settop(luaState, -num - 1); } -} +} \ No newline at end of file diff --git a/libs/server/Lua/SessionScriptCache.cs b/libs/server/Lua/SessionScriptCache.cs index 93499f84e1..4dfd1ac296 100644 --- a/libs/server/Lua/SessionScriptCache.cs +++ b/libs/server/Lua/SessionScriptCache.cs @@ -60,14 +60,7 @@ public bool TryGetFromDigest(ScriptHashKey digest, out LuaRunner scriptRunner) /// /// If necessary, will be set so the allocation can be reused. /// - internal bool TryLoad( - RespServerSession session, - ReadOnlySpan source, - ScriptHashKey digest, - out LuaRunner runner, - out ScriptHashKey? digestOnHeap, - out string error - ) + internal bool TryLoad(RespServerSession session, ReadOnlySpan source, ScriptHashKey digest, out LuaRunner runner, out ScriptHashKey? digestOnHeap, out string error) { error = null; @@ -94,7 +87,7 @@ out string error ScriptHashKey storeKeyDigest = new(into); digestOnHeap = storeKeyDigest; - + _ = scriptCache.TryAdd(storeKeyDigest, runner); } catch (Exception ex) From 4d73bceadfe8e63788106fa2e9c747cfd5785fcf Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Fri, 13 Dec 2024 17:33:34 -0500 Subject: [PATCH 50/51] formatting --- benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs | 2 +- benchmark/BDN.benchmark/Operations/ScriptOperations.cs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs index ca36c75870..c804e0d9ac 100644 --- a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs +++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs @@ -150,4 +150,4 @@ private void LoadScript(Span digest) } } } -} +} \ No newline at end of file diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs index 531ff5d16f..26436bfea9 100644 --- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs +++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs @@ -195,7 +195,7 @@ public override void GlobalSetup() var smallScriptHash = string.Join("", SHA1.HashData(Encoding.UTF8.GetBytes(SmallScriptText)).Select(static x => x.ToString("x2"))); var evalShaSmallScript = $"*4\r\n$7\r\nEVALSHA\r\n$40\r\n{smallScriptHash}\r\n$1\r\n1\r\n$3\r\nfoo\r\n"; evalShaSmallScriptBuffer = GC.AllocateUninitializedArray(evalShaSmallScript.Length * batchSize, pinned: true); - for(var i = 0; i < batchSize; i++) + for (var i = 0; i < batchSize; i++) { var start = i * evalShaSmallScript.Length; Encoding.UTF8.GetBytes(evalShaSmallScript, evalShaSmallScriptBuffer.AsSpan().Slice(start, evalShaSmallScript.Length)); From 3afa3fa405052d6ec3aa935ddf78c68f4a62e6f3 Mon Sep 17 00:00:00 2001 From: Kevin Montrose Date: Wed, 18 Dec 2024 10:13:02 -0500 Subject: [PATCH 51/51] address feedback; spelling nits --- libs/server/Lua/LuaRunner.cs | 22 +++++++++++----------- libs/server/Lua/LuaStateWrapper.cs | 4 ++-- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs index f1ce7576eb..38851209da 100644 --- a/libs/server/Lua/LuaRunner.cs +++ b/libs/server/Lua/LuaRunner.cs @@ -230,9 +230,9 @@ function load_sandboxed(source) readonly int argvTableRegistryIndex; readonly int loadSandboxedRegistryIndex; readonly int resetKeysAndArgvRegistryIndex; - readonly int okConstStringRegisteryIndex; + readonly int okConstStringRegistryIndex; readonly int errConstStringRegistryIndex; - readonly int noSessionAvailableConstStringRegisteryIndex; + readonly int noSessionAvailableConstStringRegistryIndex; readonly int pleaseSpecifyRedisCallConstStringRegistryIndex; readonly int errNoAuthConstStringRegistryIndex; readonly int errUnknownConstStringRegistryIndex; @@ -316,13 +316,13 @@ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSe resetKeysAndArgvRegistryIndex = state.Ref(); // Commonly used strings, register them once so we don't have to copy them over each time we need them - okConstStringRegisteryIndex = ConstantStringToRegistery(CmdStrings.LUA_OK); - errConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_err); - noSessionAvailableConstStringRegisteryIndex = ConstantStringToRegistery(CmdStrings.LUA_No_session_available); - pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call); - errNoAuthConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.RESP_ERR_NOAUTH); - errUnknownConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_ERR_Unknown_Redis_command_called_from_script); - errBadArgConstStringRegistryIndex = ConstantStringToRegistery(CmdStrings.LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers); + okConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_OK); + errConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_err); + noSessionAvailableConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_No_session_available); + pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call); + errNoAuthConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.RESP_ERR_NOAUTH); + errUnknownConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Unknown_Redis_command_called_from_script); + errBadArgConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers); state.ExpectLuaStackEmpty(); } @@ -340,7 +340,7 @@ public LuaRunner(string source, bool txnMode = false, RespServerSession respServ /// /// So instead we stash them in the Registry and load them by index /// - int ConstantStringToRegistery(ReadOnlySpan str) + int ConstantStringToRegistry(ReadOnlySpan str) { state.PushBuffer(str); return state.Ref(); @@ -560,7 +560,7 @@ unsafe int ProcessCommandFromScripting(TGarnetApi api) _ = api.SET(key, value); - state.PushConstantString(okConstStringRegisteryIndex); + state.PushConstantString(okConstStringRegistryIndex); return 1; } else if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "GET"u8) && argCount == 2) diff --git a/libs/server/Lua/LuaStateWrapper.cs b/libs/server/Lua/LuaStateWrapper.cs index 4e91f06f86..51ac6b0e0a 100644 --- a/libs/server/Lua/LuaStateWrapper.cs +++ b/libs/server/Lua/LuaStateWrapper.cs @@ -322,9 +322,9 @@ internal int Ref() /// Maintains and to minimize p/invoke calls. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal readonly void Unref(LuaRegistry registery, int reference) + internal readonly void Unref(LuaRegistry registry, int reference) { - state.Unref(registery, reference); + state.Unref(registry, reference); } ///