From 53feadef511b3b5838327ebe8a8ec2e904b13fcc Mon Sep 17 00:00:00 2001 From: nick-w-nick <43578531+nick-w-nick@users.noreply.github.com> Date: Mon, 30 Dec 2024 12:09:56 -0500 Subject: [PATCH] feat(core): Add support for `boolean` metadata attributes in `FunctionalTranslator` (#7407) Co-authored-by: jacoblee93 --- langchain-core/src/structured_query/base.ts | 2 +- .../src/structured_query/functional.ts | 51 +++- langchain-core/src/structured_query/ir.ts | 6 +- .../structured_query/tests/functional.test.ts | 254 ++++++++++++++++++ .../src/structured_query/tests/utils.test.ts | 9 +- langchain-core/src/structured_query/utils.ts | 11 +- .../src/chains/query_constructor/parser.ts | 5 +- .../tests/query_chain.int.test.ts | 18 +- .../tests/query_parser.test.ts | 3 +- 9 files changed, 346 insertions(+), 13 deletions(-) create mode 100644 langchain-core/src/structured_query/tests/functional.test.ts diff --git a/langchain-core/src/structured_query/base.ts b/langchain-core/src/structured_query/base.ts index d82f615d63c4..89efdb486af2 100644 --- a/langchain-core/src/structured_query/base.ts +++ b/langchain-core/src/structured_query/base.ts @@ -100,7 +100,7 @@ export class BasicTranslator< this.allowedComparators.indexOf(func as Comparator) === -1 ) { throw new Error( - `Comparator ${func} not allowed. Allowed operators: ${this.allowedComparators.join( + `Comparator ${func} not allowed. Allowed comparators: ${this.allowedComparators.join( ", " )}` ); diff --git a/langchain-core/src/structured_query/functional.ts b/langchain-core/src/structured_query/functional.ts index 7a0c2e894ad0..3b57b63bc6dc 100644 --- a/langchain-core/src/structured_query/functional.ts +++ b/langchain-core/src/structured_query/functional.ts @@ -17,8 +17,8 @@ import { castValue, isFilterEmpty } from "./utils.js"; * the result of a comparison operation. */ type ValueType = { - eq: string | number; - ne: string | number; + eq: string | number | boolean; + ne: string | number | boolean; lt: string | number; lte: string | number; gt: string | number; @@ -66,6 +66,42 @@ export class FunctionalTranslator extends BaseTranslator { throw new Error("Not implemented"); } + /** + * Returns the allowed comparators for a given data type. + * @param input The input value to get the allowed comparators for. + * @returns An array of allowed comparators for the input data type. + */ + getAllowedComparatorsForType(inputType: string): Comparator[] { + switch (inputType) { + case "string": { + return [ + Comparators.eq, + Comparators.ne, + Comparators.gt, + Comparators.gte, + Comparators.lt, + Comparators.lte, + ]; + } + case "number": { + return [ + Comparators.eq, + Comparators.ne, + Comparators.gt, + Comparators.gte, + Comparators.lt, + Comparators.lte, + ]; + } + case "boolean": { + return [Comparators.eq, Comparators.ne]; + } + default: { + throw new Error(`Unsupported data type: ${inputType}`); + } + } + } + /** * Returns a function that performs a comparison based on the provided * comparator. @@ -155,10 +191,19 @@ export class FunctionalTranslator extends BaseTranslator { * @param comparison The comparison part of a structured query. * @returns A function that takes a `Document` as an argument and returns a boolean based on the comparison. */ - visitComparison(comparison: Comparison): this["VisitComparisonOutput"] { + visitComparison( + comparison: Comparison + ): this["VisitComparisonOutput"] { const { comparator, attribute, value } = comparison; const undefinedTrue = [Comparators.ne]; if (this.allowedComparators.includes(comparator)) { + if ( + !this.getAllowedComparatorsForType(typeof value).includes(comparator) + ) { + throw new Error( + `'${comparator}' comparator not allowed to be used with ${typeof value}` + ); + } const comparatorFunction = this.getComparatorFunction(comparator); return (document: Document) => { const documentValue = document.metadata[attribute]; diff --git a/langchain-core/src/structured_query/ir.ts b/langchain-core/src/structured_query/ir.ts index d2bfa6215d81..bbfcbd098654 100644 --- a/langchain-core/src/structured_query/ir.ts +++ b/langchain-core/src/structured_query/ir.ts @@ -82,7 +82,7 @@ export type VisitorOperationResult = { */ export type VisitorComparisonResult = { [attr: string]: { - [comparator: string]: string | number; + [comparator: string]: string | number | boolean; }; }; @@ -149,13 +149,13 @@ export abstract class FilterDirective extends Expression {} * Class representing a comparison filter directive. It extends the * FilterDirective class. */ -export class Comparison extends FilterDirective { +export class Comparison extends FilterDirective { exprName = "Comparison" as const; constructor( public comparator: Comparator, public attribute: string, - public value: string | number + public value: ValueTypes ) { super(); } diff --git a/langchain-core/src/structured_query/tests/functional.test.ts b/langchain-core/src/structured_query/tests/functional.test.ts new file mode 100644 index 000000000000..1106c9e69193 --- /dev/null +++ b/langchain-core/src/structured_query/tests/functional.test.ts @@ -0,0 +1,254 @@ +import { test, expect, describe } from "@jest/globals"; +import { Document } from "../../documents/document.js"; +import { FunctionalTranslator } from "../functional.js"; +import { Comparators, Visitor } from "../ir.js"; + +describe("FunctionalTranslator", () => { + const translator = new FunctionalTranslator(); + + describe("getAllowedComparatorsForType", () => { + test("string", () => { + expect(translator.getAllowedComparatorsForType("string")).toEqual([ + Comparators.eq, + Comparators.ne, + Comparators.gt, + Comparators.gte, + Comparators.lt, + Comparators.lte, + ]); + }); + test("number", () => { + expect(translator.getAllowedComparatorsForType("number")).toEqual([ + Comparators.eq, + Comparators.ne, + Comparators.gt, + Comparators.gte, + Comparators.lt, + Comparators.lte, + ]); + }); + test("boolean", () => { + expect(translator.getAllowedComparatorsForType("boolean")).toEqual([ + Comparators.eq, + Comparators.ne, + ]); + }); + test("unsupported", () => { + expect(() => + translator.getAllowedComparatorsForType("unsupported") + ).toThrow("Unsupported data type: unsupported"); + }); + }); + + describe("visitComparison", () => { + describe("returns true or false for valid comparisons", () => { + const attributesByType = { + string: "stringValue", + number: "numberValue", + boolean: "booleanValue", + }; + + const inputValuesByAttribute: { + [key in string]: string | number | boolean; + } = { + stringValue: "value", + numberValue: 1, + booleanValue: true, + }; + + // documents that will match against the comparison + const validDocumentsByComparator: { + [key in string]: Document>[]; + } = { + [Comparators.eq]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "value", + numberValue: 1, + booleanValue: true, + }, + }), + ], + [Comparators.ne]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "not-value", + numberValue: 0, + booleanValue: false, + }, + }), + ], + [Comparators.gt]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "valueee", + numberValue: 2, + booleanValue: true, + }, + }), + ], + [Comparators.gte]: [ + // test for greater than + new Document({ + pageContent: "", + metadata: { + stringValue: "valueee", + numberValue: 2, + booleanValue: true, + }, + }), + // test for equal to + new Document({ + pageContent: "", + metadata: { + stringValue: "value", + numberValue: 1, + booleanValue: true, + }, + }), + ], + [Comparators.lt]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "val", + numberValue: 0, + booleanValue: true, + }, + }), + ], + [Comparators.lte]: [ + // test for less than + new Document({ + pageContent: "", + metadata: { + stringValue: "val", + numberValue: 0, + booleanValue: true, + }, + }), + // test for equal to + new Document({ + pageContent: "", + metadata: { + stringValue: "value", + numberValue: 1, + booleanValue: true, + }, + }), + ], + }; + + // documents that will not match against the comparison + const invalidDocumentsByComparator: { + [key in string]: Document>[]; + } = { + [Comparators.eq]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "not-value", + numberValue: 0, + booleanValue: false, + }, + }), + ], + [Comparators.ne]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "value", + numberValue: 1, + booleanValue: true, + }, + }), + ], + [Comparators.gt]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "value", + numberValue: 1, + booleanValue: true, + }, + }), + ], + [Comparators.gte]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "val", + numberValue: 0, + booleanValue: true, + }, + }), + ], + [Comparators.lt]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "valueee", + numberValue: 2, + booleanValue: true, + }, + }), + ], + [Comparators.lte]: [ + new Document({ + pageContent: "", + metadata: { + stringValue: "valueee", + numberValue: 2, + booleanValue: true, + }, + }), + ], + }; + + function generateComparatorTestsForType( + type: "string" | "number" | "boolean" + ) { + const comparators = translator.getAllowedComparatorsForType(type); + for (const comparator of comparators) { + const attribute = attributesByType[type]; + const value = inputValuesByAttribute[attribute]; + const validDocuments = validDocumentsByComparator[comparator]; + for (const validDocument of validDocuments) { + test(`${value} -> ${comparator} -> ${validDocument.metadata[attribute]}`, () => { + const comparison = translator.visitComparison({ + attribute, + comparator, + value, + exprName: "Comparison", + accept: (visitor: Visitor) => visitor, + }); + const result = comparison(validDocument); + expect(result).toBeTruthy(); + }); + } + const invalidDocuments = invalidDocumentsByComparator[comparator]; + for (const invalidDocument of invalidDocuments) { + test(`${value} -> ${comparator} -> ${invalidDocument.metadata[attribute]}`, () => { + const comparison = translator.visitComparison({ + attribute, + comparator, + value, + exprName: "Comparison", + accept: (visitor: Visitor) => visitor, + }); + const result = comparison(invalidDocument); + expect(result).toBeFalsy(); + }); + } + } + } + + generateComparatorTestsForType("string"); + generateComparatorTestsForType("number"); + generateComparatorTestsForType("boolean"); + }); + }); +}); diff --git a/langchain-core/src/structured_query/tests/utils.test.ts b/langchain-core/src/structured_query/tests/utils.test.ts index 94ac85e84dfa..53b413702c2f 100644 --- a/langchain-core/src/structured_query/tests/utils.test.ts +++ b/langchain-core/src/structured_query/tests/utils.test.ts @@ -1,6 +1,6 @@ /* eslint-disable no-process-env */ import { test, expect } from "@jest/globals"; -import { castValue, isFloat, isInt, isString } from "../utils.js"; +import { castValue, isFloat, isInt, isString, isBoolean } from "../utils.js"; test("Casting values correctly", () => { const stringString = [ @@ -28,6 +28,8 @@ test("Casting values correctly", () => { const floatFloat = ["1.1", 2.2, 3.3]; + const booleanBoolean = [true, false]; + stringString.map(castValue).forEach((value) => { expect(typeof value).toBe("string"); expect(isString(value)).toBe(true); @@ -54,4 +56,9 @@ test("Casting values correctly", () => { expect(typeof value).toBe("number"); expect(isFloat(value)).toBe(true); }); + + booleanBoolean.map(castValue).forEach((value) => { + expect(typeof value).toBe("boolean"); + expect(isBoolean(value)).toBe(true); + }); }); diff --git a/langchain-core/src/structured_query/utils.ts b/langchain-core/src/structured_query/utils.ts index 05e699e83394..92c9639b365b 100644 --- a/langchain-core/src/structured_query/utils.ts +++ b/langchain-core/src/structured_query/utils.ts @@ -72,13 +72,20 @@ export function isString(value: unknown): boolean { ); } +/** + * Checks if the provided value is a boolean. + */ +export function isBoolean(value: unknown): boolean { + return typeof value === "boolean"; +} + /** * Casts a value that might be string or number to actual string or number. * Since LLM might return back an integer/float as a string, we need to cast * it back to a number, as many vector databases can't handle number as string * values as a comparator. */ -export function castValue(input: unknown): string | number { +export function castValue(input: unknown): string | number | boolean { let value; if (isString(input)) { value = input as string; @@ -86,6 +93,8 @@ export function castValue(input: unknown): string | number { value = parseInt(input as string, 10); } else if (isFloat(input)) { value = parseFloat(input as string); + } else if (isBoolean(input)) { + value = Boolean(input); } else { throw new Error("Unsupported value type"); } diff --git a/langchain/src/chains/query_constructor/parser.ts b/langchain/src/chains/query_constructor/parser.ts index 49998a274718..c340395a860a 100644 --- a/langchain/src/chains/query_constructor/parser.ts +++ b/langchain/src/chains/query_constructor/parser.ts @@ -91,10 +91,11 @@ export class QueryTransformer { } if (funcName in Comparators) { if (node.args && node.args.length === 2) { + const [attribute, value] = node.args; return new Comparison( funcName as Comparator, - traverse(node.args[0]) as string, - traverse(node.args[1]) as string | number + traverse(attribute) as string, + traverse(value) as string | number ); } throw new Error("Comparator must have exactly 2 arguments"); diff --git a/langchain/src/chains/query_constructor/tests/query_chain.int.test.ts b/langchain/src/chains/query_constructor/tests/query_chain.int.test.ts index 332ab6afdb08..59279fe70ef1 100644 --- a/langchain/src/chains/query_constructor/tests/query_chain.int.test.ts +++ b/langchain/src/chains/query_constructor/tests/query_chain.int.test.ts @@ -33,6 +33,10 @@ test("Query Chain Test", async () => { new Comparison(Comparators.lt, "length", 90), ]) ); + const sq6 = new StructuredQuery( + "", + new Comparison(Comparators.eq, "isReleased", true) + ); const filter1 = { length: { $lt: 90 } }; const filter3 = { rating: { $gt: 8.5 } }; @@ -43,6 +47,7 @@ test("Query Chain Test", async () => { { length: { $lt: 90 } }, ], }; + const filter6 = { isReleased: { $eq: true } }; const attributeInfo: AttributeInfo[] = [ { @@ -70,6 +75,11 @@ test("Query Chain Test", async () => { description: "The length of the movie in minutes", type: "number", }, + { + name: "isReleased", + description: "Whether the movie has been released", + type: "boolean", + }, ]; const documentContents = "Brief summary of a movie"; @@ -100,22 +110,28 @@ test("Query Chain Test", async () => { query: "Which movies are either comedy or drama and are less than 90 minutes?", }); + const c6 = queryChain.invoke({ + query: "Which movies have already been released?", + }); - const [r1, r3, r4, r5] = await Promise.all([c1, c3, c4, c5]); + const [r1, r3, r4, r5, r6] = await Promise.all([c1, c3, c4, c5, c6]); expect(r1).toMatchObject(sq1); expect(r3).toMatchObject(sq3); expect(r4).toMatchObject(sq4); expect(r5).toMatchObject(sq5); + expect(r6).toMatchObject(sq6); const testTranslator = new BasicTranslator(); const { filter: parsedFilter1 } = testTranslator.visitStructuredQuery(r1); const { filter: parsedFilter3 } = testTranslator.visitStructuredQuery(r3); const { filter: parsedFilter4 } = testTranslator.visitStructuredQuery(r4); const { filter: parsedFilter5 } = testTranslator.visitStructuredQuery(r5); + const { filter: parsedFilter6 } = testTranslator.visitStructuredQuery(r6); expect(parsedFilter1).toMatchObject(filter1); expect(parsedFilter3).toMatchObject(filter3); expect(parsedFilter4).toMatchObject(filter4); expect(parsedFilter5).toMatchObject(filter5); + expect(parsedFilter6).toMatchObject(filter6); }); diff --git a/langchain/src/chains/query_constructor/tests/query_parser.test.ts b/langchain/src/chains/query_constructor/tests/query_parser.test.ts index 62dc607e79f1..10428a2e6272 100644 --- a/langchain/src/chains/query_constructor/tests/query_parser.test.ts +++ b/langchain/src/chains/query_constructor/tests/query_parser.test.ts @@ -17,6 +17,7 @@ const correctQuery = new StructuredQuery( ]), new Comparison(Comparators.lt, "length", 180), new Comparison(Comparators.eq, "genre", "pop"), + new Comparison(Comparators.eq, "hasLyrics", true), ]) ); @@ -35,7 +36,7 @@ test("StructuredQueryOutputParser test", async () => { const exampleOutput = `json\`\`\` { "query": "teenager love", - "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))" + "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"), eq(\\"hasLyrics\\", true))" } \`\`\``;