diff --git a/Directory.Packages.props b/Directory.Packages.props
index d674d273f0..0719a40ed8 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -8,7 +8,7 @@
-
+
diff --git a/benchmark/BDN.benchmark/BDN.benchmark.csproj b/benchmark/BDN.benchmark/BDN.benchmark.csproj
index 60753023ab..525b4edc72 100644
--- a/benchmark/BDN.benchmark/BDN.benchmark.csproj
+++ b/benchmark/BDN.benchmark/BDN.benchmark.csproj
@@ -10,7 +10,7 @@
-
+
diff --git a/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs
new file mode 100644
index 0000000000..ecd5002bfb
--- /dev/null
+++ b/benchmark/BDN.benchmark/Lua/LuaRunnerOperations.cs
@@ -0,0 +1,220 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+using BenchmarkDotNet.Attributes;
+using Embedded.perftest;
+using Garnet.server;
+
+namespace BDN.benchmark.Lua
+{
+ ///
+ /// Benchmark for non-script running operations in LuaRunner
+ ///
+ [MemoryDiagnoser]
+ public unsafe class LuaRunnerOperations
+ {
+ private const string SmallScript = "return nil";
+
+ private const string LargeScript = @"
+-- based on a real UDF, with some extraneous ops removed
+
+local userKey = KEYS[1]
+local identifier = ARGV[1]
+local currentTime = ARGV[2]
+local newSession = ARGV[3]
+
+local userKeyValue = redis.call(""GET"", userKey)
+
+local updatedValue = nil
+local returnValue = -1
+
+if userKeyValue then
+-- needs to be updated
+
+ local oldestEntry = nil
+ local oldestEntryUpdateTime = nil
+
+ local match = nil
+
+ local entryCount = 0
+
+-- loop over each entry, looking for one to update
+ for entry in string.gmatch(userKeyValue, ""([^%|]+)"") do
+ entryCount = entryCount + 1
+
+ local entryIdentifier = nil
+ local entrySessionNumber = -1
+ local entryRequestCount = -1
+ local entryLastSessionUpdateTime = -1
+
+ local ix = 0
+ for part in string.gmatch(entry, ""([^:]+)"") do
+ if ix == 0 then
+ entryIdentifier = part
+ elseif ix == 1 then
+ entrySessionNumber = tonumber(part)
+ elseif ix == 2 then
+ entryRequestCount = tonumber(part)
+ elseif ix == 3 then
+ entryLastSessionUpdateTime = tonumber(part)
+ else
+-- malformed, too many parts
+ return -1
+ end
+
+ ix = ix + 1
+ end
+
+ if ix ~= 4 then
+-- malformed, too few parts
+ return -2
+ end
+
+ if entryIdentifier == identifier then
+-- found the one to update
+ local updatedEntry = nil
+
+ if tonumber(newSession) == 1 then
+ local updatedSessionNumber = entrySessionNumber + 1
+ updatedEntry = entryIdentifier .. "":"" .. tostring(updatedSessionNumber) .. "":1:"" .. tostring(currentTime)
+ returnValue = 3
+ else
+ local updatedRequestCount = entryRequestCount + 1
+ updatedEntry = entryIdentifier .. "":"" .. tostring(entrySessionNumber) .. "":"" .. tostring(updatedRequestCount) .. "":"" .. tostring(currentTime)
+ returnValue = 2
+ end
+
+-- have to escape the replacement, since Lua doesn't have a literal replace :/
+ local escapedEntry = string.gsub(entry, ""%p"", ""%%%1"")
+ updatedValue = string.gsub(userKeyValue, escapedEntry, updatedEntry)
+
+ break
+ end
+
+ if oldestEntryUpdateTime == nil or oldestEntryUpdateTime > entryLastSessionUpdateTime then
+-- remember the oldest entry, so we can replace it if needed
+ oldestEntry = entry
+ oldestEntryUpdateTime = entryLastSessionUpdateTime
+ end
+ end
+
+ if updatedValue == nil then
+-- we didn't update an existing value, so we need to add it
+
+ local newEntry = identifier .. "":1:1:"" .. tostring(currentTime)
+
+ if entryCount < 20 then
+-- there's room, just append it
+ updatedValue = userKeyValue .. ""|"" .. newEntry
+ returnValue = 4
+ else
+-- there isn't room, replace the LRU entry
+
+-- have to escape the replacement, since Lua doesn't have a literal replace :/
+ local escapedOldestEntry = string.gsub(oldestEntry, ""%p"", ""%%%1"")
+
+ updatedValue = string.gsub(userKeyValue, escapedOldestEntry, newEntry)
+ returnValue = 5
+ end
+ end
+else
+-- needs to be created
+ updatedValue = identifier .. "":1:1:"" .. tostring(currentTime)
+
+ returnValue = 1
+end
+
+redis.call(""SET"", userKey, updatedValue)
+
+return returnValue
+";
+
+ ///
+ /// Lua parameters
+ ///
+ [ParamsSource(nameof(LuaParamsProvider))]
+ public LuaParams Params { get; set; }
+
+ ///
+ /// Lua parameters provider
+ ///
+ public IEnumerable LuaParamsProvider()
+ {
+ yield return new();
+ }
+
+ private EmbeddedRespServer server;
+ private RespServerSession session;
+
+ private LuaRunner paramsRunner;
+
+ private LuaRunner smallCompileRunner;
+ private LuaRunner largeCompileRunner;
+
+ [GlobalSetup]
+ public void GlobalSetup()
+ {
+ server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true });
+
+ session = server.GetRespSession();
+ paramsRunner = new LuaRunner("return nil");
+
+ smallCompileRunner = new LuaRunner(SmallScript);
+ largeCompileRunner = new LuaRunner(LargeScript);
+ }
+
+ [GlobalCleanup]
+ public void GlobalCleanup()
+ {
+ session.Dispose();
+ server.Dispose();
+ paramsRunner.Dispose();
+ }
+
+ [Benchmark]
+ public void ResetParametersSmall()
+ {
+ // First force up
+ paramsRunner.ResetParameters(1, 1);
+
+ // Then require a small amount of clearing (1 key, 1 arg)
+ paramsRunner.ResetParameters(0, 0);
+ }
+
+ [Benchmark]
+ public void ResetParametersLarge()
+ {
+ // First force up
+ paramsRunner.ResetParameters(10, 10);
+
+ // Then require a large amount of clearing (10 keys, 10 args)
+ paramsRunner.ResetParameters(0, 0);
+ }
+
+ [Benchmark]
+ public void ConstructSmall()
+ {
+ using var runner = new LuaRunner(SmallScript);
+ }
+
+ [Benchmark]
+ public void ConstructLarge()
+ {
+ using var runner = new LuaRunner(LargeScript);
+ }
+
+ [Benchmark]
+ public void CompileForSessionSmall()
+ {
+ smallCompileRunner.ResetCompilation();
+ smallCompileRunner.CompileForSession(session);
+ }
+
+ [Benchmark]
+ public void CompileForSessionLarge()
+ {
+ largeCompileRunner.ResetCompilation();
+ largeCompileRunner.CompileForSession(session);
+ }
+ }
+}
\ No newline at end of file
diff --git a/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs
new file mode 100644
index 0000000000..c804e0d9ac
--- /dev/null
+++ b/benchmark/BDN.benchmark/Lua/LuaScriptCacheOperations.cs
@@ -0,0 +1,153 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT license.
+
+using BenchmarkDotNet.Attributes;
+using Embedded.perftest;
+using Garnet.common;
+using Garnet.server;
+using Garnet.server.Auth;
+
+namespace BDN.benchmark.Lua
+{
+ [MemoryDiagnoser]
+ public class LuaScriptCacheOperations
+ {
+ ///
+ /// Lua parameters
+ ///
+ [ParamsSource(nameof(LuaParamsProvider))]
+ public LuaParams Params { get; set; }
+
+ ///
+ /// Lua parameters provider
+ ///
+ public IEnumerable LuaParamsProvider()
+ {
+ yield return new();
+ }
+
+ private EmbeddedRespServer server;
+ private StoreWrapper storeWrapper;
+ private SessionScriptCache sessionScriptCache;
+ private RespServerSession session;
+
+ private byte[] outerHitDigest;
+ private byte[] innerHitDigest;
+ private byte[] missDigest;
+
+ [GlobalSetup]
+ public void GlobalSetup()
+ {
+ server = new EmbeddedRespServer(new GarnetServerOptions() { EnableLua = true, QuietMode = true });
+ storeWrapper = server.StoreWrapper;
+ sessionScriptCache = new SessionScriptCache(storeWrapper, new GarnetNoAuthAuthenticator());
+ session = server.GetRespSession();
+
+ outerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true);
+ sessionScriptCache.GetScriptDigest("return 1"u8, outerHitDigest);
+ if (!storeWrapper.storeScriptCache.TryAdd(new(outerHitDigest), "return 1"u8.ToArray()))
+ {
+ throw new InvalidOperationException("Should have been able to load into global cache");
+ }
+
+ innerHitDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true);
+ sessionScriptCache.GetScriptDigest("return 1 + 1"u8, innerHitDigest);
+ if (!storeWrapper.storeScriptCache.TryAdd(new(innerHitDigest), "return 1 + 1"u8.ToArray()))
+ {
+ throw new InvalidOperationException("Should have been able to load into global cache");
+ }
+
+ missDigest = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true);
+ sessionScriptCache.GetScriptDigest("foobar"u8, missDigest);
+ }
+
+ [GlobalCleanup]
+ public void GlobalCleanup()
+ {
+ session?.Dispose();
+ server?.Dispose();
+ }
+
+ [IterationSetup]
+ public void IterationSetup()
+ {
+ // Force lookup to do work
+ sessionScriptCache.Clear();
+
+ // Make outer hit available for every iteration
+ if (!sessionScriptCache.TryLoad(session, "return 1"u8, new(outerHitDigest), out _, out _, out var error))
+ {
+ throw new InvalidOperationException($"Should have been able to load: {error}");
+ }
+ }
+
+ [Benchmark]
+ public void LookupHit()
+ {
+ _ = sessionScriptCache.TryGetFromDigest(new(outerHitDigest), out _);
+ }
+
+ [Benchmark]
+ public void LookupMiss()
+ {
+ _ = sessionScriptCache.TryGetFromDigest(new(missDigest), out _);
+ }
+
+ [Benchmark]
+ public void LoadOuterHit()
+ {
+ // First if returns true
+ //
+ // This is the common case
+ LoadScript(outerHitDigest);
+ }
+
+ [Benchmark]
+ public void LoadInnerHit()
+ {
+ // First if returns false, second if returns true
+ //
+ // This is expected, but rare
+ LoadScript(innerHitDigest);
+ }
+
+ [Benchmark]
+ public void LoadMiss()
+ {
+ // First if returns false, second if returns false
+ //
+ // This is extremely unlikely, basically implies an error on the client
+ LoadScript(missDigest);
+ }
+
+ [Benchmark]
+ public void Digest()
+ {
+ Span digest = stackalloc byte[SessionScriptCache.SHA1Len];
+ sessionScriptCache.GetScriptDigest("return 1 + redis.call('GET', KEYS[1])"u8, digest);
+ }
+
+ ///
+ /// The moral equivalent to our cache load operation.
+ ///
+ private void LoadScript(Span digest)
+ {
+ AsciiUtils.ToLowerInPlace(digest);
+
+ var digestKey = new ScriptHashKey(digest);
+
+ if (!sessionScriptCache.TryGetFromDigest(digestKey, out var runner))
+ {
+ if (storeWrapper.storeScriptCache.TryGetValue(digestKey, out var source))
+ {
+ if (!sessionScriptCache.TryLoad(session, source, digestKey, out runner, out _, out var error))
+ {
+ // TryLoad will have written an error out, it any
+
+ _ = storeWrapper.storeScriptCache.TryRemove(digestKey, out _);
+ }
+ }
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/benchmark/BDN.benchmark/Lua/LuaScripts.cs b/benchmark/BDN.benchmark/Lua/LuaScripts.cs
index 5c6060d2e6..3cff4b30e3 100644
--- a/benchmark/BDN.benchmark/Lua/LuaScripts.cs
+++ b/benchmark/BDN.benchmark/Lua/LuaScripts.cs
@@ -2,7 +2,6 @@
// Licensed under the MIT license.
using BenchmarkDotNet.Attributes;
-using BenchmarkDotNet.Columns;
using Garnet.server;
namespace BDN.benchmark.Lua
@@ -11,7 +10,6 @@ namespace BDN.benchmark.Lua
/// Benchmark for Lua
///
[MemoryDiagnoser]
- [HideColumns(Column.Gen0)]
public unsafe class LuaScripts
{
///
@@ -35,13 +33,13 @@ public IEnumerable LuaParamsProvider()
public void GlobalSetup()
{
r1 = new LuaRunner("return");
- r1.Compile();
+ r1.CompileForRunner();
r2 = new LuaRunner("return 1 + 1");
- r2.Compile();
+ r2.CompileForRunner();
r3 = new LuaRunner("return KEYS[1]");
- r3.Compile();
+ r3.CompileForRunner();
r4 = new LuaRunner("return redis.call(KEYS[1])");
- r4.Compile();
+ r4.CompileForRunner();
}
[GlobalCleanup]
@@ -55,18 +53,18 @@ public void GlobalCleanup()
[Benchmark]
public void Script1()
- => r1.Run();
+ => r1.RunForRunner();
[Benchmark]
public void Script2()
- => r2.Run();
+ => r2.RunForRunner();
[Benchmark]
public void Script3()
- => r3.Run(keys, null);
+ => r3.RunForRunner(keys, null);
[Benchmark]
public void Script4()
- => r4.Run(keys, null);
+ => r4.RunForRunner(keys, null);
}
}
\ No newline at end of file
diff --git a/benchmark/BDN.benchmark/Operations/OperationsBase.cs b/benchmark/BDN.benchmark/Operations/OperationsBase.cs
index 8d58631fe9..be4fa09608 100644
--- a/benchmark/BDN.benchmark/Operations/OperationsBase.cs
+++ b/benchmark/BDN.benchmark/Operations/OperationsBase.cs
@@ -39,7 +39,7 @@ public IEnumerable OperationParamsProvider()
/// 25 us = 4 Mops/sec
/// 100 us = 1 Mops/sec
///
- const int batchSize = 100;
+ internal const int batchSize = 100;
internal EmbeddedRespServer server;
internal RespServerSession session;
diff --git a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs
index f068a7c0ac..26436bfea9 100644
--- a/benchmark/BDN.benchmark/Operations/ScriptOperations.cs
+++ b/benchmark/BDN.benchmark/Operations/ScriptOperations.cs
@@ -1,6 +1,9 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
+using System.Runtime.CompilerServices;
+using System.Security.Cryptography;
+using System.Text;
using BenchmarkDotNet.Attributes;
namespace BDN.benchmark.Operations
@@ -11,6 +14,124 @@ namespace BDN.benchmark.Operations
[MemoryDiagnoser]
public unsafe class ScriptOperations : OperationsBase
{
+ // Small script that does 1 operation and no logic
+ const string SmallScriptText = @"return redis.call('GET', KEYS[1]);";
+
+ // Large script that does 2 operations and lots of logic
+ const string LargeScriptText = @"
+-- based on a real UDF, with some extraneous ops removed
+
+local userKey = KEYS[1]
+local identifier = ARGV[1]
+local currentTime = ARGV[2]
+local newSession = ARGV[3]
+
+local userKeyValue = redis.call(""GET"", userKey)
+
+local updatedValue = nil
+local returnValue = -1
+
+if userKeyValue then
+-- needs to be updated
+
+ local oldestEntry = nil
+ local oldestEntryUpdateTime = nil
+
+ local match = nil
+
+ local entryCount = 0
+
+-- loop over each entry, looking for one to update
+ for entry in string.gmatch(userKeyValue, ""([^%|]+)"") do
+ entryCount = entryCount + 1
+
+ local entryIdentifier = nil
+ local entrySessionNumber = -1
+ local entryRequestCount = -1
+ local entryLastSessionUpdateTime = -1
+
+ local ix = 0
+ for part in string.gmatch(entry, ""([^:]+)"") do
+ if ix == 0 then
+ entryIdentifier = part
+ elseif ix == 1 then
+ entrySessionNumber = tonumber(part)
+ elseif ix == 2 then
+ entryRequestCount = tonumber(part)
+ elseif ix == 3 then
+ entryLastSessionUpdateTime = tonumber(part)
+ else
+-- malformed, too many parts
+ return -1
+ end
+
+ ix = ix + 1
+ end
+
+ if ix ~= 4 then
+-- malformed, too few parts
+ return -2
+ end
+
+ if entryIdentifier == identifier then
+-- found the one to update
+ local updatedEntry = nil
+
+ if tonumber(newSession) == 1 then
+ local updatedSessionNumber = entrySessionNumber + 1
+ updatedEntry = entryIdentifier .. "":"" .. tostring(updatedSessionNumber) .. "":1:"" .. tostring(currentTime)
+ returnValue = 3
+ else
+ local updatedRequestCount = entryRequestCount + 1
+ updatedEntry = entryIdentifier .. "":"" .. tostring(entrySessionNumber) .. "":"" .. tostring(updatedRequestCount) .. "":"" .. tostring(currentTime)
+ returnValue = 2
+ end
+
+-- have to escape the replacement, since Lua doesn't have a literal replace :/
+ local escapedEntry = string.gsub(entry, ""%p"", ""%%%1"")
+ updatedValue = string.gsub(userKeyValue, escapedEntry, updatedEntry)
+
+ break
+ end
+
+ if oldestEntryUpdateTime == nil or oldestEntryUpdateTime > entryLastSessionUpdateTime then
+-- remember the oldest entry, so we can replace it if needed
+ oldestEntry = entry
+ oldestEntryUpdateTime = entryLastSessionUpdateTime
+ end
+ end
+
+ if updatedValue == nil then
+-- we didn't update an existing value, so we need to add it
+
+ local newEntry = identifier .. "":1:1:"" .. tostring(currentTime)
+
+ if entryCount < 20 then
+-- there's room, just append it
+ updatedValue = userKeyValue .. ""|"" .. newEntry
+ returnValue = 4
+ else
+-- there isn't room, replace the LRU entry
+
+-- have to escape the replacement, since Lua doesn't have a literal replace :/
+ local escapedOldestEntry = string.gsub(oldestEntry, ""%p"", ""%%%1"")
+
+ updatedValue = string.gsub(userKeyValue, escapedOldestEntry, newEntry)
+ returnValue = 5
+ end
+ end
+else
+-- needs to be created
+ updatedValue = identifier .. "":1:1:"" .. tostring(currentTime)
+
+ returnValue = 1
+end
+
+redis.call(""SET"", userKey, updatedValue)
+
+return returnValue
+";
+
static ReadOnlySpan SCRIPT_LOAD => "*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n$8\r\nreturn 1\r\n"u8;
byte[] scriptLoadRequestBuffer;
byte* scriptLoadRequestBufferPointer;
@@ -33,6 +154,16 @@ public unsafe class ScriptOperations : OperationsBase
byte[] evalShaRequestBuffer;
byte* evalShaRequestBufferPointer;
+ byte[] evalShaSmallScriptBuffer;
+ byte* evalShaSmallScriptBufferPointer;
+
+ byte[] evalShaLargeScriptBuffer;
+ byte* evalShaLargeScriptBufferPointer;
+
+ static ReadOnlySpan ARRAY_RETURN => "*3\r\n$4\r\nEVAL\r\n$22\r\nreturn {1, 2, 3, 4, 5}\r\n$1\r\n0\r\n"u8;
+ byte[] arrayReturnRequestBuffer;
+ byte* arrayReturnRequestBufferPointer;
+
public override void GlobalSetup()
{
base.GlobalSetup();
@@ -50,6 +181,59 @@ public override void GlobalSetup()
SetupOperation(ref evalRequestBuffer, ref evalRequestBufferPointer, EVAL);
SetupOperation(ref evalShaRequestBuffer, ref evalShaRequestBufferPointer, EVALSHA);
+
+ SetupOperation(ref arrayReturnRequestBuffer, ref arrayReturnRequestBufferPointer, ARRAY_RETURN);
+
+ // Setup small script
+ var loadSmallScript = $"*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n${SmallScriptText.Length}\r\n{SmallScriptText}\r\n";
+ var loadSmallScriptBytes = Encoding.UTF8.GetBytes(loadSmallScript);
+ fixed (byte* loadPtr = loadSmallScriptBytes)
+ {
+ _ = session.TryConsumeMessages(loadPtr, loadSmallScriptBytes.Length);
+ }
+
+ var smallScriptHash = string.Join("", SHA1.HashData(Encoding.UTF8.GetBytes(SmallScriptText)).Select(static x => x.ToString("x2")));
+ var evalShaSmallScript = $"*4\r\n$7\r\nEVALSHA\r\n$40\r\n{smallScriptHash}\r\n$1\r\n1\r\n$3\r\nfoo\r\n";
+ evalShaSmallScriptBuffer = GC.AllocateUninitializedArray(evalShaSmallScript.Length * batchSize, pinned: true);
+ for (var i = 0; i < batchSize; i++)
+ {
+ var start = i * evalShaSmallScript.Length;
+ Encoding.UTF8.GetBytes(evalShaSmallScript, evalShaSmallScriptBuffer.AsSpan().Slice(start, evalShaSmallScript.Length));
+ }
+ evalShaSmallScriptBufferPointer = (byte*)Unsafe.AsPointer(ref evalShaSmallScriptBuffer[0]);
+
+ // Setup large script
+ var loadLargeScript = $"*3\r\n$6\r\nSCRIPT\r\n$4\r\nLOAD\r\n${LargeScriptText.Length}\r\n{LargeScriptText}\r\n";
+ var loadLargeScriptBytes = Encoding.UTF8.GetBytes(loadLargeScript);
+ fixed (byte* loadPtr = loadLargeScriptBytes)
+ {
+ _ = session.TryConsumeMessages(loadPtr, loadLargeScriptBytes.Length);
+ }
+
+ var largeScriptHash = string.Join("", SHA1.HashData(Encoding.UTF8.GetBytes(LargeScriptText)).Select(static x => x.ToString("x2")));
+ var largeScriptEvals = new List();
+ for (var i = 0; i < batchSize; i++)
+ {
+ var evalShaLargeScript = $"*7\r\n$7\r\nEVALSHA\r\n$40\r\n{largeScriptHash}\r\n$1\r\n1\r\n$5\r\nhello\r\n";
+
+ var id = Guid.NewGuid().ToString();
+ evalShaLargeScript += $"${id.Length}\r\n";
+ evalShaLargeScript += $"{id}\r\n";
+
+ var time = (i * 10).ToString();
+ evalShaLargeScript += $"${time.Length}\r\n";
+ evalShaLargeScript += $"{time}\r\n";
+
+ var newSession = i % 2;
+ evalShaLargeScript += "$1\r\n";
+ evalShaLargeScript += $"{newSession}\r\n";
+
+ var asBytes = Encoding.UTF8.GetBytes(evalShaLargeScript);
+ largeScriptEvals.AddRange(asBytes);
+ }
+ evalShaLargeScriptBuffer = GC.AllocateUninitializedArray(largeScriptEvals.Count, pinned: true);
+ largeScriptEvals.CopyTo(evalShaLargeScriptBuffer);
+ evalShaLargeScriptBufferPointer = (byte*)Unsafe.AsPointer(ref evalShaLargeScriptBuffer[0]);
}
[Benchmark]
@@ -81,5 +265,23 @@ public void EvalSha()
{
_ = session.TryConsumeMessages(evalShaRequestBufferPointer, evalShaRequestBuffer.Length);
}
+
+ [Benchmark]
+ public void SmallScript()
+ {
+ _ = session.TryConsumeMessages(evalShaSmallScriptBufferPointer, evalShaSmallScriptBuffer.Length);
+ }
+
+ [Benchmark]
+ public void LargeScript()
+ {
+ _ = session.TryConsumeMessages(evalShaLargeScriptBufferPointer, evalShaLargeScriptBuffer.Length);
+ }
+
+ [Benchmark]
+ public void ArrayReturn()
+ {
+ _ = session.TryConsumeMessages(arrayReturnRequestBufferPointer, arrayReturnRequestBuffer.Length);
+ }
}
}
\ No newline at end of file
diff --git a/libs/common/RespReadUtils.cs b/libs/common/RespReadUtils.cs
index 8b93be12e0..a82ed7ee87 100644
--- a/libs/common/RespReadUtils.cs
+++ b/libs/common/RespReadUtils.cs
@@ -957,50 +957,6 @@ public static bool ReadStringArrayWithLengthHeader(out string[] result, ref byte
return true;
}
- ///
- /// Read string array with length header
- ///
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- public static bool ReadStringArrayResponseWithLengthHeader(out string[] result, ref byte* ptr, byte* end)
- {
- result = null;
-
- // Parse RESP array header
- if (!ReadSignedArrayLength(out var length, ref ptr, end))
- {
- return false;
- }
-
- if (length < 0)
- {
- // NULL value ('*-1\r\n')
- return true;
- }
-
- // Parse individual strings in the array
- result = new string[length];
- for (var i = 0; i < length; i++)
- {
- if (*ptr == '$')
- {
- if (!ReadStringResponseWithLengthHeader(out result[i], ref ptr, end))
- return false;
- }
- else if (*ptr == '+')
- {
- if (!ReadSimpleString(out result[i], ref ptr, end))
- return false;
- }
- else
- {
- if (!ReadIntegerAsString(out result[i], ref ptr, end))
- return false;
- }
- }
-
- return true;
- }
-
///
/// Read double with length header
///
diff --git a/libs/host/Garnet.host.csproj b/libs/host/Garnet.host.csproj
index a4d034dd85..1c1259fec3 100644
--- a/libs/host/Garnet.host.csproj
+++ b/libs/host/Garnet.host.csproj
@@ -29,7 +29,7 @@
-
+
diff --git a/libs/server/ArgSlice/ScratchBufferManager.cs b/libs/server/ArgSlice/ScratchBufferManager.cs
index 68861d3985..e2106bfaa7 100644
--- a/libs/server/ArgSlice/ScratchBufferManager.cs
+++ b/libs/server/ArgSlice/ScratchBufferManager.cs
@@ -7,7 +7,6 @@
using System.Runtime.CompilerServices;
using System.Text;
using Garnet.common;
-using NLua;
namespace Garnet.server
{
@@ -43,6 +42,12 @@ public ScratchBufferManager()
///
public void Reset() => scratchBufferOffset = 0;
+ ///
+ /// Return the full buffer managed by this .
+ ///
+ public Span FullBuffer()
+ => scratchBuffer;
+
///
/// Rewind (pop) the last entry of scratch buffer (rewinding the current scratch buffer offset),
/// if it contains the given ArgSlice
@@ -218,95 +223,62 @@ public ArgSlice FormatScratch(int headerSize, ReadOnlySpan arg)
}
///
- /// Format specified command with arguments, as a RESP command. Lua state
- /// can be specified to handle Lua tables as arguments.
+ /// Start a RESP array to hold a command and arguments.
+ ///
+ /// Fill it with calls to and/or .
///
- public ArgSlice FormatCommandAsResp(string cmd, object[] args, Lua state)
+ public void StartCommand(ReadOnlySpan cmd, int argCount)
{
if (scratchBuffer == null)
ExpandScratchBuffer(64);
- scratchBufferOffset += 10; // Reserve space for the array length if it is larger than expected
- int commandStartOffset = scratchBufferOffset;
- byte* ptr = scratchBufferHead + scratchBufferOffset;
+ var ptr = scratchBufferHead + scratchBufferOffset;
- while (!RespWriteUtils.WriteArrayLength(args.Length + 1, ref ptr, scratchBufferHead + scratchBuffer.Length))
+ while (!RespWriteUtils.WriteArrayLength(argCount + 1, ref ptr, scratchBufferHead + scratchBuffer.Length))
{
ExpandScratchBuffer(scratchBuffer.Length + 1);
ptr = scratchBufferHead + scratchBufferOffset;
}
scratchBufferOffset = (int)(ptr - scratchBufferHead);
- while (!RespWriteUtils.WriteAsciiBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length))
+ while (!RespWriteUtils.WriteBulkString(cmd, ref ptr, scratchBufferHead + scratchBuffer.Length))
{
ExpandScratchBuffer(scratchBuffer.Length + 1);
ptr = scratchBufferHead + scratchBufferOffset;
}
scratchBufferOffset = (int)(ptr - scratchBufferHead);
+ }
- int count = 1;
- foreach (var item in args)
+ ///
+ /// Use to fill a RESP array with arguments after a call to .
+ ///
+ public void WriteNullArgument()
+ {
+ var ptr = scratchBufferHead + scratchBufferOffset;
+
+ while (!RespWriteUtils.WriteNull(ref ptr, scratchBufferHead + scratchBuffer.Length))
{
- if (item is string str)
- {
- count++;
- while (!RespWriteUtils.WriteAsciiBulkString(str, ref ptr, scratchBufferHead + scratchBuffer.Length))
- {
- ExpandScratchBuffer(scratchBuffer.Length + 1);
- ptr = scratchBufferHead + scratchBufferOffset;
- }
- scratchBufferOffset = (int)(ptr - scratchBufferHead);
- }
- else if (item is LuaTable t)
- {
- var d = state.GetTableDict(t);
- foreach (var value in d.Values)
- {
- count++;
- while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(value), ref ptr, scratchBufferHead + scratchBuffer.Length))
- {
- ExpandScratchBuffer(scratchBuffer.Length + 1);
- ptr = scratchBufferHead + scratchBufferOffset;
- }
- scratchBufferOffset = (int)(ptr - scratchBufferHead);
- }
- t.Dispose();
- }
- else if (item is long i)
- {
- count++;
- while (!RespWriteUtils.WriteIntegerAsBulkString((int)i, ref ptr, scratchBufferHead + scratchBuffer.Length))
- {
- ExpandScratchBuffer(scratchBuffer.Length + 1);
- ptr = scratchBufferHead + scratchBufferOffset;
- }
- scratchBufferOffset = (int)(ptr - scratchBufferHead);
- }
- else
- {
- count++;
- while (!RespWriteUtils.WriteAsciiBulkString(Convert.ToString(item), ref ptr, scratchBufferHead + scratchBuffer.Length))
- {
- ExpandScratchBuffer(scratchBuffer.Length + 1);
- ptr = scratchBufferHead + scratchBufferOffset;
- }
- scratchBufferOffset = (int)(ptr - scratchBufferHead);
- }
+ ExpandScratchBuffer(scratchBuffer.Length + 1);
+ ptr = scratchBufferHead + scratchBufferOffset;
}
- if (count != args.Length + 1)
+
+ scratchBufferOffset = (int)(ptr - scratchBufferHead);
+ }
+
+ ///
+ /// Use to fill a RESP array with arguments after a call to .
+ ///
+ public void WriteArgument(ReadOnlySpan arg)
+ {
+ var ptr = scratchBufferHead + scratchBufferOffset;
+
+ while (!RespWriteUtils.WriteBulkString(arg, ref ptr, scratchBufferHead + scratchBuffer.Length))
{
- int extraSpace = NumUtils.NumDigits(count) - NumUtils.NumDigits(args.Length + 1);
- if (commandStartOffset < extraSpace)
- throw new InvalidOperationException("Invalid number of arguments");
- commandStartOffset -= extraSpace;
- var head = scratchBufferHead + commandStartOffset;
- // There should be space as we have reserved it
- _ = RespWriteUtils.WriteArrayLength(count, ref head, scratchBufferHead + scratchBuffer.Length);
+ ExpandScratchBuffer(scratchBuffer.Length + 1);
+ ptr = scratchBufferHead + scratchBufferOffset;
}
- var retVal = new ArgSlice(scratchBufferHead + commandStartOffset, scratchBufferOffset - commandStartOffset);
- Debug.Assert(scratchBufferOffset <= scratchBuffer.Length);
- return retVal;
+ scratchBufferOffset = (int)(ptr - scratchBufferHead);
}
///
@@ -351,5 +323,20 @@ public ArgSlice GetSliceFromTail(int length)
{
return new ArgSlice(scratchBufferHead + scratchBufferOffset - length, length);
}
+
+ ///
+ /// Force backing buffer to grow.
+ ///
+ public void GrowBuffer()
+ {
+ if (scratchBuffer == null)
+ {
+ ExpandScratchBuffer(64);
+ }
+ else
+ {
+ ExpandScratchBuffer(scratchBuffer.Length + 1);
+ }
+ }
}
}
\ No newline at end of file
diff --git a/libs/server/Garnet.server.csproj b/libs/server/Garnet.server.csproj
index e3abc5b5c2..04f91c2909 100644
--- a/libs/server/Garnet.server.csproj
+++ b/libs/server/Garnet.server.csproj
@@ -21,7 +21,7 @@
-
+
\ No newline at end of file
diff --git a/libs/server/Lua/LuaCommands.cs b/libs/server/Lua/LuaCommands.cs
index f306262719..f9a4df47f8 100644
--- a/libs/server/Lua/LuaCommands.cs
+++ b/libs/server/Lua/LuaCommands.cs
@@ -2,11 +2,9 @@
// Licensed under the MIT license.
using System;
-using System.Buffers;
+using System.Collections.Generic;
using Garnet.common;
using Microsoft.Extensions.Logging;
-using NLua;
-using NLua.Exceptions;
using Tsavorite.core;
namespace Garnet.server
@@ -31,22 +29,27 @@ private unsafe bool TryEVALSHA()
}
ref var digest = ref parseState.GetArgSliceByRef(0);
- AsciiUtils.ToLowerInPlace(digest.Span);
- var digestAsSpanByteMem = new SpanByteAndMemory(digest.SpanByte);
+ LuaRunner runner = null;
- var result = false;
- if (!sessionScriptCache.TryGetFromDigest(digestAsSpanByteMem, out var runner))
+ // Length check is mandatory, as ScriptHashKey assumes correct length
+ if (digest.length == SessionScriptCache.SHA1Len)
{
- if (storeWrapper.storeScriptCache.TryGetValue(digestAsSpanByteMem, out var source))
+ AsciiUtils.ToLowerInPlace(digest.Span);
+
+ var scriptKey = new ScriptHashKey(digest.Span);
+
+ if (!sessionScriptCache.TryGetFromDigest(scriptKey, out runner))
{
- if (!sessionScriptCache.TryLoad(source, digestAsSpanByteMem, out runner, out var error))
+ if (storeWrapper.storeScriptCache.TryGetValue(scriptKey, out var source))
{
- while (!RespWriteUtils.WriteError(error, ref dcurr, dend))
- SendAndReset();
+ if (!sessionScriptCache.TryLoad(this, source, scriptKey, out runner, out _, out var error))
+ {
+ // TryLoad will have written an error out, it any
- _ = storeWrapper.storeScriptCache.TryRemove(digestAsSpanByteMem, out _);
- return result;
+ _ = storeWrapper.storeScriptCache.TryRemove(scriptKey, out _);
+ return true;
+ }
}
}
}
@@ -58,10 +61,10 @@ private unsafe bool TryEVALSHA()
}
else
{
- result = ExecuteScript(count - 1, runner);
+ ExecuteScript(count - 1, runner);
}
- return result;
+ return true;
}
@@ -82,19 +85,16 @@ private unsafe bool TryEVAL()
return AbortWithWrongNumberOfArguments("EVAL");
}
- var script = parseState.GetArgSliceByRef(0).ToArray();
+ ref var script = ref parseState.GetArgSliceByRef(0);
// that this is stack allocated is load bearing - if it moves, things will break
Span digest = stackalloc byte[SessionScriptCache.SHA1Len];
- sessionScriptCache.GetScriptDigest(script, digest);
+ sessionScriptCache.GetScriptDigest(script.ReadOnlySpan, digest);
- var result = false;
- if (!sessionScriptCache.TryLoad(script, new SpanByteAndMemory(SpanByte.FromPinnedSpan(digest)), out var runner, out var error))
+ if (!sessionScriptCache.TryLoad(this, script.ReadOnlySpan, new ScriptHashKey(digest), out var runner, out _, out var error))
{
- while (!RespWriteUtils.WriteError(error, ref dcurr, dend))
- SendAndReset();
-
- return result;
+ // TryLoad will have written any errors out
+ return true;
}
if (runner == null)
@@ -104,10 +104,10 @@ private unsafe bool TryEVAL()
}
else
{
- result = ExecuteScript(count - 1, runner);
+ ExecuteScript(count - 1, runner);
}
- return result;
+ return true;
}
///
@@ -125,7 +125,7 @@ private bool NetworkScriptExists()
return AbortWithWrongNumberOfArguments("script|exists");
}
- // returns an array where each element is a 0 if the script does not exist, and a 1 if it does
+ // Returns an array where each element is a 0 if the script does not exist, and a 1 if it does
while (!RespWriteUtils.WriteArrayLength(parseState.Count, ref dcurr, dend))
SendAndReset();
@@ -133,11 +133,17 @@ private bool NetworkScriptExists()
for (var shaIx = 0; shaIx < parseState.Count; shaIx++)
{
ref var sha1 = ref parseState.GetArgSliceByRef(shaIx);
- AsciiUtils.ToLowerInPlace(sha1.Span);
+ var exists = 0;
- var sha1Arg = new SpanByteAndMemory(sha1.SpanByte);
+ // Length check is required, as ScriptHashKey makes a hard assumption
+ if (sha1.length == SessionScriptCache.SHA1Len)
+ {
+ AsciiUtils.ToLowerInPlace(sha1.Span);
+
+ var sha1Arg = new ScriptHashKey(sha1.Span);
- var exists = storeWrapper.storeScriptCache.ContainsKey(sha1Arg) ? 1 : 0;
+ exists = storeWrapper.storeScriptCache.ContainsKey(sha1Arg) ? 1 : 0;
+ }
while (!RespWriteUtils.WriteArrayItem(exists, ref dcurr, dend))
SendAndReset();
@@ -202,18 +208,26 @@ private bool NetworkScriptLoad()
return AbortWithWrongNumberOfArguments("script|load");
}
- var source = parseState.GetArgSliceByRef(0).ToArray();
- if (!sessionScriptCache.TryLoad(source, out var digest, out _, out var error))
- {
- while (!RespWriteUtils.WriteError(error, ref dcurr, dend))
- SendAndReset();
- }
- else
+ ref var source = ref parseState.GetArgSliceByRef(0);
+
+ Span digest = stackalloc byte[SessionScriptCache.SHA1Len];
+ sessionScriptCache.GetScriptDigest(source.Span, digest);
+
+ if (sessionScriptCache.TryLoad(this, source.ReadOnlySpan, new(digest), out _, out var digestOnHeap, out var error))
{
+ // TryLoad will write any errors out
// Add script to the store dictionary
- var scriptKey = new SpanByteAndMemory(new ScriptHashOwner(digest.AsMemory()), digest.Length);
- _ = storeWrapper.storeScriptCache.TryAdd(scriptKey, source);
+ if (digestOnHeap == null)
+ {
+ var newAlloc = GC.AllocateUninitializedArray(SessionScriptCache.SHA1Len, pinned: true);
+ digest.CopyTo(newAlloc);
+ _ = storeWrapper.storeScriptCache.TryAdd(new(newAlloc), source.ToArray());
+ }
+ else
+ {
+ _ = storeWrapper.storeScriptCache.TryAdd(digestOnHeap.Value, source.ToArray());
+ }
while (!RespWriteUtils.WriteBulkString(digest, ref dcurr, dend))
SendAndReset();
@@ -243,120 +257,17 @@ private bool CheckLuaEnabled()
///
/// Invoke the execution of a server-side Lua script.
///
- ///
- ///
- ///
- private unsafe bool ExecuteScript(int count, LuaRunner scriptRunner)
+ private void ExecuteScript(int count, LuaRunner scriptRunner)
{
try
{
- var scriptResult = scriptRunner.Run(count, parseState);
- WriteObject(scriptResult);
- }
- catch (LuaScriptException ex)
- {
- logger?.LogError(ex.InnerException ?? ex, "Error executing Lua script callback");
- while (!RespWriteUtils.WriteError("ERR " + (ex.InnerException ?? ex).Message, ref dcurr, dend))
- SendAndReset();
- return true;
+ scriptRunner.RunForSession(count, this);
}
catch (Exception ex)
{
logger?.LogError(ex, "Error executing Lua script");
while (!RespWriteUtils.WriteError("ERR " + ex.Message, ref dcurr, dend))
SendAndReset();
- return true;
- }
- return true;
- }
-
- void WriteObject(object scriptResult)
- {
- if (scriptResult != null)
- {
- if (scriptResult is string s)
- {
- while (!RespWriteUtils.WriteAsciiBulkString(s, ref dcurr, dend))
- SendAndReset();
- }
- else if ((scriptResult as byte?) != null && (byte)scriptResult == 36) //equals to $
- {
- while (!RespWriteUtils.WriteDirect((byte[])scriptResult, ref dcurr, dend))
- SendAndReset();
- }
- else if (scriptResult is bool b)
- {
- if (b)
- {
- while (!RespWriteUtils.WriteInteger(1, ref dcurr, dend))
- SendAndReset();
- }
- else
- {
- while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend))
- SendAndReset();
- }
- }
- else if (scriptResult is long l)
- {
- while (!RespWriteUtils.WriteInteger(l, ref dcurr, dend))
- SendAndReset();
- }
- else if (scriptResult is ArgSlice a)
- {
- while (!RespWriteUtils.WriteBulkString(a.ReadOnlySpan, ref dcurr, dend))
- SendAndReset();
- }
- else if (scriptResult is object[] o)
- {
- // Two objects one boolean value and the result from the Lua Call
- while (!RespWriteUtils.WriteAsciiBulkString(o[1].ToString().AsSpan(), ref dcurr, dend))
- SendAndReset();
- }
- else if (scriptResult is LuaTable luaTable)
- {
- try
- {
- var retVal = luaTable["err"];
- if (retVal != null)
- {
- while (!RespWriteUtils.WriteError((string)retVal, ref dcurr, dend))
- SendAndReset();
- }
- else
- {
- retVal = luaTable["ok"];
- if (retVal != null)
- {
- while (!RespWriteUtils.WriteAsciiBulkString((string)retVal, ref dcurr, dend))
- SendAndReset();
- }
- else
- {
- int count = luaTable.Values.Count;
- while (!RespWriteUtils.WriteArrayLength(count, ref dcurr, dend))
- SendAndReset();
- foreach (var value in luaTable.Values)
- {
- WriteObject(value);
- }
- }
- }
- }
- finally
- {
- luaTable.Dispose();
- }
- }
- else
- {
- throw new LuaScriptException("Unknown return type", "");
- }
- }
- else
- {
- while (!RespWriteUtils.WriteDirect(CmdStrings.RESP_ERRNOTFOUND, ref dcurr, dend))
- SendAndReset();
}
}
}
diff --git a/libs/server/Lua/LuaRunner.cs b/libs/server/Lua/LuaRunner.cs
index ea2ef3f0dd..38851209da 100644
--- a/libs/server/Lua/LuaRunner.cs
+++ b/libs/server/Lua/LuaRunner.cs
@@ -2,11 +2,14 @@
// Licensed under the MIT license.
using System;
-using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
using System.Text;
using Garnet.common;
+using KeraLua;
using Microsoft.Extensions.Logging;
-using NLua;
namespace Garnet.server
{
@@ -15,121 +18,439 @@ namespace Garnet.server
///
internal sealed class LuaRunner : IDisposable
{
- readonly string source;
+ ///
+ /// Adapter to allow us to write directly to the network
+ /// when in Garnet and still keep script runner work.
+ ///
+ private unsafe interface IResponseAdapter
+ {
+ ///
+ /// Equivalent to the ref curr we pass into methods.
+ ///
+ ref byte* BufferCur { get; }
+
+ ///
+ /// Equivalent to the end we pass into methods.
+ ///
+ byte* BufferEnd { get; }
+
+ ///
+ /// Equivalent to .
+ ///
+ void SendAndReset();
+ }
+
+ ///
+ /// Adapter so script results go directly
+ /// to the network.
+ ///
+ private readonly struct RespResponseAdapter : IResponseAdapter
+ {
+ private readonly RespServerSession session;
+
+ internal RespResponseAdapter(RespServerSession session)
+ {
+ this.session = session;
+ }
+
+ ///
+ public unsafe ref byte* BufferCur
+ => ref session.dcurr;
+
+ ///
+ public unsafe byte* BufferEnd
+ => session.dend;
+
+ ///
+ public void SendAndReset()
+ => session.SendAndReset();
+ }
+
+ ///
+ /// For the runner, put output into an array.
+ ///
+ private unsafe struct RunnerAdapter : IResponseAdapter
+ {
+ private readonly ScratchBufferManager bufferManager;
+ private byte* origin;
+ private byte* curHead;
+ private byte* curEnd;
+
+ internal RunnerAdapter(ScratchBufferManager bufferManager)
+ {
+ this.bufferManager = bufferManager;
+ this.bufferManager.Reset();
+
+ var scratchSpace = bufferManager.FullBuffer();
+
+ origin = curHead = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace));
+ curEnd = curHead + scratchSpace.Length;
+ }
+
+#pragma warning disable CS9084 // Struct member returns 'this' or other instance members by reference
+ ///
+ public unsafe ref byte* BufferCur
+ => ref curHead;
+#pragma warning restore CS9084
+
+ ///
+ public unsafe byte* BufferEnd
+ => curEnd;
+
+ ///
+ /// Gets a span that covers the responses as written so far.
+ ///
+ public readonly ReadOnlySpan Response
+ {
+ get
+ {
+ var len = (int)(curHead - origin);
+
+ var full = bufferManager.FullBuffer();
+
+ return full[..len];
+ }
+ }
+
+ ///
+ public void SendAndReset()
+ {
+ var len = (int)(curHead - origin);
+
+ // We don't actually send anywhere, we grow the backing array
+ bufferManager.GrowBuffer();
+
+ var scratchSpace = bufferManager.FullBuffer();
+
+ origin = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(scratchSpace));
+ curEnd = origin + scratchSpace.Length;
+ curHead = origin + len;
+ }
+ }
+
+ const string LoaderBlock = @"
+import = function () end
+redis = {}
+function redis.call(...)
+ return garnet_call(...)
+end
+function redis.status_reply(text)
+ return text
+end
+function redis.error_reply(text)
+ return { err = 'ERR ' .. text }
+end
+KEYS = {}
+ARGV = {}
+sandbox_env = {
+ _G = _G;
+ _VERSION = _VERSION;
+
+ assert = assert;
+ collectgarbage = collectgarbage;
+ coroutine = coroutine;
+ error = error;
+ gcinfo = gcinfo;
+ -- explicitly not allowing getfenv
+ getmetatable = getmetatable;
+ ipairs = ipairs;
+ load = load;
+ loadstring = loadstring;
+ math = math;
+ next = next;
+ pairs = pairs;
+ pcall = pcall;
+ rawequal = rawequal;
+ rawget = rawget;
+ rawset = rawset;
+ redis = redis;
+ select = select;
+ -- explicitly not allowing setfenv
+ string = string;
+ setmetatable = setmetatable;
+ table = table;
+ tonumber = tonumber;
+ tostring = tostring;
+ type = type;
+ unpack = table.unpack;
+ xpcall = xpcall;
+
+ KEYS = KEYS;
+ ARGV = ARGV;
+}
+-- do resets in the Lua side to minimize pinvokes
+function reset_keys_and_argv(fromKey, fromArgv)
+ local keyCount = #KEYS
+ for i = fromKey, keyCount do
+ KEYS[i] = nil
+ end
+
+ local argvCount = #ARGV
+ for i = fromArgv, argvCount do
+ ARGV[i] = nil
+ end
+end
+-- responsible for sandboxing user provided code
+function load_sandboxed(source)
+ if (not source) then return nil end
+ local rawFunc, err = load(source, nil, nil, sandbox_env)
+
+ -- compilation error is returned directly
+ if err then
+ return rawFunc, err
+ end
+
+ -- otherwise we wrap the compiled function in a helper
+ return function()
+ local rawRet = rawFunc()
+
+ -- handle ok response wrappers without crossing the pinvoke boundary
+ -- err response wrapper requires a bit more work, but is also rarer
+ if rawRet and type(rawRet) == ""table"" and rawRet.ok then
+ return rawRet.ok
+ end
+
+ return rawRet
+ end
+end
+";
+
+ private static readonly ReadOnlyMemory LoaderBlockBytes = Encoding.UTF8.GetBytes(LoaderBlock);
+
+ // Rooted to keep function pointer alive
+ readonly LuaFunction garnetCall;
+
+ // References into Registry on the Lua side
+ //
+ // These are mix of objects we regularly update,
+ // constants we want to avoid copying from .NET to Lua,
+ // and the compiled function definition.
+ readonly int sandboxEnvRegistryIndex;
+ readonly int keysTableRegistryIndex;
+ readonly int argvTableRegistryIndex;
+ readonly int loadSandboxedRegistryIndex;
+ readonly int resetKeysAndArgvRegistryIndex;
+ readonly int okConstStringRegistryIndex;
+ readonly int errConstStringRegistryIndex;
+ readonly int noSessionAvailableConstStringRegistryIndex;
+ readonly int pleaseSpecifyRedisCallConstStringRegistryIndex;
+ readonly int errNoAuthConstStringRegistryIndex;
+ readonly int errUnknownConstStringRegistryIndex;
+ readonly int errBadArgConstStringRegistryIndex;
+ int functionRegistryIndex;
+
+ readonly ReadOnlyMemory source;
readonly ScratchBufferNetworkSender scratchBufferNetworkSender;
readonly RespServerSession respServerSession;
+
readonly ScratchBufferManager scratchBufferManager;
readonly ILogger logger;
- readonly Lua state;
- readonly LuaTable sandbox_env;
- LuaFunction function;
readonly TxnKeyEntries txnKeyEntries;
readonly bool txnMode;
- readonly LuaFunction garnetCall;
- readonly LuaTable keyTable, argvTable;
+
+ // This cannot be readonly, as it is a mutable struct
+ LuaStateWrapper state;
+
int keyLength, argvLength;
- Queue disposeQueue;
///
/// Creates a new runner with the source of the script
///
- public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null)
+ public LuaRunner(ReadOnlyMemory source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null)
{
this.source = source;
this.txnMode = txnMode;
this.respServerSession = respServerSession;
this.scratchBufferNetworkSender = scratchBufferNetworkSender;
- this.scratchBufferManager = respServerSession?.scratchBufferManager;
+ this.scratchBufferManager = respServerSession?.scratchBufferManager ?? new();
this.logger = logger;
- state = new Lua();
- state.State.Encoding = Encoding.UTF8;
+ sandboxEnvRegistryIndex = -1;
+ keysTableRegistryIndex = -1;
+ argvTableRegistryIndex = -1;
+ loadSandboxedRegistryIndex = -1;
+ functionRegistryIndex = -1;
+
+ // TODO: custom allocator?
+ state = new LuaStateWrapper(new Lua());
+
if (txnMode)
{
- this.txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext);
- garnetCall = state.RegisterFunction("garnet_call", this, this.GetType().GetMethod(nameof(garnet_call_txn)));
+ txnKeyEntries = new TxnKeyEntries(16, respServerSession.storageSession.lockableContext, respServerSession.storageSession.objectStoreLockableContext);
+
+ garnetCall = garnet_call_txn;
}
else
{
- garnetCall = state.RegisterFunction("garnet_call", this, this.GetType().GetMethod("garnet_call"));
- }
- _ = state.DoString(@"
- import = function () end
- redis = {}
- function redis.call(cmd, ...)
- return garnet_call(cmd, ...)
- end
- function redis.status_reply(text)
- return text
- end
- function redis.error_reply(text)
- return { err = text }
- end
- KEYS = {}
- ARGV = {}
- sandbox_env = {
- tostring = tostring;
- next = next;
- assert = assert;
- tonumber = tonumber;
- rawequal = rawequal;
- collectgarbage = collectgarbage;
- coroutine = coroutine;
- type = type;
- select = select;
- unpack = table.unpack;
- gcinfo = gcinfo;
- pairs = pairs;
- loadstring = loadstring;
- ipairs = ipairs;
- error = error;
- redis = redis;
- math = math;
- table = table;
- string = string;
- KEYS = KEYS;
- ARGV = ARGV;
- }
- function load_sandboxed(source)
- if (not source) then return nil end
- return load(source, nil, nil, sandbox_env)
- end
- ");
- sandbox_env = (LuaTable)state["sandbox_env"];
- keyTable = (LuaTable)state["KEYS"];
- argvTable = (LuaTable)state["ARGV"];
+ garnetCall = garnet_call;
+ }
+
+ var loadRes = state.LoadBuffer(LoaderBlockBytes.Span);
+ if (loadRes != LuaStatus.OK)
+ {
+ throw new GarnetException("Could load loader into Lua");
+ }
+
+ var sandboxRes = state.PCall(0, -1);
+ if (sandboxRes != LuaStatus.OK)
+ {
+ throw new GarnetException("Could not initialize Lua sandbox state");
+ }
+
+ // Register garnet_call in global namespace
+ state.Register("garnet_call", garnetCall);
+
+ state.GetGlobal(LuaType.Table, "sandbox_env");
+ sandboxEnvRegistryIndex = state.Ref();
+
+ state.GetGlobal(LuaType.Table, "KEYS");
+ keysTableRegistryIndex = state.Ref();
+
+ state.GetGlobal(LuaType.Table, "ARGV");
+ argvTableRegistryIndex = state.Ref();
+
+ state.GetGlobal(LuaType.Function, "load_sandboxed");
+ loadSandboxedRegistryIndex = state.Ref();
+
+ state.GetGlobal(LuaType.Function, "reset_keys_and_argv");
+ resetKeysAndArgvRegistryIndex = state.Ref();
+
+ // Commonly used strings, register them once so we don't have to copy them over each time we need them
+ okConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_OK);
+ errConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_err);
+ noSessionAvailableConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_No_session_available);
+ pleaseSpecifyRedisCallConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Please_specify_at_least_one_argument_for_this_redis_lib_call);
+ errNoAuthConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.RESP_ERR_NOAUTH);
+ errUnknownConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Unknown_Redis_command_called_from_script);
+ errBadArgConstStringRegistryIndex = ConstantStringToRegistry(CmdStrings.LUA_ERR_Lua_redis_lib_command_arguments_must_be_strings_or_integers);
+
+ state.ExpectLuaStackEmpty();
}
///
/// Creates a new runner with the source of the script
///
- public LuaRunner(ReadOnlySpan source, bool txnMode, RespServerSession respServerSession, ScratchBufferNetworkSender scratchBufferNetworkSender, ILogger logger = null)
- : this(Encoding.UTF8.GetString(source), txnMode, respServerSession, scratchBufferNetworkSender, logger)
+ public LuaRunner(string source, bool txnMode = false, RespServerSession respServerSession = null, ScratchBufferNetworkSender scratchBufferNetworkSender = null, ILogger logger = null)
+ : this(Encoding.UTF8.GetBytes(source), txnMode, respServerSession, scratchBufferNetworkSender, logger)
+ {
+ }
+
+ ///
+ /// Some strings we use a bunch, and copying them to Lua each time is wasteful
+ ///
+ /// So instead we stash them in the Registry and load them by index
+ ///
+ int ConstantStringToRegistry(ReadOnlySpan str)
+ {
+ state.PushBuffer(str);
+ return state.Ref();
+ }
+
+ ///
+ /// Compile script for running in a .NET host.
+ ///
+ /// Errors are raised as exceptions.
+ ///
+ public unsafe void CompileForRunner()
+ {
+ var adapter = new RunnerAdapter(scratchBufferManager);
+ CompileCommon(ref adapter);
+
+ var resp = adapter.Response;
+ var respStart = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp));
+ var respEnd = respStart + resp.Length;
+ if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respStart, respEnd))
+ {
+ var errStr = Encoding.UTF8.GetString(errSpan);
+ throw new GarnetException(errStr);
+ }
+ }
+
+ ///
+ /// Compile script for a .
+ ///
+ /// Any errors encountered are written out as Resp errors.
+ ///
+ public void CompileForSession(RespServerSession session)
{
+ var adapter = new RespResponseAdapter(session);
+ CompileCommon(ref adapter);
}
///
- /// Compile script
+ /// Drops compiled function, just for benchmarking purposes.
///
- public void Compile()
+ public void ResetCompilation()
{
+ if (functionRegistryIndex != -1)
+ {
+ state.Unref(LuaRegistry.Index, functionRegistryIndex);
+ functionRegistryIndex = -1;
+ }
+ }
+
+ ///
+ /// Compile script, writing errors out to given response.
+ ///
+ unsafe void CompileCommon(ref TResponse resp)
+ where TResponse : struct, IResponseAdapter
+ {
+ const int NeededStackSpace = 2;
+
+ Debug.Assert(functionRegistryIndex == -1, "Shouldn't compile multiple times");
+
+ state.ExpectLuaStackEmpty();
+
try
{
- using var loader = (LuaFunction)state["load_sandboxed"];
- var result = loader.Call(source);
- if (result?.Length == 1)
+ state.ForceMinimumStackCapacity(NeededStackSpace);
+
+ state.PushInteger(loadSandboxedRegistryIndex);
+ _ = state.RawGet(LuaType.Function, (int)LuaRegistry.Index);
+
+ state.PushBuffer(source.Span);
+ state.Call(1, -1); // Multiple returns allowed
+
+ var numRets = state.StackTop;
+
+ if (numRets == 0)
{
- function = result[0] as LuaFunction;
+ while (!RespWriteUtils.WriteError("Shouldn't happen, no returns from load_sandboxed"u8, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
+
return;
}
+ else if (numRets == 1)
+ {
+ var returnType = state.Type(1);
+ if (returnType != LuaType.Function)
+ {
+ var errStr = $"Could not compile function, got back a {returnType}";
+ while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
- if (result?.Length == 2)
+ return;
+ }
+
+ functionRegistryIndex = state.Ref();
+ }
+ else if (numRets == 2)
{
- throw new GarnetException($"Compilation error: {(string)result[1]}");
+ state.CheckBuffer(2, out var errorBuf);
+
+ var errStr = $"Compilation error: {Encoding.UTF8.GetString(errorBuf)}";
+ while (!RespWriteUtils.WriteError(errStr, ref resp.BufferCur, resp.BufferEnd))
+ resp.SendAndReset();
+
+ state.Pop(2);
+
+ return;
}
else
{
- throw new GarnetException($"Unable to load script");
+ state.Pop(numRets);
+
+ throw new GarnetException($"Unexpected error compiling, got too many replies back: reply count = {numRets}");
}
}
catch (Exception ex)
@@ -137,6 +458,10 @@ public void Compile()
logger?.LogError(ex, "CreateFunction threw an exception");
throw;
}
+ finally
+ {
+ state.ExpectLuaStackEmpty();
+ }
}
///
@@ -144,184 +469,403 @@ public void Compile()
///
public void Dispose()
{
- garnetCall?.Dispose();
- keyTable?.Dispose();
- argvTable?.Dispose();
- sandbox_env?.Dispose();
- function?.Dispose();
- state?.Dispose();
+ state.Dispose();
}
///
/// Entry point for redis.call method from a Lua script (non-transactional mode)
///
- ///
- /// Parameters
- ///
- public object garnet_call(string cmd, params object[] args)
- => respServerSession == null ? null : ProcessCommandFromScripting(respServerSession.basicGarnetApi, cmd, args);
+ public int garnet_call(IntPtr luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ if (respServerSession == null)
+ {
+ return NoSessionResponse();
+ }
+
+ return ProcessCommandFromScripting(respServerSession.basicGarnetApi);
+ }
///
/// Entry point for redis.call method from a Lua script (transactional mode)
///
- ///
- /// Parameters
- ///
- public object garnet_call_txn(string cmd, params object[] args)
- => respServerSession == null ? null : ProcessCommandFromScripting(respServerSession.lockableGarnetApi, cmd, args);
+ public int garnet_call_txn(IntPtr luaStatePtr)
+ {
+ state.CallFromLuaEntered(luaStatePtr);
+
+ if (respServerSession == null)
+ {
+ return NoSessionResponse();
+ }
+
+ return ProcessCommandFromScripting(respServerSession.lockableGarnetApi);
+ }
+
+ ///
+ /// Call somehow came in with no valid resp server session.
+ ///
+ /// This is used in benchmarking.
+ ///
+ int NoSessionResponse()
+ {
+ const int NeededStackSpace = 1;
+
+ state.ForceMinimumStackCapacity(NeededStackSpace);
+
+ state.PushNil();
+ return 1;
+ }
///
/// Entry point method for executing commands from a Lua Script
///
- unsafe object ProcessCommandFromScripting(TGarnetApi api, string cmd, params object[] args)
+ unsafe int ProcessCommandFromScripting(TGarnetApi api)
where TGarnetApi : IGarnetApi
{
- switch (cmd)
+ const int AdditionalStackSpace = 1;
+
+ try
{
+ var argCount = state.StackTop;
+
+ if (argCount == 0)
+ {
+ return LuaStaticError(pleaseSpecifyRedisCallConstStringRegistryIndex);
+ }
+
+ state.ForceMinimumStackCapacity(AdditionalStackSpace);
+
+ if (!state.CheckBuffer(1, out var cmdSpan))
+ {
+ return LuaStaticError(errBadArgConstStringRegistryIndex);
+ }
+
// We special-case a few performance-sensitive operations to directly invoke via the storage API
- case "SET" when args.Length == 2:
- case "set" when args.Length == 2:
+ if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "SET"u8) && argCount == 3)
+ {
+ if (!respServerSession.CheckACLPermissions(RespCommand.SET))
{
- if (!respServerSession.CheckACLPermissions(RespCommand.SET))
- return Encoding.ASCII.GetString(CmdStrings.RESP_ERR_NOAUTH);
- var key = scratchBufferManager.CreateArgSlice(Convert.ToString(args[0]));
- var value = scratchBufferManager.CreateArgSlice(Convert.ToString(args[1]));
- _ = api.SET(key, value);
- return "OK";
+ return LuaStaticError(errNoAuthConstStringRegistryIndex);
}
- case "GET":
- case "get":
+
+ if (!state.CheckBuffer(2, out var keySpan) || !state.CheckBuffer(3, out var valSpan))
{
- if (!respServerSession.CheckACLPermissions(RespCommand.GET))
- throw new Exception(Encoding.ASCII.GetString(CmdStrings.RESP_ERR_NOAUTH));
- var key = scratchBufferManager.CreateArgSlice(Convert.ToString(args[0]));
- var status = api.GET(key, out var value);
- if (status == GarnetStatus.OK)
- return value.ToString();
- return null;
+ return LuaStaticError(errBadArgConstStringRegistryIndex);
}
+
+ // Note these spans are implicitly pinned, as they're actually on the Lua stack
+ var key = ArgSlice.FromPinnedSpan(keySpan);
+ var value = ArgSlice.FromPinnedSpan(valSpan);
+
+ _ = api.SET(key, value);
+
+ state.PushConstantString(okConstStringRegistryIndex);
+ return 1;
+ }
+ else if (AsciiUtils.EqualsUpperCaseSpanIgnoringCase(cmdSpan, "GET"u8) && argCount == 2)
+ {
+ if (!respServerSession.CheckACLPermissions(RespCommand.GET))
+ {
+ return LuaStaticError(errNoAuthConstStringRegistryIndex);
+ }
+
+ if (!state.CheckBuffer(2, out var keySpan))
+ {
+ return LuaStaticError(errBadArgConstStringRegistryIndex);
+ }
+
+ // Span is (implicitly) pinned since it's actually on the Lua stack
+ var key = ArgSlice.FromPinnedSpan(keySpan);
+ var status = api.GET(key, out var value);
+ if (status == GarnetStatus.OK)
+ {
+ state.PushBuffer(value.ReadOnlySpan);
+ }
+ else
+ {
+ state.PushNil();
+ }
+
+ return 1;
+ }
+
// As fallback, we use RespServerSession with a RESP-formatted input. This could be optimized
// in future to provide parse state directly.
- default:
+
+ scratchBufferManager.Reset();
+ scratchBufferManager.StartCommand(cmdSpan, argCount - 1);
+
+ for (var i = 0; i < argCount - 1; i++)
+ {
+ var argIx = 2 + i;
+
+ var argType = state.Type(argIx);
+ if (argType == LuaType.Nil)
+ {
+ scratchBufferManager.WriteNullArgument();
+ }
+ else if (argType is LuaType.String or LuaType.Number)
+ {
+ // KnownStringToBuffer will coerce a number into a string
+ //
+ // Redis nominally converts numbers to integers, but in this case just ToStrings things
+ state.KnownStringToBuffer(argIx, out var span);
+
+ // Span remains pinned so long as we don't pop the stack
+ scratchBufferManager.WriteArgument(span);
+ }
+ else
{
- var request = scratchBufferManager.FormatCommandAsResp(cmd, args, state);
- _ = respServerSession.TryConsumeMessages(request.ptr, request.length);
- var response = scratchBufferNetworkSender.GetResponse();
- var result = ProcessResponse(response.ptr, response.length);
- scratchBufferNetworkSender.Reset();
- return result;
+ return LuaStaticError(errBadArgConstStringRegistryIndex);
}
+ }
+
+ var request = scratchBufferManager.ViewFullArgSlice();
+
+ // Once the request is formatted, we can release all the args on the Lua stack
+ //
+ // This keeps the stack size down for processing the response
+ state.Pop(argCount);
+
+ _ = respServerSession.TryConsumeMessages(request.ptr, request.length);
+
+ var response = scratchBufferNetworkSender.GetResponse();
+ var result = ProcessResponse(response.ptr, response.length);
+ scratchBufferNetworkSender.Reset();
+ return result;
+ }
+ catch (Exception e)
+ {
+ logger?.LogError(e, "During Lua script execution");
+
+ return state.RaiseError(e.Message);
}
}
///
- /// Process a RESP-formatted response from the RespServerSession
+ /// Cause a Lua error to be raised with a message previously registered.
///
- unsafe object ProcessResponse(byte* ptr, int length)
+ int LuaStaticError(int constStringRegistryIndex)
{
+ const int NeededStackSize = 1;
+
+ state.ForceMinimumStackCapacity(NeededStackSize);
+
+ state.PushConstantString(constStringRegistryIndex);
+ return state.RaiseErrorFromStack();
+ }
+
+ ///
+ /// Process a RESP-formatted response from the RespServerSession.
+ ///
+ /// Pushes result onto state stack and returns 1, or raises an error and never returns.
+ ///
+ unsafe int ProcessResponse(byte* ptr, int length)
+ {
+ const int NeededStackSize = 3;
+
+ state.ForceMinimumStackCapacity(NeededStackSize);
+
switch (*ptr)
{
case (byte)'+':
- if (RespReadUtils.ReadSimpleString(out var resultStr, ref ptr, ptr + length))
- return resultStr;
- break;
+ ptr++;
+ length--;
+ if (RespReadUtils.ReadAsSpan(out var resultSpan, ref ptr, ptr + length))
+ {
+ state.PushBuffer(resultSpan);
+ return 1;
+ }
+ goto default;
+
case (byte)':':
if (RespReadUtils.Read64Int(out var number, ref ptr, ptr + length))
- return number;
- break;
+ {
+ state.PushInteger(number);
+ return 1;
+ }
+ goto default;
+
case (byte)'-':
- if (RespReadUtils.ReadErrorAsString(out resultStr, ref ptr, ptr + length))
- return resultStr;
- break;
+ ptr++;
+ length--;
+ if (RespReadUtils.ReadAsSpan(out var errSpan, ref ptr, ptr + length))
+ {
+ if (errSpan.SequenceEqual(CmdStrings.RESP_ERR_GENERIC_UNK_CMD))
+ {
+ // Gets a special response
+ return LuaStaticError(errUnknownConstStringRegistryIndex);
+ }
+
+ state.PushBuffer(errSpan);
+ return state.RaiseErrorFromStack();
+
+ }
+ goto default;
case (byte)'$':
- if (RespReadUtils.ReadStringResponseWithLengthHeader(out resultStr, ref ptr, ptr + length))
- return resultStr;
- break;
+ if (length >= 5 && new ReadOnlySpan(ptr + 1, 4).SequenceEqual("-1\r\n"u8))
+ {
+ // Bulk null strings are mapped to FALSE
+ // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion
+ state.PushBoolean(false);
+
+ return 1;
+ }
+ else if (RespReadUtils.ReadSpanWithLengthHeader(out var bulkSpan, ref ptr, ptr + length))
+ {
+ state.PushBuffer(bulkSpan);
+
+ return 1;
+ }
+ goto default;
case (byte)'*':
- if (RespReadUtils.ReadStringArrayResponseWithLengthHeader(out var resultArray, ref ptr, ptr + length))
+ if (RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref ptr, ptr + length))
{
- // Create return table
- var returnValue = (LuaTable)state.DoString("return { }")[0];
-
- // Queue up for disposal at the end of the script call
- disposeQueue ??= new();
- disposeQueue.Enqueue(returnValue);
-
- // Populate the table
- var i = 1;
- foreach (var item in resultArray)
- returnValue[i++] = item == null ? false : item;
- return returnValue;
+ // Create the new table
+ state.CreateTable(itemCount, 0);
+
+ for (var itemIx = 0; itemIx < itemCount; itemIx++)
+ {
+ if (*ptr == '$')
+ {
+ // Bulk String
+ if (length >= 4 && new ReadOnlySpan(ptr + 1, 4).SequenceEqual("-1\r\n"u8))
+ {
+ // Null strings are mapped to false
+ // See: https://redis.io/docs/latest/develop/interact/programmability/lua-api/#lua-to-resp2-type-conversion
+ state.PushBoolean(false);
+ }
+ else if (RespReadUtils.ReadSpanWithLengthHeader(out var strSpan, ref ptr, ptr + length))
+ {
+ state.PushBuffer(strSpan);
+ }
+ else
+ {
+ // Error, drop the table we allocated
+ state.Pop(1);
+ goto default;
+ }
+ }
+ else
+ {
+ // In practice, we ONLY ever return bulk strings
+ // So just... not implementing the rest for now
+ throw new NotImplementedException($"Unexpected sigil: {(char)*ptr}");
+ }
+
+ // Stack now has table and value at itemIx on it
+ state.RawSetInteger(1, itemIx + 1);
+ }
+
+ return 1;
}
- break;
+ goto default;
default:
throw new Exception("Unexpected response: " + Encoding.UTF8.GetString(new Span(ptr, length)).Replace("\n", "|").Replace("\r", "") + "]");
}
- return null;
}
///
- /// Runs the precompiled Lua function with specified parse state
+ /// Runs the precompiled Lua function with the given outer session.
+ ///
+ /// Response is written directly into the .
///
- public object Run(int count, SessionParseState parseState)
+ public void RunForSession(int count, RespServerSession outerSession)
{
+ const int NeededStackSize = 3;
+
+ state.ForceMinimumStackCapacity(NeededStackSize);
+
scratchBufferManager.Reset();
- int offset = 1;
- int nKeys = parseState.GetInt(offset++);
+ var parseState = outerSession.parseState;
+
+ var offset = 1;
+ var nKeys = parseState.GetInt(offset++);
count--;
ResetParameters(nKeys, count - nKeys);
if (nKeys > 0)
{
- for (int i = 0; i < nKeys; i++)
+ // Get KEYS on the stack
+ state.PushInteger(keysTableRegistryIndex);
+ state.RawGet(LuaType.Table, (int)LuaRegistry.Index);
+
+ for (var i = 0; i < nKeys; i++)
{
+ ref var key = ref parseState.GetArgSliceByRef(offset);
+
if (txnMode)
{
- var key = parseState.GetArgSliceByRef(offset);
txnKeyEntries.AddKey(key, false, Tsavorite.core.LockType.Exclusive);
if (!respServerSession.storageSession.objectStoreLockableContext.IsNull)
txnKeyEntries.AddKey(key, true, Tsavorite.core.LockType.Exclusive);
}
- keyTable[i + 1] = parseState.GetString(offset++);
+
+ // Equivalent to KEYS[i+1] = key
+ state.PushInteger(i + 1);
+ state.PushBuffer(key.ReadOnlySpan);
+ state.RawSet(1);
+
+ offset++;
}
- count -= nKeys;
- //TODO: handle slot verification for Lua script keys
- //if (NetworkKeyArraySlotVerify(keys, true))
- //{
- // return true;
- //}
+ // Remove KEYS from the stack
+ state.Pop(1);
+
+ count -= nKeys;
}
if (count > 0)
{
- for (int i = 0; i < count; i++)
+ // Get ARGV on the stack
+ state.PushInteger(argvTableRegistryIndex);
+ state.RawGet(LuaType.Table, (int)LuaRegistry.Index);
+
+ for (var i = 0; i < count; i++)
{
- argvTable[i + 1] = parseState.GetString(offset++);
+ ref var argv = ref parseState.GetArgSliceByRef(offset);
+
+ // Equivalent to ARGV[i+1] = argv
+ state.PushInteger(i + 1);
+ state.PushBuffer(argv.ReadOnlySpan);
+ state.RawSet(1);
+
+ offset++;
}
+
+ // Remove ARGV from the stack
+ state.Pop(1);
}
+ var adapter = new RespResponseAdapter(outerSession);
+
if (txnMode && nKeys > 0)
{
- return RunTransaction();
+ RunInTransaction(ref adapter);
}
else
{
- return Run();
+ RunCommon(ref adapter);
}
}
///
- /// Runs the precompiled Lua function with specified (keys, argv) state
+ /// Runs the precompiled Lua function with specified (keys, argv) state.
+ ///
+ /// Meant for use from a .NET host rather than in Garnet properly.
///
- public object Run(string[] keys = null, string[] argv = null)
+ public unsafe object RunForRunner(string[] keys = null, string[] argv = null)
{
scratchBufferManager?.Reset();
- LoadParameters(keys, argv);
+ LoadParametersForRunner(keys, argv);
+
+ var adapter = new RunnerAdapter(scratchBufferManager);
+
if (txnMode && keys?.Length > 0)
{
// Add keys to the transaction
@@ -332,15 +876,88 @@ public object Run(string[] keys = null, string[] argv = null)
if (!respServerSession.storageSession.objectStoreLockableContext.IsNull)
txnKeyEntries.AddKey(_key, true, Tsavorite.core.LockType.Exclusive);
}
- return RunTransaction();
+
+ RunInTransaction(ref adapter);
}
else
{
- return Run();
+ RunCommon(ref adapter);
+ }
+
+ var resp = adapter.Response;
+ var respCur = (byte*)Unsafe.AsPointer(ref MemoryMarshal.GetReference(resp));
+ var respEnd = respCur + resp.Length;
+
+ if (RespReadUtils.TryReadErrorAsSpan(out var errSpan, ref respCur, respEnd))
+ {
+ var errStr = Encoding.UTF8.GetString(errSpan);
+ throw new GarnetException(errStr);
+ }
+
+ var ret = MapRespToObject(ref respCur, respEnd);
+ Debug.Assert(respCur == respEnd, "Should have fully consumed response");
+
+ return ret;
+
+ static object MapRespToObject(ref byte* cur, byte* end)
+ {
+ switch (*cur)
+ {
+ case (byte)'+':
+ var simpleStrRes = RespReadUtils.ReadSimpleString(out var simpleStr, ref cur, end);
+ Debug.Assert(simpleStrRes, "Should never fail");
+
+ return simpleStr;
+
+ case (byte)':':
+ var readIntRes = RespReadUtils.Read64Int(out var int64, ref cur, end);
+ Debug.Assert(readIntRes, "Should never fail");
+
+ return int64;
+
+ // Error ('-') is handled before call to MapRespToObject
+
+ case (byte)'$':
+ var length = end - cur;
+
+ if (length >= 5 && new ReadOnlySpan(cur + 1, 4).SequenceEqual("-1\r\n"u8))
+ {
+ cur += 5;
+ return null;
+ }
+
+ var bulkStrRes = RespReadUtils.ReadStringResponseWithLengthHeader(out var bulkStr, ref cur, end);
+ Debug.Assert(bulkStrRes, "Should never fail");
+
+ return bulkStr;
+
+ case (byte)'*':
+ var arrayLengthRes = RespReadUtils.ReadUnsignedArrayLength(out var itemCount, ref cur, end);
+ Debug.Assert(arrayLengthRes, "Should never fail");
+
+ if (itemCount == 0)
+ {
+ return Array.Empty