From d890fecad38ed11d90a85e6472e64c81c607cf91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 18 Jun 2024 09:29:42 +0000 Subject: [PATCH 1/8] docs(llm): add docs for azure openai (#55) --- docs/how-to/llms/litellm.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/how-to/llms/litellm.md b/docs/how-to/llms/litellm.md index 03f41208..6d995af8 100644 --- a/docs/how-to/llms/litellm.md +++ b/docs/how-to/llms/litellm.md @@ -48,6 +48,24 @@ Integrate db-ally with your LLM vendor. llm=LiteLLM(model_name="anyscale/meta-llama/Llama-2-70b-chat-hf") ``` +=== "Azure OpenAI" + + ```python + import os + from dbally.llms.litellm import LiteLLM + + ## set ENV variables + os.environ["AZURE_API_KEY"] = "your-api-key" + os.environ["AZURE_API_BASE"] = "your-api-base-url" + os.environ["AZURE_API_VERSION"] = "your-api-version" + + # optional + os.environ["AZURE_AD_TOKEN"] = "" + os.environ["AZURE_API_TYPE"] = "" + + llm = LiteLLM(model_name="azure/") + ``` + Use LLM in your collection. ```python From 9fd817f3955e4e0c61da1cf9be44e9b6ac426c15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Thu, 20 Jun 2024 08:13:53 +0000 Subject: [PATCH 2/8] refactor: move events to audit module (#58) --- docs/how-to/create_custom_event_handler.md | 3 +-- docs/reference/event_handlers/index.md | 16 +++++++++----- src/dbally/audit/__init__.py | 11 ++++++++++ src/dbally/audit/event_handlers/base.py | 10 ++++----- .../audit/event_handlers/cli_event_handler.py | 14 +++++------- .../event_handlers/langsmith_event_handler.py | 10 ++++----- src/dbally/audit/event_span.py | 22 ------------------- src/dbally/audit/event_tracker.py | 8 +++---- .../{data_models/audit.py => audit/events.py} | 13 +++++------ src/dbally/audit/spans.py | 21 ++++++++++++++++++ src/dbally/collection/collection.py | 2 +- src/dbally/data_models/__init__.py | 0 src/dbally/llms/base.py | 2 +- src/dbally/llms/clients/base.py | 2 +- src/dbally/llms/clients/litellm.py | 2 +- src/dbally/similarity/index.py | 2 +- 16 files changed, 73 insertions(+), 65 deletions(-) delete mode 100644 src/dbally/audit/event_span.py rename src/dbally/{data_models/audit.py => audit/events.py} (87%) create mode 100644 src/dbally/audit/spans.py delete mode 100644 src/dbally/data_models/__init__.py diff --git a/docs/how-to/create_custom_event_handler.md b/docs/how-to/create_custom_event_handler.md index 410973c5..d4c26f74 100644 --- a/docs/how-to/create_custom_event_handler.md +++ b/docs/how-to/create_custom_event_handler.md @@ -10,8 +10,7 @@ In this guide we will implement a simple [Event Handler](../reference/event_hand First, we need to create a new class that inherits from `EventHandler` and implements the all abstract methods. ```python -from dbally.audit import EventHandler -from dbally.data_models.audit import RequestStart, RequestEnd +from dbally.audit import EventHandler, RequestStart, RequestEnd class FileEventHandler(EventHandler): diff --git a/docs/reference/event_handlers/index.md b/docs/reference/event_handlers/index.md index ae69bc0d..f95f5798 100644 --- a/docs/reference/event_handlers/index.md +++ b/docs/reference/event_handlers/index.md @@ -10,10 +10,10 @@ db-ally provides an `EventHandler` abstract class that can be used to log the ru Each run of [dbally.Collection.ask][dbally.Collection.ask] will trigger all instances of EventHandler that were passed to the Collection's constructor (or the [dbally.create_collection][dbally.create_collection] function). -1. `EventHandler.request_start` is called with [RequestStart][dbally.data_models.audit.RequestStart], it can return a context object that will be passed to next calls. +1. `EventHandler.request_start` is called with [RequestStart][dbally.audit.events.RequestStart], it can return a context object that will be passed to next calls. 2. For each event that occurs during the run, `EventHandler.event_start` is called with the context object returned by `EventHandler.request_start` and an Event object. It can return context for the `EventHandler.event_end` method. 3. When the event ends `EventHandler.event_end` is called with the context object returned by `EventHandler.event_start` and an Event object. -4. On the end of the run `EventHandler.request_end` is called with the context object returned by `EventHandler.request_start` and the [RequestEnd][dbally.data_models.audit.RequestEnd]. +4. On the end of the run `EventHandler.request_end` is called with the context object returned by `EventHandler.request_start` and the [RequestEnd][dbally.audit.events.RequestEnd]. ``` mermaid @@ -42,8 +42,14 @@ Currently handled events: ::: dbally.audit.EventHandler -::: dbally.data_models.audit.RequestStart +::: dbally.audit.events.RequestStart -::: dbally.data_models.audit.RequestEnd +::: dbally.audit.events.RequestEnd -::: dbally.data_models.audit.LLMEvent +::: dbally.audit.events.Event + +::: dbally.audit.events.LLMEvent + +::: dbally.audit.events.SimilarityEvent + +::: dbally.audit.spans.EventSpan diff --git a/src/dbally/audit/__init__.py b/src/dbally/audit/__init__.py index af9a5384..73253f71 100644 --- a/src/dbally/audit/__init__.py +++ b/src/dbally/audit/__init__.py @@ -7,8 +7,19 @@ except ImportError: pass +from .event_tracker import EventTracker +from .events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from .spans import EventSpan + __all__ = [ "CLIEventHandler", "LangSmithEventHandler", + "Event", "EventHandler", + "EventTracker", + "EventSpan", + "LLMEvent", + "RequestEnd", + "RequestStart", + "SimilarityEvent", ] diff --git a/src/dbally/audit/event_handlers/base.py b/src/dbally/audit/event_handlers/base.py index 10fce0cf..dc3ea7f8 100644 --- a/src/dbally/audit/event_handlers/base.py +++ b/src/dbally/audit/event_handlers/base.py @@ -1,8 +1,8 @@ import abc from abc import ABC -from typing import Generic, TypeVar, Union +from typing import Generic, Optional, TypeVar -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.events import Event, RequestEnd, RequestStart RequestCtx = TypeVar("RequestCtx") EventCtx = TypeVar("EventCtx") @@ -26,7 +26,7 @@ async def request_start(self, user_request: RequestStart) -> RequestCtx: """ @abc.abstractmethod - async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: RequestCtx) -> EventCtx: + async def event_start(self, event: Event, request_context: RequestCtx) -> EventCtx: """ Function that is called during every event execution. @@ -40,9 +40,7 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con """ @abc.abstractmethod - async def event_end( - self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RequestCtx, event_context: EventCtx - ) -> None: + async def event_end(self, event: Optional[Event], request_context: RequestCtx, event_context: EventCtx) -> None: """ Function that is called during every event execution. diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index f738f90b..aa48e049 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -1,7 +1,7 @@ import re from io import StringIO from sys import stdout -from typing import Optional, Union +from typing import Optional try: from rich import print as pprint @@ -15,7 +15,7 @@ pprint = print # type: ignore from dbally.audit.event_handlers.base import EventHandler -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent _RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"} _RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]" @@ -40,14 +40,14 @@ class CLIEventHandler(EventHandler): ![Example output from CLIEventHandler](../../assets/event_handler_example.png) """ - def __init__(self, buffer: StringIO = None) -> None: + def __init__(self, buffer: Optional[StringIO] = None) -> None: super().__init__() self.buffer = buffer out = self.buffer if buffer else stdout self._console = Console(file=out, record=True) if RICH_OUTPUT else None - def _print_syntax(self, content: str, lexer: str = None) -> None: + def _print_syntax(self, content: str, lexer: Optional[str] = None) -> None: if self._console: if lexer: console_content = Syntax(content, lexer, word_wrap=True) @@ -69,7 +69,7 @@ async def request_start(self, user_request: RequestStart) -> None: self._print_syntax("[grey53]\n=======================================") self._print_syntax("[grey53]=======================================\n") - async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: None) -> None: + async def event_start(self, event: Event, request_context: None) -> None: """ Displays information that event has started, then all messages inside the prompt @@ -98,9 +98,7 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n" ) - async def event_end( - self, event: Union[None, LLMEvent, SimilarityEvent], request_context: None, event_context: None - ) -> None: + async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None: """ Displays the response from the LLM. diff --git a/src/dbally/audit/event_handlers/langsmith_event_handler.py b/src/dbally/audit/event_handlers/langsmith_event_handler.py index 5974a068..c0b619c2 100644 --- a/src/dbally/audit/event_handlers/langsmith_event_handler.py +++ b/src/dbally/audit/event_handlers/langsmith_event_handler.py @@ -1,12 +1,12 @@ import socket from getpass import getuser -from typing import Optional, Union +from typing import Optional from langsmith.client import Client from langsmith.run_trees import RunTree from dbally.audit.event_handlers.base import EventHandler -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent class LangSmithEventHandler(EventHandler[RunTree, RunTree]): @@ -47,7 +47,7 @@ async def request_start(self, user_request: RequestStart) -> RunTree: return run_tree - async def event_start(self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree) -> RunTree: + async def event_start(self, event: Event, request_context: RunTree) -> RunTree: """ Log the start of the event. @@ -79,9 +79,7 @@ async def event_start(self, event: Union[None, LLMEvent, SimilarityEvent], reque raise ValueError("Unsupported event") - async def event_end( - self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree, event_context: RunTree - ) -> None: + async def event_end(self, event: Optional[Event], request_context: RunTree, event_context: RunTree) -> None: """ Log the end of the event. diff --git a/src/dbally/audit/event_span.py b/src/dbally/audit/event_span.py deleted file mode 100644 index c7cba584..00000000 --- a/src/dbally/audit/event_span.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Any, Optional, Union - -from dbally.data_models.audit import LLMEvent, SimilarityEvent - - -class EventSpan: - """Helper class for logging events.""" - - data: Optional[Any] - - def __init__(self) -> None: - self.data = None - - def __call__(self, data: Union[LLMEvent, SimilarityEvent]) -> None: - """ - Call method for logging events. - - Args: - data: Event data. - """ - - self.data = data diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py index c483a65e..34faf803 100644 --- a/src/dbally/audit/event_tracker.py +++ b/src/dbally/audit/event_tracker.py @@ -1,9 +1,9 @@ from contextlib import asynccontextmanager -from typing import AsyncIterator, Dict, List, Optional, Union +from typing import AsyncIterator, Dict, List, Optional from dbally.audit.event_handlers.base import EventHandler -from dbally.audit.event_span import EventSpan -from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent +from dbally.audit.events import Event, RequestEnd, RequestStart +from dbally.audit.spans import EventSpan class EventTracker: @@ -69,7 +69,7 @@ def subscribe(self, event_handler: EventHandler) -> None: self._handlers.append(event_handler) @asynccontextmanager - async def track_event(self, event: Union[LLMEvent, SimilarityEvent]) -> AsyncIterator[EventSpan]: + async def track_event(self, event: Event) -> AsyncIterator[EventSpan]: """ Context manager for processing an event. diff --git a/src/dbally/data_models/audit.py b/src/dbally/audit/events.py similarity index 87% rename from src/dbally/data_models/audit.py rename to src/dbally/audit/events.py index c56e73be..c02cd5cb 100644 --- a/src/dbally/data_models/audit.py +++ b/src/dbally/audit/events.py @@ -1,21 +1,20 @@ +from abc import ABC from dataclasses import dataclass -from enum import Enum from typing import Optional, Union from dbally.collection.results import ExecutionResult from dbally.prompts import ChatFormat -class EventType(Enum): +@dataclass +class Event(ABC): """ - Enum for event types. + Base class for all events. """ - LLM = "LLM" - @dataclass -class LLMEvent: +class LLMEvent(Event): """ Class for LLM event. """ @@ -30,7 +29,7 @@ class LLMEvent: @dataclass -class SimilarityEvent: +class SimilarityEvent(Event): """ SimilarityEvent is fired when a SimilarityIndex lookup is performed. """ diff --git a/src/dbally/audit/spans.py b/src/dbally/audit/spans.py new file mode 100644 index 00000000..0b9d273d --- /dev/null +++ b/src/dbally/audit/spans.py @@ -0,0 +1,21 @@ +from typing import Optional + +from dbally.audit.events import Event + + +class EventSpan: + """ + Helper class for logging events. + """ + + def __init__(self) -> None: + self.data: Optional[Event] = None + + def __call__(self, data: Event) -> None: + """ + Call method for logging events. + + Args: + data: Event data. + """ + self.data = data diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py index c207d95b..5c059fbc 100644 --- a/src/dbally/collection/collection.py +++ b/src/dbally/collection/collection.py @@ -7,9 +7,9 @@ from dbally.audit.event_handlers.base import EventHandler from dbally.audit.event_tracker import EventTracker +from dbally.audit.events import RequestEnd, RequestStart from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ExecutionResult -from dbally.data_models.audit import RequestEnd, RequestStart from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder diff --git a/src/dbally/data_models/__init__.py b/src/dbally/data_models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/dbally/llms/base.py b/src/dbally/llms/base.py index e570547c..55b85c20 100644 --- a/src/dbally/llms/base.py +++ b/src/dbally/llms/base.py @@ -3,7 +3,7 @@ from typing import Dict, Generic, Optional, Type from dbally.audit.event_tracker import EventTracker -from dbally.data_models.audit import LLMEvent +from dbally.audit.events import LLMEvent from dbally.llms.clients.base import LLMClient, LLMClientOptions, LLMOptions from dbally.prompts.common_validation_utils import ChatFormat from dbally.prompts.prompt_template import PromptTemplate diff --git a/src/dbally/llms/clients/base.py b/src/dbally/llms/clients/base.py index bc55f6ea..5de63ce7 100644 --- a/src/dbally/llms/clients/base.py +++ b/src/dbally/llms/clients/base.py @@ -2,7 +2,7 @@ from dataclasses import asdict, dataclass from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar -from dbally.data_models.audit import LLMEvent +from dbally.audit.events import LLMEvent from dbally.prompts import ChatFormat from ..._types import NotGiven diff --git a/src/dbally/llms/clients/litellm.py b/src/dbally/llms/clients/litellm.py index 3ec4ccc9..b15ad362 100644 --- a/src/dbally/llms/clients/litellm.py +++ b/src/dbally/llms/clients/litellm.py @@ -9,7 +9,7 @@ HAVE_LITELLM = False -from dbally.data_models.audit import LLMEvent +from dbally.audit.events import LLMEvent from dbally.llms.clients.base import LLMClient, LLMOptions from dbally.llms.clients.exceptions import LLMConnectionError, LLMResponseError, LLMStatusError from dbally.prompts import ChatFormat diff --git a/src/dbally/similarity/index.py b/src/dbally/similarity/index.py index 6895c566..31cfe8cf 100644 --- a/src/dbally/similarity/index.py +++ b/src/dbally/similarity/index.py @@ -2,7 +2,7 @@ from typing import Optional from dbally.audit.event_tracker import EventTracker -from dbally.data_models.audit import SimilarityEvent +from dbally.audit.events import SimilarityEvent from dbally.similarity.fetcher import SimilarityFetcher from dbally.similarity.store import SimilarityStore From db3b53c48eda605ceacf64edd13bdf5e8577c7c1 Mon Sep 17 00:00:00 2001 From: pwyzgow Date: Thu, 20 Jun 2024 10:51:27 +0200 Subject: [PATCH 3/8] Adding initial changes for aggregation example in quickstart code. --- docs/quickstart/quickstart_code.py | 10 +++++++- src/dbally/iql_generator/iql_generator.py | 23 +++++++++++++++++-- .../nl_responder_prompt_template.py | 4 ++-- .../query_explainer_prompt_template.py | 6 ++--- .../view_selection/llm_view_selector.py | 16 ++++++++++++- src/dbally/views/decorators.py | 14 +++++++++++ src/dbally/views/sqlalchemy_base.py | 20 ++++++++++++++++ src/dbally/views/structured.py | 18 +++++++++++++++ 8 files changed, 102 insertions(+), 9 deletions(-) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 34ee9765..82571711 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -1,4 +1,6 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring +from typing import Union, Tuple, Any + import dbally import asyncio @@ -54,6 +56,11 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: """ return Candidate.country == country + @decorators.view_aggregation() + def group_by_university(self, aggregation:str): # -> Union[Select[Tuple[Any, Any]], Select]: # pylint: disable=W0602, C0116, W9011 + return sqlalchemy.select(Candidate.university, sqlalchemy.func.count(Candidate.university).label("count")) \ + .group_by(Candidate.university) + async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") @@ -61,7 +68,8 @@ async def main(): collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(engine)) - result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") + # result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") + result = await collection.ask("Could you count the candidates university-wise and present the rows?") print(f"The generated SQL query is: {result.context.get('sql')}") print() diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 8633afc0..48557fae 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -38,11 +38,12 @@ def __init__( """ self._llm = llm self._prompt_template = prompt_template or copy.deepcopy(default_iql_template) - self._promptify_view = promptify_view or _promptify_filters + self._promptify_view = promptify_view or _promptify_filters or _promptify_aggregations async def generate_iql( self, filters: List[ExposedFunction], + aggregations: List[ExposedFunction], question: str, event_tracker: EventTracker, conversation: Optional[IQLPromptTemplate] = None, @@ -62,12 +63,14 @@ async def generate_iql( IQL - iql generated based on the user question """ filters_for_prompt = self._promptify_view(filters) + aggregations_for_prompt = self._promptify_view(aggregations) template = conversation or self._prompt_template llm_response = await self._llm.generate_text( template=template, - fmt={"filters": filters_for_prompt, "question": question}, + fmt={"filters": filters_for_prompt, "question": question, + "aggregation": aggregations_for_prompt}, event_tracker=event_tracker, options=llm_options, ) @@ -114,3 +117,19 @@ def _promptify_filters( """ filters_for_prompt = "\n".join([str(filter) for filter in filters]) return filters_for_prompt + + +def _promptify_aggregations( + aggregations: List[ExposedFunction], +) -> str: + """ + Formats filters for prompt + + Args: + filters: list of filters exposed by the view + + Returns: + filters_for_prompt: filters formatted for prompt + """ + aggregations_for_prompt = "\n".join([str(aggregation) for aggregation in aggregations]) + return aggregations_for_prompt diff --git a/src/dbally/nl_responder/nl_responder_prompt_template.py b/src/dbally/nl_responder/nl_responder_prompt_template.py index 9e6e687e..eeb3bed9 100644 --- a/src/dbally/nl_responder/nl_responder_prompt_template.py +++ b/src/dbally/nl_responder/nl_responder_prompt_template.py @@ -24,7 +24,7 @@ def __init__( """ super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"rows", "question"}) + self.chat = check_prompt_variables(chat, {"rows", "question", "aggregation"}) default_nl_responder_template = NLResponderPromptTemplate( @@ -34,7 +34,7 @@ def __init__( "content": "You are a helpful assistant that helps answer the user's questions " "based on the table provided. You MUST use the table to answer the question. " "You are very intelligent and obedient.\n" - "The table ALWAYS contains full answer to a question.\n" + "The table ALWAYS contains full answer to a question including necessary {aggregation}.\n" "Answer the question in a way that is easy to understand and informative.\n" "DON'T MENTION using a table in your answer.", }, diff --git a/src/dbally/nl_responder/query_explainer_prompt_template.py b/src/dbally/nl_responder/query_explainer_prompt_template.py index 00a3e6a6..8c60122f 100644 --- a/src/dbally/nl_responder/query_explainer_prompt_template.py +++ b/src/dbally/nl_responder/query_explainer_prompt_template.py @@ -21,7 +21,7 @@ def __init__( llm_response_parser: Callable = lambda x: x, ) -> None: super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"}) + self.chat = check_prompt_variables(chat, {"question", "query", "aggregation", "number_of_results"}) default_query_explainer_template = QueryExplainerPromptTemplate( @@ -34,14 +34,14 @@ def __init__( "Your task is to provide natural language description of the table used by the logical query " "to the database.\n" "Describe the table in a way that is short and informative.\n" - "Make your answer as short as possible, start it by infroming the user that the underlying " + "Make your answer as short as possible, start it by informing the user that the underlying " "data is too long to print and then describe the table based on the question and the query.\n" "DON'T MENTION using a query in your answer.\n", }, { "role": "user", "content": "The query below represents the answer to a question: {question}.\n" - "Describe the table generated using this query: {query}.\n" + "Describe the table generated using this query: {query} which applies {aggregation}.\n" "Number of results to this query: {number_of_results}.\n", }, ) diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py index 2d501922..d209b0f6 100644 --- a/src/dbally/view_selection/llm_view_selector.py +++ b/src/dbally/view_selection/llm_view_selector.py @@ -35,7 +35,7 @@ def __init__( """ self._llm = llm self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template) - self._promptify_views = promptify_views or _promptify_views + self._promptify_views = promptify_views or _promptify_views or _promptify_aggregations async def select_view( self, @@ -81,3 +81,17 @@ def _promptify_views(views: Dict[str, str]) -> str: """ return "\n".join([f"{name}: {description}" for name, description in views.items()]) + + +def _promptify_aggregations(views: Dict[str, str]) -> str: + """ + Formats views for aggregation + + Args: + views: dictionary of available view names with corresponding descriptions. + + Returns: + views_for_prompt: views formatted for prompt + """ + + return "\n".join([f"{name}: {description}" for name, description in views.items()]) diff --git a/src/dbally/views/decorators.py b/src/dbally/views/decorators.py index ac537f5f..bd49cfaa 100644 --- a/src/dbally/views/decorators.py +++ b/src/dbally/views/decorators.py @@ -14,3 +14,17 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin return func return wrapped + +def view_aggregation() -> typing.Callable: + """ + Decorator for marking a method as an aggregation + + Returns: + Function that returns the decorated method + """ + + def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc + func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access + return func + + return wrapped diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index b1783558..cfc720dc 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -64,6 +64,26 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu return alchemy_op(await self._build_filter_node(bool_op.child)) raise ValueError(f"BoolOp {bool_op} has no children") + async def _build_aggregation_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: + """ + Converts a filter node from the IQLQuery to a SQLAlchemy expression. + """ + if isinstance(node, syntax.BoolOp): + return await self._build_filter_bool_op(node) + if isinstance(node, syntax.FunctionCall): + return await self.call_filter_method(node) + + raise ValueError(f"Unsupported grammar: {node}") + + async def apply_aggregation(self, aggregation: IQLQuery) -> None: + """ + Applies the chosen aggregation to the view. + + Args: + aggregation: IQLQuery object representing the aggregation to apply + """ + self._select = self._select.where(await self._build_filter_node(aggregation.root)) + def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ Executes the generated SQL query and returns the results. diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c43e4c2b..2107da77 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -58,6 +58,7 @@ async def ask( """ iql_generator = self.get_iql_generator(llm) filter_list = self.list_filters() + aggregation_list = self.list_aggregations() iql_filters, conversation = await iql_generator.generate_iql( question=query, @@ -104,6 +105,23 @@ async def apply_filters(self, filters: IQLQuery) -> None: filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply """ + @abc.abstractmethod + def list_aggregations(self) -> List[ExposedFunction]: + """ + + Returns: + Aggregations defined inside the View. + """ + + @abc.abstractmethod + async def apply_aggregations(self, filters: IQLQuery) -> None: + """ + Applies the chosen filters to the view. + + Args: + filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply + """ + @abc.abstractmethod def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ From be639c0d94d5fa8e372ebc29293804ab237018d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= Date: Thu, 20 Jun 2024 11:24:27 +0200 Subject: [PATCH 4/8] Revert "Adding initial changes for aggregation example in quickstart code." This reverts commit db3b53c48eda605ceacf64edd13bdf5e8577c7c1. --- docs/quickstart/quickstart_code.py | 10 +------- src/dbally/iql_generator/iql_generator.py | 23 ++----------------- .../nl_responder_prompt_template.py | 4 ++-- .../query_explainer_prompt_template.py | 6 ++--- .../view_selection/llm_view_selector.py | 16 +------------ src/dbally/views/decorators.py | 14 ----------- src/dbally/views/sqlalchemy_base.py | 20 ---------------- src/dbally/views/structured.py | 18 --------------- 8 files changed, 9 insertions(+), 102 deletions(-) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 82571711..34ee9765 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -1,6 +1,4 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring -from typing import Union, Tuple, Any - import dbally import asyncio @@ -56,11 +54,6 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: """ return Candidate.country == country - @decorators.view_aggregation() - def group_by_university(self, aggregation:str): # -> Union[Select[Tuple[Any, Any]], Select]: # pylint: disable=W0602, C0116, W9011 - return sqlalchemy.select(Candidate.university, sqlalchemy.func.count(Candidate.university).label("count")) \ - .group_by(Candidate.university) - async def main(): llm = LiteLLM(model_name="gpt-3.5-turbo") @@ -68,8 +61,7 @@ async def main(): collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(engine)) - # result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") - result = await collection.ask("Could you count the candidates university-wise and present the rows?") + result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") print(f"The generated SQL query is: {result.context.get('sql')}") print() diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 48557fae..8633afc0 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -38,12 +38,11 @@ def __init__( """ self._llm = llm self._prompt_template = prompt_template or copy.deepcopy(default_iql_template) - self._promptify_view = promptify_view or _promptify_filters or _promptify_aggregations + self._promptify_view = promptify_view or _promptify_filters async def generate_iql( self, filters: List[ExposedFunction], - aggregations: List[ExposedFunction], question: str, event_tracker: EventTracker, conversation: Optional[IQLPromptTemplate] = None, @@ -63,14 +62,12 @@ async def generate_iql( IQL - iql generated based on the user question """ filters_for_prompt = self._promptify_view(filters) - aggregations_for_prompt = self._promptify_view(aggregations) template = conversation or self._prompt_template llm_response = await self._llm.generate_text( template=template, - fmt={"filters": filters_for_prompt, "question": question, - "aggregation": aggregations_for_prompt}, + fmt={"filters": filters_for_prompt, "question": question}, event_tracker=event_tracker, options=llm_options, ) @@ -117,19 +114,3 @@ def _promptify_filters( """ filters_for_prompt = "\n".join([str(filter) for filter in filters]) return filters_for_prompt - - -def _promptify_aggregations( - aggregations: List[ExposedFunction], -) -> str: - """ - Formats filters for prompt - - Args: - filters: list of filters exposed by the view - - Returns: - filters_for_prompt: filters formatted for prompt - """ - aggregations_for_prompt = "\n".join([str(aggregation) for aggregation in aggregations]) - return aggregations_for_prompt diff --git a/src/dbally/nl_responder/nl_responder_prompt_template.py b/src/dbally/nl_responder/nl_responder_prompt_template.py index eeb3bed9..9e6e687e 100644 --- a/src/dbally/nl_responder/nl_responder_prompt_template.py +++ b/src/dbally/nl_responder/nl_responder_prompt_template.py @@ -24,7 +24,7 @@ def __init__( """ super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"rows", "question", "aggregation"}) + self.chat = check_prompt_variables(chat, {"rows", "question"}) default_nl_responder_template = NLResponderPromptTemplate( @@ -34,7 +34,7 @@ def __init__( "content": "You are a helpful assistant that helps answer the user's questions " "based on the table provided. You MUST use the table to answer the question. " "You are very intelligent and obedient.\n" - "The table ALWAYS contains full answer to a question including necessary {aggregation}.\n" + "The table ALWAYS contains full answer to a question.\n" "Answer the question in a way that is easy to understand and informative.\n" "DON'T MENTION using a table in your answer.", }, diff --git a/src/dbally/nl_responder/query_explainer_prompt_template.py b/src/dbally/nl_responder/query_explainer_prompt_template.py index 8c60122f..00a3e6a6 100644 --- a/src/dbally/nl_responder/query_explainer_prompt_template.py +++ b/src/dbally/nl_responder/query_explainer_prompt_template.py @@ -21,7 +21,7 @@ def __init__( llm_response_parser: Callable = lambda x: x, ) -> None: super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"question", "query", "aggregation", "number_of_results"}) + self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"}) default_query_explainer_template = QueryExplainerPromptTemplate( @@ -34,14 +34,14 @@ def __init__( "Your task is to provide natural language description of the table used by the logical query " "to the database.\n" "Describe the table in a way that is short and informative.\n" - "Make your answer as short as possible, start it by informing the user that the underlying " + "Make your answer as short as possible, start it by infroming the user that the underlying " "data is too long to print and then describe the table based on the question and the query.\n" "DON'T MENTION using a query in your answer.\n", }, { "role": "user", "content": "The query below represents the answer to a question: {question}.\n" - "Describe the table generated using this query: {query} which applies {aggregation}.\n" + "Describe the table generated using this query: {query}.\n" "Number of results to this query: {number_of_results}.\n", }, ) diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py index d209b0f6..2d501922 100644 --- a/src/dbally/view_selection/llm_view_selector.py +++ b/src/dbally/view_selection/llm_view_selector.py @@ -35,7 +35,7 @@ def __init__( """ self._llm = llm self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template) - self._promptify_views = promptify_views or _promptify_views or _promptify_aggregations + self._promptify_views = promptify_views or _promptify_views async def select_view( self, @@ -81,17 +81,3 @@ def _promptify_views(views: Dict[str, str]) -> str: """ return "\n".join([f"{name}: {description}" for name, description in views.items()]) - - -def _promptify_aggregations(views: Dict[str, str]) -> str: - """ - Formats views for aggregation - - Args: - views: dictionary of available view names with corresponding descriptions. - - Returns: - views_for_prompt: views formatted for prompt - """ - - return "\n".join([f"{name}: {description}" for name, description in views.items()]) diff --git a/src/dbally/views/decorators.py b/src/dbally/views/decorators.py index bd49cfaa..ac537f5f 100644 --- a/src/dbally/views/decorators.py +++ b/src/dbally/views/decorators.py @@ -14,17 +14,3 @@ def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missin return func return wrapped - -def view_aggregation() -> typing.Callable: - """ - Decorator for marking a method as an aggregation - - Returns: - Function that returns the decorated method - """ - - def wrapped(func: typing.Callable) -> typing.Callable: # pylint: disable=missing-return-doc - func._methodDecorator = view_aggregation # type:ignore # pylint: disable=protected-access - return func - - return wrapped diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index cfc720dc..b1783558 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -64,26 +64,6 @@ async def _build_filter_bool_op(self, bool_op: syntax.BoolOp) -> sqlalchemy.Colu return alchemy_op(await self._build_filter_node(bool_op.child)) raise ValueError(f"BoolOp {bool_op} has no children") - async def _build_aggregation_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: - """ - Converts a filter node from the IQLQuery to a SQLAlchemy expression. - """ - if isinstance(node, syntax.BoolOp): - return await self._build_filter_bool_op(node) - if isinstance(node, syntax.FunctionCall): - return await self.call_filter_method(node) - - raise ValueError(f"Unsupported grammar: {node}") - - async def apply_aggregation(self, aggregation: IQLQuery) -> None: - """ - Applies the chosen aggregation to the view. - - Args: - aggregation: IQLQuery object representing the aggregation to apply - """ - self._select = self._select.where(await self._build_filter_node(aggregation.root)) - def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ Executes the generated SQL query and returns the results. diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 2107da77..c43e4c2b 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -58,7 +58,6 @@ async def ask( """ iql_generator = self.get_iql_generator(llm) filter_list = self.list_filters() - aggregation_list = self.list_aggregations() iql_filters, conversation = await iql_generator.generate_iql( question=query, @@ -105,23 +104,6 @@ async def apply_filters(self, filters: IQLQuery) -> None: filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply """ - @abc.abstractmethod - def list_aggregations(self) -> List[ExposedFunction]: - """ - - Returns: - Aggregations defined inside the View. - """ - - @abc.abstractmethod - async def apply_aggregations(self, filters: IQLQuery) -> None: - """ - Applies the chosen filters to the view. - - Args: - filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply - """ - @abc.abstractmethod def execute(self, dry_run: bool = False) -> ViewExecutionResult: """ From cd5bf7b76b97e8d9e46ff872859ccd0ffdef859e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= Date: Mon, 24 Jun 2024 13:31:35 +0200 Subject: [PATCH 5/8] chore: change feature request label --- .github/ISSUE_TEMPLATE/01_feature_request.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/01_feature_request.yml b/.github/ISSUE_TEMPLATE/01_feature_request.yml index c5d3a360..90087c8b 100644 --- a/.github/ISSUE_TEMPLATE/01_feature_request.yml +++ b/.github/ISSUE_TEMPLATE/01_feature_request.yml @@ -1,7 +1,7 @@ name: 🚀 Feature Request description: Submit a proposal/request for a new db-ally feature. title: "feat: " -labels: ["enhancement"] +labels: ["feature"] body: - type: markdown attributes: From d4826385e95505c077a1c710feeba68ddcaef20c Mon Sep 17 00:00:00 2001 From: sgnatonski Date: Tue, 25 Jun 2024 17:12:30 +0200 Subject: [PATCH 6/8] feat: few-shot selector (#42) --- .coveragerc | 1 + benchmark/dbally_benchmark/iql_benchmark.py | 6 +- examples/recruiting/views.py | 45 +++++++- src/dbally/iql_generator/iql_generator.py | 52 ++------- src/dbally/llms/base.py | 2 +- src/dbally/prompts/elements.py | 61 ++++++++++ src/dbally/prompts/formatters.py | 119 ++++++++++++++++++++ src/dbally/views/base.py | 10 ++ src/dbally/views/freeform/text2sql/view.py | 1 - src/dbally/views/structured.py | 19 +++- tests/unit/test_fewshot.py | 72 ++++++++++++ tests/unit/test_iql_format.py | 68 +++++++++++ tests/unit/test_iql_generator.py | 41 +++++-- 13 files changed, 435 insertions(+), 62 deletions(-) create mode 100644 src/dbally/prompts/elements.py create mode 100644 src/dbally/prompts/formatters.py create mode 100644 tests/unit/test_fewshot.py create mode 100644 tests/unit/test_iql_format.py diff --git a/.coveragerc b/.coveragerc index 2c70198f..5ba74cca 100644 --- a/.coveragerc +++ b/.coveragerc @@ -12,4 +12,5 @@ omit = exclude_lines = pragma: no cover if __name__ == .__main__. + \.\.\. show_missing = True \ No newline at end of file diff --git a/benchmark/dbally_benchmark/iql_benchmark.py b/benchmark/dbally_benchmark/iql_benchmark.py index adf33710..7bb2ae28 100644 --- a/benchmark/dbally_benchmark/iql_benchmark.py +++ b/benchmark/dbally_benchmark/iql_benchmark.py @@ -23,6 +23,7 @@ from dbally.iql_generator.iql_generator import IQLGenerator from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template from dbally.llms.litellm import LiteLLM +from dbally.prompts.formatters import IQLInputFormatter from dbally.views.structured import BaseStructuredView @@ -31,11 +32,10 @@ async def _run_iql_for_single_example( ) -> IQLResult: filter_list = view.list_filters() event_tracker = EventTracker() + input_formatter = IQLInputFormatter(question=example.question, filters=filter_list) try: - iql_filters, _ = await iql_generator.generate_iql( - question=example.question, filters=filter_list, event_tracker=event_tracker - ) + iql_filters, _ = await iql_generator.generate_iql(input_formatter=input_formatter, event_tracker=event_tracker) except UnsupportedQueryError: return IQLResult(question=example.question, iql_filters="UNSUPPORTED_QUERY", exception_raised=True) diff --git a/examples/recruiting/views.py b/examples/recruiting/views.py index 22afb455..63a6c821 100644 --- a/examples/recruiting/views.py +++ b/examples/recruiting/views.py @@ -1,10 +1,13 @@ -from typing import Literal +from datetime import date +from typing import List, Literal import awoc # pip install a-world-of-countries import sqlalchemy +from dateutil.relativedelta import relativedelta from sqlalchemy import and_, select from dbally import SqlAlchemyBaseView, decorators +from dbally.prompts.elements import FewShotExample from .db import Candidate @@ -57,3 +60,43 @@ def is_from_continent( # pylint: disable=W0602, C0116, W9011 @decorators.view_filter() def studied_at(self, university: str) -> sqlalchemy.ColumnElement: # pylint: disable=W0602, C0116, W9011 return Candidate.university == university + + +class FewShotRecruitmentView(RecruitmentView): + """ + A view for the recruitment database including examples of question:answers pairs (few-shot). + """ + + @decorators.view_filter() + def is_available_within_months( # pylint: disable=W0602, C0116, W9011 + self, months: int + ) -> sqlalchemy.ColumnElement: + start = date.today() + end = start + relativedelta(months=months) + return Candidate.available_from.between(start, end) + + def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011 + return [ + FewShotExample( + "Which candidates studied at University of Toronto?", + 'studied_at("University of Toronto")', + ), + FewShotExample( + "Do we have any soon available candidate?", + lambda: self.is_available_within_months(1), + ), + FewShotExample( + "Do we have any soon available perfect fits for senior data scientist positions?", + lambda: ( + self.is_available_within_months(1) + and self.data_scientist_position() + and self.has_seniority("senior") + ), + ), + FewShotExample( + "List all junior or senior data scientist positions", + lambda: ( + self.data_scientist_position() and (self.has_seniority("junior") or self.has_seniority("senior")) + ), + ), + ] diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 8633afc0..cea13957 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,11 +1,10 @@ -import copy -from typing import Callable, List, Optional, Tuple, TypeVar +from typing import List, Optional, Tuple, TypeVar from dbally.audit.event_tracker import EventTracker -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template +from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template # noqa from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.views.exposed_functions import ExposedFunction +from dbally.prompts.formatters import IQLInputFormatter class IQLGenerator: @@ -24,26 +23,16 @@ class IQLGenerator: TException = TypeVar("TException", bound=Exception) - def __init__( - self, - llm: LLM, - prompt_template: Optional[IQLPromptTemplate] = None, - promptify_view: Optional[Callable] = None, - ) -> None: + def __init__(self, llm: LLM) -> None: """ Args: llm: LLM used to generate IQL - prompt_template: If not provided by the users is set to `default_iql_template` - promptify_view: Function formatting filters for prompt """ self._llm = llm - self._prompt_template = prompt_template or copy.deepcopy(default_iql_template) - self._promptify_view = promptify_view or _promptify_filters async def generate_iql( self, - filters: List[ExposedFunction], - question: str, + input_formatter: IQLInputFormatter, event_tracker: EventTracker, conversation: Optional[IQLPromptTemplate] = None, llm_options: Optional[LLMOptions] = None, @@ -52,8 +41,7 @@ async def generate_iql( Uses LLM to generate IQL in text form Args: - question: user question - filters: list of filters exposed by the view + input_formatter: formatter used to prepare prompt arguments dictionary event_tracker: event store used to audit the generation process conversation: conversation to be continued llm_options: options to use for the LLM client @@ -61,21 +49,17 @@ async def generate_iql( Returns: IQL - iql generated based on the user question """ - filters_for_prompt = self._promptify_view(filters) - template = conversation or self._prompt_template + conversation, fmt = input_formatter(conversation or default_iql_template) llm_response = await self._llm.generate_text( - template=template, - fmt={"filters": filters_for_prompt, "question": question}, + template=conversation, + fmt=fmt, event_tracker=event_tracker, options=llm_options, ) - iql_filters = self._prompt_template.llm_response_parser(llm_response) - - if conversation is None: - conversation = self._prompt_template + iql_filters = conversation.llm_response_parser(llm_response) conversation = conversation.add_assistant_message(content=llm_response) @@ -98,19 +82,3 @@ def add_error_msg(self, conversation: IQLPromptTemplate, errors: List[TException msg += str(error) + "\n" return conversation.add_user_message(content=msg) - - -def _promptify_filters( - filters: List[ExposedFunction], -) -> str: - """ - Formats filters for prompt - - Args: - filters: list of filters exposed by the view - - Returns: - filters_for_prompt: filters formatted for prompt - """ - filters_for_prompt = "\n".join([str(filter) for filter in filters]) - return filters_for_prompt diff --git a/src/dbally/llms/base.py b/src/dbally/llms/base.py index 55b85c20..067fbe56 100644 --- a/src/dbally/llms/base.py +++ b/src/dbally/llms/base.py @@ -52,7 +52,7 @@ def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFo Returns: Prompt in the format of the client. """ - return [{**message, "content": message["content"].format(**fmt)} for message in template.chat] + return [{"role": message["role"], "content": message["content"].format(**fmt)} for message in template.chat] def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: """ diff --git a/src/dbally/prompts/elements.py b/src/dbally/prompts/elements.py new file mode 100644 index 00000000..2937d7c1 --- /dev/null +++ b/src/dbally/prompts/elements.py @@ -0,0 +1,61 @@ +import inspect +import re +import textwrap +from typing import Callable, Union + + +class FewShotExample: + """ + A question:answer representation for few-shot prompting + """ + + def __init__(self, question: str, answer_expr: Union[str, Callable]) -> None: + """ + Args: + question: sample question + answer_expr: it can be either a stringified expression or a lambda for greater safety and code completions. + + Raises: + ValueError: If answer_expr is not a correct type. + """ + self.question = question + self.answer_expr = answer_expr + + if isinstance(self.answer_expr, str): + self.answer = self.answer_expr + elif callable(answer_expr): + self.answer = self._parse_lambda(answer_expr) + else: + raise ValueError("Answer expression should be either a string or a lambda") + + def _parse_lambda(self, expr: Callable) -> str: + """ + Parses provided callable in order to extract the lambda code. + All comments and references to variables like `self` etc will be removed + to form a simple lambda representation. + + Args: + expr: lambda expression to parse + + Returns: + Parsed lambda in a form of cleaned up string + """ + # extract lambda from code + expr_source = textwrap.dedent(inspect.getsource(expr)) + expr_body = expr_source.replace("lambda:", "") + + # clean up by removing comments, new lines, free vars (self etc) + parsed_expr = re.sub("\\#.*\n", "\n", expr_body, flags=re.MULTILINE) + + for m_name in expr.__code__.co_names: + parsed_expr = parsed_expr.replace(f"{expr.__code__.co_freevars[0]}.{m_name}", m_name) + + # clean up any dangling commas or leading and trailing brackets + parsed_expr = " ".join(parsed_expr.split()).strip().rstrip(",").replace("( ", "(").replace(" )", ")") + if parsed_expr.startswith("("): + parsed_expr = parsed_expr[1:-1] + + return parsed_expr + + def __str__(self) -> str: + return self.answer diff --git a/src/dbally/prompts/formatters.py b/src/dbally/prompts/formatters.py new file mode 100644 index 00000000..c2cce950 --- /dev/null +++ b/src/dbally/prompts/formatters.py @@ -0,0 +1,119 @@ +import copy +from abc import ABCMeta, abstractmethod +from typing import Dict, List, Tuple + +from dbally.prompts.elements import FewShotExample +from dbally.prompts.prompt_template import PromptTemplate +from dbally.views.exposed_functions import ExposedFunction + + +def _promptify_filters( + filters: List[ExposedFunction], +) -> str: + """ + Formats filters for prompt + + Args: + filters: list of filters exposed by the view + + Returns: + filters formatted for prompt + """ + filters_for_prompt = "\n".join([str(filter) for filter in filters]) + return filters_for_prompt + + +class InputFormatter(metaclass=ABCMeta): + """ + Formats provided parameters to a form acceptable by IQL prompt + """ + + @abstractmethod + def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: + """ + Runs the input formatting for provided prompt template. + + Args: + conversation_template: a prompt template to use. + + Returns: + A tuple with template and a dictionary with formatted inputs. + """ + + +class IQLInputFormatter(InputFormatter): + """ + Formats provided parameters to a form acceptable by default IQL prompt + """ + + def __init__(self, filters: List[ExposedFunction], question: str) -> None: + self.filters = filters + self.question = question + + def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: + """ + Runs the input formatting for provided prompt template. + + Args: + conversation_template: a prompt template to use. + + Returns: + A tuple with template and a dictionary with formatted filters and a question. + """ + return conversation_template, { + "filters": _promptify_filters(self.filters), + "question": self.question, + } + + +class IQLFewShotInputFormatter(InputFormatter): + """ + Formats provided parameters to a form acceptable by default IQL prompt. + Calling it will inject `examples` before last message in a conversation. + """ + + def __init__( + self, + filters: List[ExposedFunction], + examples: List[FewShotExample], + question: str, + ) -> None: + self.filters = filters + self.question = question + self.examples = examples + + def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: + """ + Performs a deep copy of provided template and injects examples into chat history. + Also prepares filters and question to be included within the prompt. + + Args: + conversation_template: a prompt template to use to inject few-shot examples. + + Returns: + A tuple with deeply-copied and enriched with examples template + and a dictionary with formatted filters and a question. + """ + + template_copy = copy.deepcopy(conversation_template) + sys_msg = template_copy.chat[0] + existing_msgs = [msg for msg in template_copy.chat[1:] if "is_example" not in msg] + chat_examples = [ + msg + for example in self.examples + for msg in [ + {"role": "user", "content": example.question, "is_example": True}, + {"role": "assistant", "content": example.answer, "is_example": True}, + ] + ] + + template_copy.chat = ( + sys_msg, + *chat_examples, + *existing_msgs, + ) + + return template_copy, { + "filters": _promptify_filters(self.filters), + "question": self.question, + } diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index 7b6dfc81..a3278281 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -5,6 +5,7 @@ from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions +from dbally.prompts.elements import FewShotExample from dbally.similarity import AbstractSimilarityIndex IndexLocation = Tuple[str, str, str] @@ -49,3 +50,12 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc Mapping of similarity indexes to their locations. """ return {} + + def list_few_shots(self) -> List[FewShotExample]: + """ + List all examples to be injected into few-shot prompt. + + Returns: + List of few-shot examples + """ + return [] diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index af948b3b..6891785e 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -142,7 +142,6 @@ async def ask( Raises: Text2SQLError: If the text2sql query generation fails after n_retries. """ - conversation = text2sql_prompt sql, rows = None, None exceptions = [] diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c43e4c2b..8b95ecaa 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -8,6 +8,7 @@ from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions +from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter from dbally.views.exposed_functions import ExposedFunction from ..similarity import AbstractSimilarityIndex @@ -56,26 +57,32 @@ async def ask( Returns: The result of the query. """ + + filters = self.list_filters() + examples = self.list_few_shots() iql_generator = self.get_iql_generator(llm) - filter_list = self.list_filters() + + input_formatter = ( + IQLFewShotInputFormatter(question=query, filters=filters, examples=examples) + if examples + else IQLInputFormatter(question=query, filters=filters) + ) iql_filters, conversation = await iql_generator.generate_iql( - question=query, - filters=filter_list, + input_formatter=input_formatter, event_tracker=event_tracker, llm_options=llm_options, ) for _ in range(n_retries): try: - filters = await IQLQuery.parse(iql_filters, filter_list, event_tracker=event_tracker) + filters = await IQLQuery.parse(iql_filters, filters, event_tracker=event_tracker) await self.apply_filters(filters) break except (IQLError, ValueError) as e: conversation = iql_generator.add_error_msg(conversation, [e]) iql_filters, conversation = await iql_generator.generate_iql( - question=query, - filters=filter_list, + input_formatter=input_formatter, event_tracker=event_tracker, conversation=conversation, llm_options=llm_options, diff --git a/tests/unit/test_fewshot.py b/tests/unit/test_fewshot.py new file mode 100644 index 00000000..e2f4cf8d --- /dev/null +++ b/tests/unit/test_fewshot.py @@ -0,0 +1,72 @@ +from typing import Callable, List, Tuple + +import pytest + +from dbally.prompts.elements import FewShotExample + + +class TestExamples: + def studied_at(self, _: str): + return False + + def is_available_within_months(self, _: int): + return False + + def data_scientist_position(self): + return False + + def has_seniority(self, _: str): + return False + + def __call__(self) -> List[Tuple[str, Callable]]: # pylint: disable=W0602, C0116, W9011 + return [ + ( + # dummy test + "None", + lambda: None, + ), + ( + # test lambda + "True and False or data_scientist_position() or (True or True)", + lambda: (True and False or self.data_scientist_position() or (True or True)), + ), + ( + # test string + 'studied_at("University of Toronto")', + lambda: self.studied_at("University of Toronto"), + ), + ( + # test complex conditions with comments + 'is_available_within_months(1) and data_scientist_position() and has_seniority("senior")', + lambda: ( + self.is_available_within_months(1) + and self.data_scientist_position() + and self.has_seniority("senior") + ), # pylint: disable=line-too-long + ), + ( + # test nested conditions with comments + 'data_scientist_position(1) and (has_seniority("junior") or has_seniority("senior"))', + lambda: ( + self.data_scientist_position(1) + and ( + self.has_seniority("junior") or self.has_seniority("senior") + ) # pylint: disable=too-many-function-args + ), + ), + ] + + +def test_fewshot_string(): + result = FewShotExample("question", "answer") + assert result.answer == "answer" + assert str(result) == "answer" + + +@pytest.mark.parametrize( + "repr_lambda", + TestExamples()(), +) +def test_fewshot_lambda(repr_lambda: Tuple[str, Callable]): + result = FewShotExample("question", repr_lambda[1]) + assert str(result) == repr_lambda[0] diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py new file mode 100644 index 00000000..c2fb4274 --- /dev/null +++ b/tests/unit/test_iql_format.py @@ -0,0 +1,68 @@ +from typing import List + +import pytest + +from dbally.iql_generator.iql_prompt_template import default_iql_template +from dbally.prompts.elements import FewShotExample +from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter + + +async def test_iql_input_format_default() -> None: + input_fmt = IQLInputFormatter([], "") + + conversation, format = input_fmt(default_iql_template) + + assert len(conversation.chat) == len(default_iql_template.chat) + assert "filters" in format + assert "question" in format + + +async def test_iql_input_format_few_shot_default() -> None: + input_fmt = IQLFewShotInputFormatter([], [], "") + + conversation, format = input_fmt(default_iql_template) + + assert len(conversation.chat) == len(default_iql_template.chat) + assert "filters" in format + assert "question" in format + + +@pytest.mark.parametrize( + "examples", + [ + [], + [FewShotExample("q1", "a1")], + ], +) +async def test_iql_input_format_few_shot_examples_injected(examples: List[FewShotExample]) -> None: + examples = [FewShotExample("q1", "a1")] + input_fmt = IQLFewShotInputFormatter([], examples, "") + + conversation, format = input_fmt(default_iql_template) + + assert len(conversation.chat) == len(default_iql_template.chat) + (len(examples) * 2) + assert "filters" in format + assert "question" in format + + +async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() -> None: + examples = [FewShotExample("q1", "a1")] + input_fmt = IQLFewShotInputFormatter([], examples, "q") + + conversation, _ = input_fmt(default_iql_template) + + assert len(conversation.chat) == len(default_iql_template.chat) + (len(examples) * 2) + assert conversation.chat[1]["role"] == "user" + assert conversation.chat[1]["content"] == examples[0].question + assert conversation.chat[2]["role"] == "assistant" + assert conversation.chat[2]["content"] == examples[0].answer + + conversation = conversation.add_assistant_message("response") + + conversation2, _ = input_fmt(conversation) + + assert len(conversation2.chat) == len(conversation.chat) + assert conversation2.chat[1]["role"] == "user" + assert conversation2.chat[1]["content"] == examples[0].question + assert conversation2.chat[2]["role"] == "assistant" + assert conversation2.chat[2]["content"] == examples[0].answer diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index c330f747..8c8df9e7 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -10,6 +10,8 @@ from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.iql_generator.iql_prompt_template import default_iql_template +from dbally.prompts.elements import FewShotExample +from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM @@ -52,28 +54,51 @@ def event_tracker() -> EventTracker: @pytest.mark.asyncio async def test_iql_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None: - iql_generator = IQLGenerator(llm, default_iql_template) + iql_generator = IQLGenerator(llm) - filters_for_prompt = iql_generator._promptify_view(view.list_filters()) - filters_in_prompt = set(filters_for_prompt.split("\n")) + filters = {str(_filter) for _filter in view.list_filters()} + assert filters == {"filter_by_id(idx: int)", "filter_by_name(city: str)"} - assert filters_in_prompt == {"filter_by_id(idx: int)", "filter_by_name(city: str)"} + input_formatter = IQLInputFormatter(question="Mock_question", filters=view.list_filters()) - response = await iql_generator.generate_iql(view.list_filters(), "Mock_question", event_tracker) + response = await iql_generator.generate_iql(input_formatter, event_tracker, default_iql_template) template_after_response = default_iql_template.add_assistant_message(content="LLM IQL mock answer") assert response == ("LLM IQL mock answer", template_after_response) template_after_response = template_after_response.add_user_message(content="Mock_error") - response2 = await iql_generator.generate_iql( - view.list_filters(), "Mock_question", event_tracker, template_after_response + response2 = await iql_generator.generate_iql(input_formatter, event_tracker, template_after_response) + template_after_2nd_response = template_after_response.add_assistant_message(content="LLM IQL mock answer") + assert response2 == ("LLM IQL mock answer", template_after_2nd_response) + + +@pytest.mark.asyncio +async def test_iql_few_shot_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None: + iql_generator = IQLGenerator(llm) + + filters = {str(_filter) for _filter in view.list_filters()} + assert filters == {"filter_by_id(idx: int)", "filter_by_name(city: str)"} + + input_formatter = IQLFewShotInputFormatter( + question="Mock_question", + filters=view.list_filters(), + examples=[FewShotExample("question", "filter_by_id(0)")], ) + + response = await iql_generator.generate_iql(input_formatter, event_tracker, default_iql_template) + + expected_conversation, _ = input_formatter(default_iql_template) + template_after_response = expected_conversation.add_assistant_message(content="LLM IQL mock answer") + assert response == ("LLM IQL mock answer", template_after_response) + + template_after_response = template_after_response.add_user_message(content="Mock_error") + response2 = await iql_generator.generate_iql(input_formatter, event_tracker, template_after_response) template_after_2nd_response = template_after_response.add_assistant_message(content="LLM IQL mock answer") assert response2 == ("LLM IQL mock answer", template_after_2nd_response) def test_add_error_msg(llm: MockLLM) -> None: - iql_generator = IQLGenerator(llm, default_iql_template) + iql_generator = IQLGenerator(llm) errors = [ValueError("Mock_error")] conversation = default_iql_template.add_assistant_message(content="Assistant") From a4fd4115bc7884f5043a6839cfefdd36c97e94ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Hordy=C5=84ski?= <26008518+mhordynski@users.noreply.github.com> Date: Tue, 2 Jul 2024 15:31:34 +0200 Subject: [PATCH 7/8] chore: doggify project (#67) --- README.md | 2 +- docs/assets/guide_dog_lg.png | Bin 0 -> 15525 bytes docs/assets/guide_dog_sm.png | Bin 0 -> 2533 bytes docs/stylesheets/extra.css | 10 ++++++++++ mkdocs.yml | 2 ++ 5 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 docs/assets/guide_dog_lg.png create mode 100644 docs/assets/guide_dog_sm.png diff --git a/README.md b/README.md index 9088bd15..0bbc2843 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -#

db-ally

+#

🦮 db-ally

Efficient, consistent and secure library for querying structured data with natural language diff --git a/docs/assets/guide_dog_lg.png b/docs/assets/guide_dog_lg.png new file mode 100644 index 0000000000000000000000000000000000000000..dee16c227311f84319c95ebc81fcc5e740eb5486 GIT binary patch literal 15525 zcmb`uc{r5cA2)vPnKAaPWsO1BvdfZfM3e~GsSH{uvQ(BVGnTScD3vWsN`#`U6=rNv zDxxA;#xCpF2eUkPpWpY--#^cFJ#)D(_w7FOKCkm$U*~-vNw%jfxH&{Q007)4j+>nZ z00Cbj02>`)ZKY=UNrgs=b}KtlomT(4vT38@vQU-g)=Ny zJ@4yrSd4Xw`xiej*%6#lG`i&P(PXpp{jXqER!!)oXaB_hzxfo<);grFiC_CxtvVUG zMf?D?C%TCYql>4qHchfumOp^nD~Ec%x>SF#%2y2de1mpR``YIPq|1=Y0#T#iDqO06 zg?S1p?;o~gXzZJK8$0yzMZtGo?9W1C_Zbb-7hlrcUu;~&8B#cixlN@W2l~IbYEqoS z8b>dE&e1qO2ukm=IX}r+4J!N$O5WrhIStEqfF#f-F>NR$$zmc6*=}GHFeE(!{)OdMv%#UUQLZn zg1DM*RAi&F5`&u_Gm4r%*vEKq_2WeRS(0Hpv&K8Bke!f~vU4M=kFyUoL;dP)g-9m) zAZ%&Gi!O`H@{G-~oc#SeV3gLl*^i=Vv=g)NmL408tO-#-ddV?k!r z#zeFzVY|7W{Z;4XV8#W(wH1{QK%a5tF!0A7(`I~9dtq=QU+)T=_D2?D`pwOQy68k# zrZ&T$=&H-&FnGdatls0+$M#nQAK=FF!^}iUZ35r9AQa8Wu_GDpbz%nJz0q=)*@{j* z7y^lTZm7l?LI=Mn(>KZ-TqD-3eh~&WWdq;oZ<~bNl zOU3`zpVzH>ePy>>(zEnI!^h^3(=)d`ElpbE^|G)c8W*nqG}2MocDbx-R<8Oj zOQo$2rCTN!YXfw{pRVa&nclza_Up{q6Eo))BbJZkcW*uQG(U0aRaN(3#M-8Z;bhUa ze3`#}w_}CM()~`)qXyz@SNNHHug>X0_(cPQp$+EN3D(kGVtA$K>Mhevwc^p{%ZpEt z(u%ZzL&oNYn(LQ0QuBv?f0`iHJJS0<%Lq8KkUeL+4=`9Fd3;A0;bHOC^tjIakqNdN zfqnOfSCdZi+`i= z9QA!7Q3kiK5X0czeM?p(xPm;?Eym!5QwJv(^ApgR(c>$`%3?ldDSD767 z7^@EmzQ_O65y?|}k_K_bV^$I>#a1%h7&@pO(|hp1bZW-CGC~y7K#)0&=15bd_kUgP z5I^t{WLD!t4IaHSX<3v+ZO8o^YSAY>x$E;wSK}?`zx6!39&n7->|9~w5Kj%rqj=h^ zgcdU|@D9{sC2lj^yeUzF*2mUtDXndG`{HJaxUhiby%#+ z@*kB5dnsqesv5Pw=Ouo2Ms@FM;%&&TH=*8H3v9{*4X#d^Q(D)8^VRBCQWF9gOK*vp z>%?2)gxOx^nv8oaNSPCj@LSX5W>Q zT>40Ti4T17t~D@uK-5K)c_4q!&2$?Q!QO=mBs#Qb4Amq^A> z6z!lRy0(vQ4r8a?>?jV4f4F$NcCZkEhWKFX%#H%-8!Lz_BK zrxrXytTtT<=+hOwARru*X()j049ej-OAFsME)b7ICst4|`KNC0r5NrNp{0i?F!b(m zB!6h4A9g@`IkV4({(0II!kp+I-)dXl#A7K2ZhS^awRIu(r)}w|IDFC(LA*{77uV4t zzCH0hGLoQY2G7*tak?^7f6HTTk$BVO@yCc%3DQ4}X;!mI;GXNQ*z`W7)+NX?ru7Jm z@v&E#2U9rQGFxN20V0KvOI=*8;^Ja_($fOR?r{iQ7qWeRVTd7>rykP{xab#H@HYNW zsPZdn>%~XWuWljYAJj4XGxm$|1sV6KMxu`|JtsEMw9Q}dJi6fnH14SZBtAs!*865n zrYHXb$6Fwwm!?zXlR?)<9bm}b&qJ3X)6%BJ43^o06i+OzfJa~Lwl>-Yt z)xMnB(EuN=8MX?Zw;8Fisa0G`sEaloq+bgDZDX}ZX!(McUR6RdTHHE8?$wDC^9VK{ z8xXE^&Ggs`UgKI^9p>be=V%+=jw67M5t^X0w41Ro6{xP(QYB{=NZW;tcY4Z`P2&07W>q> zU);QT5240ewM>DL${csv_*L4v$oba5Ya3n^kyfj#vOBio@+$wkpjZJ{5e!AzJM&omo7b5UgP9jZzD+qwcqkSPvCG;L~l~ z8kK^HqKwM@IV*N_U5FOAI+~s(_Cd(IQtr&ip1ZA5hMO1J8Rj~i5K=i__NmS zj!#hY?{wa^o}5zy_ug#?9N0Z5N>EN$Nt5nmxcc7;%!BIdh`8-v`6;-O=CM3x+o{=u z9DOVn>7|=uC$2o?L~WtOz-lioo$gO}kM@g-ZE)jH*qwg$GeQ6T;{BI%WHZ%gcB^sG z(iwC!{E%up@ob59_0PHU4>=Q$aAVuLnC;&i+u7dc)Y0CSUd=4ur`BUZdgbBOjW?Td zq+}N4V}my$?b}7Ufg79RXG9-U|6<-%o4d`W7+vY(lY5*zujJF~qQ|*wV%_(tz+@jOS;pTGSy_`>7riwzUb$|wB6Kw2RlsSHQemFE9w$Bb+`n}D+}U?D z;|VUA&R|zhT_v@VGeLs+(B<6Amzcjs_oa8gh$~{6CE?90C3HpgZx0T4CvMXnzw8vN zC7^C@(;qTNU81kw^WsY3c%&1xW)yJbebIZ#_gve1wJb?m#@`K@hv=G%s)W0XzUIPX z?i*AOy*0*f#?}z!pc!W{<+&yMmI%Rpm+P4Im2#tJgZdcV4C(ABMhebQ0li|azqD$s z51oG#-jZ;ThnW?bhfznr%r;3|xwkW>iJfwPE~j* z4CW*$=lPwWN8(kI75${wu7=wTx51Aee;MumTZ9E8llxtZFVAP7J94!9CE>v+vp=QL zDiPz3a}5=DKkP1-+x@6)X`Al1SVf%4f61&DCU~6=V19kHNt+-#?;K|1J6$Q(%dEhE z)sx~}ee-l}Fq-Brbn?Tc-F3~CZR~U5Rqrbvj<=Iu;%jNLQ=6l6vPtT42B*>KBU=LKB*4x4NmVJ+44kC zh@4u`8Rl*qe&qQ%0p(7KLo3rivvt358_95bVsFzP1Xo@7n_q|Dj$&?FV{0A0ure03 z9a}STnLrqwT*snTUOY}6Rd0OG8N7Jtb%F#b@>ESmxyWkQ0Y4^ml9|NEL>30_vNnsI zypE{vpuG=R{dtPROb*<+r7+&i3UsWgffO&llWB%wMf>)J2N{qq39f4b(`Z2^TlRsx zEeS&mM;)^#I;0s@J)09C95Ju1D@*&+% z70RC7a_?^DStu0o2{JAAST%0V#-wONy*CQ1d-ol9xK*YOQELvL;yU27T!sUM(w)2$ zJ>fwkic8(4kED!m(=53ia&vfceVk5&jYLfrg!m&KYD1Qv4nxP1J{@O!&quv{7bqe2 zT|x&U$o6rsQ`0bU0$UeM78@o%oyhge9vEjd_+|L&eM;??8`9x(L{cct$NUiv8SJ{) zQua=c;}(OOkgSyHd2U`6vvc4hi93If)}yOTRUgv})^|-XUR$V$RJGo7gc%osTPM!|Vcm2!b|9XBZ z!g!P}aDC)ll*`(mD|;N(z*PZ^8c$lfhQiavLYG&1T+6#pa|pU(1el{&_<)cvx~5GS z_o=wgS@u~q6Dog~G_P1w-;Y?uhDr%4C*iBB__cuRy_>~9PzKW9P??b@w5$l>{A^XC zY3uv2L9u=@!@E60hE-ivzDMu$U-~6_Kg}c0knVI%fWwv+KH_$x9*LxEie|8%F-LB)$TVZ zV>_PIIZ1qH{3@0C;a#0LX)2(ON@deejk}SBg43_3BuC*1qb>h&-gruF_ze0zSJX!ZLHxylQIcQtAGrG zmYJ(@F%tGDI}vQGWZjk$izJbiWP8ACgJPd0rX&Ue)UU>4B6_M0>0gZ6-a*E6a zcqnE#_q5PKzDexB_fN!7^=qnPOS>f=E3Artk6s`a=gZ>Ja^WKjVIG6RppKn4n>Eqg zMdXxhroZL$)qSKcGmf|WPiXKzPF0Ay$;Q*nw$`gAJZjsfot-TheK+@4s^94#ZB7;0gd1T@U``jb;%$siGM>zLukEaQC&Q8TT8Z!n^&^cH zNFxZVrm>5V6UqmeZ>W2T!`?gbBr5ClzRdkYJMgCY@nNR>{>=pC#c`pn0oARl$?kh_ z>blTT+t{=qTx3l+x|d|Y;_w9-mVgie<*U90aV=N-fq`*pfH&%QobqnaYn}8)*Zf9X zL&hEV75c9>`F-UCs+il+54nRp+ZRezZD5g~q75pP>2*)Ep|fb<2%E@d#V9G9;KG<0 za$xCw`~`56I-49hR^Qcdyk(u)_NvM`F&xB@`1Myg!RGx8=125eXgZHcfrj{)SkeQ$ zk^1dqG%4Kw@0-#Fb_H%0Z37;%mh4o)kEbWDyv>fR$wg^_SDx&_i&odKH5>UUwZ4x~ zakAOd@aEzeYga>~cVa|EJwY+R$9hg{uN& zc^+`oLp1&*M`{;Fc<^u*|LO&W18uJVmsS*uM1 z971wfrVaZoOt$V`!@r7rp3c1$Z8amjGrLXqzrks2$>>*&DlezR2}CoiZ?0rm927p$ zEJZXJeeVByc`EF@C_w@)4s(w7ENXHjeZg6I9 z@((l8!{0({PHl{)EKBvc6Zx1#=3L|zrj+KLh%n82Un-xcOgcJ*opZe&`RCxBn_HK;j$7hwn1zk(iPFSQ8YWXY@;XK5BI+(P;xNz2}duH+#o5y}tR1iZp8y zUD270KeXieZj`<~xSU76r}T7Y>Q+ank3hjW*M@!+wHxVjR^ilG}cC zaoX6zimhW0oPo|A%wYxhx{#f8Ghdk5K87H+gO+n6v1e3;9?j&f3U-{k?btr;KUzw` zxKcKT> z1it|^SrH#a3A5P}JYN}Ntg_jytEvL+C>Cbiyg2K4(Z?pOCBU9GELR{~U+439*YJyD z3@r2)H;IpH)&z*>iR!tt(BF0YKjHQgVQQ}yymCpIY;nv?a z@MSJd5{xYU)7eh>&*PKuIg1aBXXMC~2HoY!i~iupNzPNJRWsnK9Zu+#GT}ZW`p|dW zleeeRf9_p$JHIrS$)=#+`k>gklB9FFCQZd?521&%IVMID0kpjF05|&O>Cm#<{SHeW z_2-jBA2V)?QhwSTR334nez%U|aI0U-jx~2JEcw*#^F&o4&`rAg95T?-DF^)Yg?*lo zYCuqj;3?a_Bu@%#-fqp9&Oh@gu#lhfrKij;WvhKHcYPBTo9fOGKnDE4Gxo{cH+4msgGKv;!sdtm&73hF!3y zG2g!@uI)VQxDXVfWO_x!%TF-u$ves8jk;Tc7CB&JJhG0R?GO)6U9c0cBZ4Y#5-1S# z``T9+#oZ-Snw2uy^Jku)osKyx zw#tqDVBQR_J3>Qp1;G*p_fN)6;@@il5xA!?G&|4dlBgRS?ow`#rZE@LTQr_!_4M62 zua&v-&6QRz?@#*7Yz`=phurs6GF8Pt`XS_L6Q}sPp+)He=PVHyX#FuRx z9~$Q&5D4oQDJ;NC>&1isv#3{ZIMRo!Y`EQ@Afx^;SD)N+jNfTUXI#V6MbtltNr=!) zLfpSmqjBRrvZ+Pgn-G3zOR8S6xM;NYxj+f*N^cy5* zXZ~)hB%9zW_i%I)bH-jSb1~8>ux`737;x23+{SW|jOFq~XJ=NZruj9BG|1$OM-{@XJ5bZ2$6U89Zyl{*CRDmUT!0( zhm#QYgaqbpkqMln*xdqa;S`rWdk&@PLX|U$@t#!(vK}H#HdHJ zMzg=~_u+Zd(k_%1QW$(?b@**ta7h9~D>C_VFA$R^4{z(yHc*m4RBPoZ!)2Xk{?a9h zTn8l1P2P2tr}$`EUm*Fr5t>~6L>pK-Ll6I;T)L1bqE=^8r^Q!f@e8{XhoNa}Vlc6( zaf{(dY=TDd(%e|0yBks+bnSRtE~}H2z4@3Xc6xt-$bkCVsC%eEVD=#U=S)KIf1!(b z{7>^*k)8^EB&xqMK7*(hjYS zNV1108y-?b1H`>Ci(l=Mxk$ll2Tkt{KJWzJI(-ESaR2C983$-s12C!}eT5n5#|1P+ z^#~WVkIhkwW!o;_x#}ZL@Mx-HWLRsR^5*L3?xMh9Bt~{N!I(!AsW1|v_|es+*4nr# z3f>*UTa?ltD2H#zdTT;rkqK8=umn~CduHXeiupq0=&mp5)!F0-mc%;J{J-}5jqcc; zC@Gz^@j=yE>z5sSR?Tkv5w$SjjVS$sWOETUJ2rKWK-zAtxH)EhHhVULxYY*;=3=0# zt$0ipe^jRO#@uhkdwss(Mebh>s8~z@AVK65ggV6S1=I@Bm{)c!&H6`nX4*OMC-C^^ zA~&2*4sn6lXlYAEyg9Dx;Wc~r;wSb!Xe{K7sQq9~vLee1QXyGbr&a(-;5!VIMP-G* zeAbzYev=u_2qHIC{sQr5bkpM?gJWA;x&7(XA`Q<$QAfWV3EiI z#eCMZO07f)TZ?pKDZ0MUA*4*ZL}%c44{i4Wu5AbROwO1u{F3Vbts=ZP{GD=oxBY)r zZ~!e}dr|`Xtrv3&oYyaC3}&X{_06g&8MV~!AaRh*CpOyaKlpuZoWVKrxfQU^cC*2n zU{91p!BASkuo|QR+Yy2D8iZVzLYJ2_Sey|#vPR+yre-2B)&g$9gP`QAfG5CaK%fNi zBafB$jaEsV(&rG;?T5RrtHn^*&7DuV9$SM66)noSRkS1pnaULg&y4mdAG;i^o`71> zfA%Gz*xW=PzDRz%zfx){`Z3KHL~E|LSeh7rTb19gbpCm@e)-u;?yaZ;#oUBmcyp%` zm_=CoF->0sqA}rMS;|Blw1tr7Ohg*IrK;D_L}SJ7Uw75)7DV_##Z#9 znz518eee?h(+y1TFRTHQ14Iawz_#>anwVAlD|bSpANJ1koqtSw#~(IV*<6#Tf0EJm z{W(&Axgr^G)|lZ9^D4_6f(jC6MVME_+5fc%YK~ypw%~b@NhI@SQl%%2MW2rWr_H7KGCd_?_tz^W#P9qPqw!T{97 z$7T-UMG5T2AKaDCy?-c+cjxUeKOVSPf9FsC$hG^=tO>p7e?9?sS-0##i8c=qz+Cg? z-DUpXBg`|#O1;_7c&oQ#Jb-$QS{e1nibZ> zvL5ckiGqi;xcUv|^3T7)u&X}lZQSg>H~VLGsIzIY``$u&M6^SI#j|S2b`StK&C2d( z5^hMyk(?y3q+U!dv%LqLFI|}ExEej*((P`{jpZFrH8Bwuo?K}Z825?!YOGINZK;xR z+Y5(apM{|@$GX#d4%=qbndg?PK(xXD8!mte^R>NyIcq2cJCvzT>Ufq4ccUdob`~p8 zwdaBDW`NJ9eBQ*@+E6YU>kKKu-5>yva1L?#jD+2@ZG)(;4W(pIB9^BzM*BxI-IdFO;sv! z(W-|-A8Or)e@N3Yj7)0pzTJab56U6*76V&79mnOfoNC@JMo)qZGB1w7$ z??Z%dZ2x((v{iT5ql%uzn+^2;3zyyuwB8_iUh+WQ&`sAqdk=RP-4lz%wRv2JHL%`f zx^U$)5_p4#1Qff6==C!xN+JQsG|qDczx%o$kv7|zF!f^16{riR0k(ZviE`k(Um*K9S#E1^vZye zd>$!YE2e3zG*K(Dj`>{t6J%U1Z@!g)-h%G#@aW4#7pF?Fj|DV9m`IKig98e4C*v|{ zHofeH+-Z=pD8ch@fD0c}ukVG)`Y~b9QVbS26IXgMFprdDW+6-V?^%5cQU3}J{%Das zS>CYw>_+7e@P&4&Rhz~a!qOJ*Yl{eF^Z_l-|dx)TwxTgw$*h zOoy|W?Aj2E-0u7>&Y0*QA@2^aK9RZzPgBuR?@ZYnjRoj`XuKfEBpO>ba}!L|ZIq6%!P?JrLupJx;p zy(O(Vir&k+H4AG%iBx$I?lsnHl@wTGlRtM%XhX+1h#vpi3QVK;Khu@@$3?(V%=EM( z=IF^daessY!=UCCq(~EI^+VGKBsc9v6n6nzGyi@E=q`IA9OV{bgngy^ZS;_7y}C8e zcbJT_wQZ!72pTjpYTG2m7`$+Y3c{2lU_LNZWCG_NLA->GK-!8Pi8YF5wITln)U(oE z#5V<8kyy5m3oRhifIu!X@!Ohssmd9+vK)Qyextew*|g`>R+*IY<(;|;%FiX8&?RxU zx+!k_&#IfGj`qmg@qJLf$W2);0vZhx-9;zeY4D3;H?0)+b@kU77Z_nljJAri%lm#> zqDkvJ3EB%C+(HPxX7a~Z?M%xxL5#nz?Z$X_>oaEiQcuaOX%R$lMfVzZ-swtpU>=Nk@Aiv%dbDt-HRmR49<`$~ zU)`gz?B@3wbF8#@qFx?iD4c*+;SR25_j#7Q^1~ZsA5QL^x%V@H{CHVq(b*d2j>-oZ zuef)Uei(tHm99Mh8Z17re!p(RI&RL#TK*xl;}>*5kECX zp3A?0g8Pozz_}S!zXm*2xs@&pzybH~V1ei{QOTj*jyC(1l+;xnRL>FmpVvbj;I5L|c$2bj>*5A)*pabu;&G)eYx6)Kkyeuub~d<01D#x|XL z;PFz>Gob_Pn}t`+6*h!t)U z9@*)((y>H_`5>NT4i64v2e`xX6G%BBAI#LRl7?@HflM|5xi&v&*F#aWVf=q94C^)M_3C31-@s|>M0eQ=m>I?X}!>eRikmP8Nl91zR z&&q(FKO(z58NsuKcsTMxz5Ip4lb~lQw{_F)m@B#2?UJwJa$9WR5NZ>bD;PAu+A*8% zYC{bxYn{3~3I1Kk(MAncsZiNLZ?^6BnX9i((9|`ApTZeZg{|4CNuKrWga(usPT1p* zsBcKHcn$^fOA$~{8!uY`ne~JslcI5+O3ySj=4s`Ail9GueG7t{HFwzq*X^;V#0VkW z0lNlX?L<92#POcmgdBrGYN_v5+(0Oj8++PQAHxNA4vw8=neCCoyyC{5AEd(m<(4iR zs|79!U`Ey(uhO6G;71c9Ww?Not(B{zYJqS4V7xhw;nDYk`mg2=hFSqs0`-yEpL$Qf;BH?mIZJcih&bFxm$s zf8xVFXnK=jjU<#JxUmnU?n_L!nj+I|9Maqf%hLqIh@Z5>05ymM3y* zPtAHxU74>(gB!i!l@2$M+DN}Lohc9_FJt5+Btj7J zEbY-_I`EwMty3Gtk|Iz=K8@$C5FgQ1unEXH?xLyCwL(}VQvU@eRNzu$y#+r@|weqGfeZ0MQKs+-~fzz?0alqq22Sbx(NGk0c^xw=E5*CO<7MYVMQf+bgh)_wH`EW zDpgG(M@4-%YnN+$JT>{XSQcQPaBAFt$e;2!t0%%RL`ZJi>zJhWObyLU(sz%8JOrMg zc*DbqQxYY+D39{VJ8z>3Qbx%t!#6Y#J^Cb_tmBs3p*JFM>g>1QhLmMYawa=QBqSMX zzmw?|pC$jys4xRYjfGRCiloG><4WxM?lK+=o-GSWlp%eq@<+X0^RElJQjcFNKsvj> z=7++Tr8RCxx1RufaiaI9vFx0Hd_5AIpQG&Tu-e1O19 z5aC`(A-@s$D2@TUuik1wfH9G4b-;w3`jNcEmNJO@hymLp2UYJ^0-(aej|hxwzL%Gc zH~j+3w#>B~GqT zS?%8y8zFImV9SC5fr7SX?zj>JFR0&Z+ZKc^L1?}C|3tdHX(JJouyY7tke>8oMIBH& z?fQOL17HEQdjoQYJ0o9`*t~jUAHcO)wRyvPZeF0vp)jrv93VG0`0z692Jr&`T>TwT zi(&yioLDwe{2K^FMhpj{01DeS3~X6S?g0QdiBsyOh1mqc3lYEvec&})%2hNz|0=w-LQv8 zRbg9kiY_RO3YSoe%l9G%QKDZI*!C(AS3M}DR2}ap`dO{K8|BS5JsEjt zK1YB9o|+g)fjGz`Mos=1?b>1lNjeXMvU?>6uqCeHZOlga8GScUgVd1;5=#&T#@a}J zLgy*Qp6<=ELYtci0$_t&)V0+JYqhQq0ayI8N-hRJ-`QUJWF2`EF^e8B;ERHAnq(OZ zYx{W({P*(W=J^f`wHv+2QS{AcU^HNE}mjD=uk3&4_OwS(@cs=aB^qv!& zNj3f32&I4c3J+Z2(_x9$tdB!WAgkF4NyL@-M%E^yVOC79JbxgW#L8}fF1ZEbb2qt% z2l{V+BPdG&=@gcuSI8Lf=@s?Q%{3jdBS`IYz|4;gC1LEj`j7>)P|AGq>c^;Md=rvc z8OqALR{j&&!9!-Mx}dOkuupHD#HwOrSs)iS0z0;sKhedG{mmHSd=MuI1Gtje_9Z2n z+t#mWw;o#B90?r_$Lm*huYH;6f(`|SOzU@F`=+{_po$7@7^Y>y?5 zPjLy;5#1kSLqA$q#mI%mu90){Lc*wMEN7VLnmg#B`uuTpi-p{cPE zPHb)!YJmSZ0N%0fg5M>#MtKPfO*;W>1bg^ce?5nf#;ztOqB1pT9TE7&6>y=Yfj!`_ zwBresVZ)|)(EAaU-Dkm$r+M%rFeH{n7+y5Kj{(1p+9IHms|~nZIM}d;1U2qJUP3u> zM|DUTJt;^4z?LV?g_qEJA$xQ`3YWgwHT)0m7gk;r*`S@a$(VZAX%B)D;~_y74n*ML zRy+do5enzP%pKoHO2e3d!Iv4eXTBl=;iGwAmW!9km}E*Qg5sA=JEJ@{1yEDg!<`@OQ&tq=zYXK;W3zWf}~i33_i>?BDdGZ$tXOUxAcH#JZS;B=in@^CWsGXM)?15Xy!PS7}lS-k^ z8A8+2oQQT65LpknJodwY?hcL=^9J_DVj~dKdMtpA)N^Nn@x}`3Ebg(9BGUvAxd?zF zuTHI`df?kzFN7m(jd#jCqzcNw2Z(^WSX##E@(?)ON68_YFq3O}pR>5wcnMEcNUg=^ zAz==>M)%m}gNi!~9K0QXaurfw3O@kn-DZpTkogC1aHA0to0J942M)4A7|<%tj13N} z1JLZm>l!^Y5<(S7?DFpXQ6=V4Td|R=ajEC{3ED)~qHr(=RihA}V5};c6h`pMt|Wd@ z6G2JvD7T?2jK;1}9U8z;HV>v&-4HZ1)s~;?m%VNusq%a7BNG`mNJg)*3a)32Xz=r{~$ZwML!%rfex!4X!f&N!< zFiGr^uv#*m`8>16FH2T6rkfv}u|}gF@Da*h$23Ltyd*{1iTTx(3~b#Anr!?;zy>1% zqXK^OXa$ApneY-ej6Nz6;R6q?Lm+XZhB>Z|8FH0!_&1kdt~39t!oeR!igF z_j0`-6}5uHAE8Jc#jue=JXM);ufyaZ>xMmqhmwAcyz6_vMJgrhro>H#G-Kc z8F3F)oH!1qjTt&Jy!l9v{vWg9Xz}erXVbe7I(;CWG6H-N#F~VIYja9SgMEVh$ohq31kM zA}S1SKtVlwT;wT|C1;Re+Z3#3vtYpw$YI6Phv)i#QVCovH~~WjXO#yMpa|dz8!1kW zr{Lb7|2=OduwgxK5sK?Mk04>w{e?nMmf#BVW= zAP0c~*f16V+sh0h_|8U%G(PZTULJu31YdYMY*W6#w+*G$1n_V5yr6?Q z$Yg`va(X1`6wr>DKVzI7Y694R8cqP0h2X9Kw)S&i4nKis$1?tXp>r)aT?K)KxhNL? zpDEYyC@}EFYy^Z z(#~u!oP7`Lj;O4J4|k-M!3eM^ zSfUm@gX`#Sod^dhdQ6qM7bXFg@<5UwMb?$huV!A*uuF$kS|9}^a8>!|e*ZTGjt4-L zjS5&Kc=3YopDvgZZ)5=q1Bck=K4`BzU>(oCULFhl*hv!*q&wxLXblb_-doYS)K10B zUV8`(BzOs$@XR8$48u#JjsyS{51mq6efpRnz=CTyZQ4Uf>qf`V0ni z&{%-&F>827((?1ou~(=9IY6jj?%F^Hz!d~;XS9JAkoGG%@NhizBv{f;=t=;MM` z!5`&A#QFGhv=8+xkS-s(4F}aAF!^Qtlrl-78ek@0&TqSld9TGQUP{SKL9T805>d zdoWoMX}pyiHbUIwCH(D#fzms*cPoTQZN_J{DyLMLZ(h2=UMa|U`4noP*$I#6*Sy;|e2$T#-Vc?~@B!k89!7?DbH4NoKnvg(@*#lgvwmF$AFoHUH#lg` z&r9|oTteT})~+uFagaWK*)sdnt`v;!B$pB2)@r@k11EhyWDh!X!kH*7w;xVv=FC{& zkkm-+u2QxON_dxS27Y{nX#=VBj`n>a<~LN;?{RUG@O4N_)>890iP3CGw~;uPs`n9P z)xEyT%C7s3Rzu9+O7U39mzFzEF}~atv5h=HHaH|P*d|lC(M>o0_$+O6k{By{f>e8F zE!T~D>a2rr z%iih$;MjXy-30(xJAqAaxlTNjN`pWkH}0Yy0JRxeuCNjpFJxp*%>w|A1*m>#SP#9= z0+aDQCGU0b4dH%q5npnxK%jt2(dEX<`e_tA?EmLu)rJ2P=MBM62quhQ=>x9+F8?n+ aftco+M=U=j%i+Wez=>n0%w8YG#r_`w^dcz$ literal 0 HcmV?d00001 diff --git a/docs/assets/guide_dog_sm.png b/docs/assets/guide_dog_sm.png new file mode 100644 index 0000000000000000000000000000000000000000..85f91ee53228202908c70665841c7720820724bb GIT binary patch literal 2533 zcmVCfm%U|o+1*?W1SkaLsud8DRth4fRzgfgiXzZIY)i{nm?;J7U~zP+ z4NmJ+(YChLj+J(>NI*Ib)>4gPXhUKX0}^6@1UC2G?ESr*bAC_X5U?3%JN9`H~T}FRy_%0y8%p?7s$8HHLPT7^*c>bnb z=zSsMo5=#h9Jqv#mN!u*_rx;2ABQreEpP&qd}IBi$VNbf6etY$SvpVnZ-CDMmjX>X z;)Vpt9Ak@P!ok_z_{j6Wm>2EX6;iL8n1WA!vDz9t8VaT4(}x$u7{jNNwEaZ#3v+03 z35q=R=pPXrsk$UW1t=_EFZ?c<_ZuAq_Q`IZn;#}nph^jf*b0&pIzms@dgt1+w|-Sw zx|xJC#-~F3>&MgTWTQy($>D5nx}Q+Vqpf#(PjVRtg+dJDI*W+ZRFWdvhzme>V%Atj z_^*UNOUHBRP#DtlkOv$S0TC$aNNbh+Q_qhrx-CxLB7PXbA!-CbT@e2iF06_Ub^$x7f> zU`H6J&;93FLkCFGqFKhoHbqt~Y0vgR41`e7p6f~4tuFFbeN8gApsW7gBV{&tYEd+s zOTmAS;?0;3m#+OVq5T$Mbs)c^bcPA>rt~Tp3LzwgadOIt5-t$&;Z9#n`+gq~4|zn@MtyIxxJ z(JhyZt?#?KD{J==!w6}4U*SrvH9qLRW`n$_vN8M$I|ffquDa#>H*6oh;qBu)wl1IE z{qntkaLNN)xQsKlG(K|80`0nMaq`S<7hUtf#%+)PX64-UpV~2|%0CLs!#o^ID3PSZ#|YeK=={=Qc)d1Y{U{A&2U@cn_esCVFy?+y+BCPoUGE^g#+nn{f9mISdX$$hOuloeW5(NHO$z?4YeyS`S^u*$HQ72 zZZH1A>zbcX)^3}WxS%qQw4k?V>6_R){Og=oK2|-mw!Q?i=xivc#OOlc!jP1=R%&yf zYc=ai~^Tbd^V z6Lh4!LFPA#_2u`%N?Oz&`AYu&KVDI88X*!})YX_xH8$D9!LeRF_v%tETzwt;w{G)9 z*B)^yc5Z^AY?%(p0Z43_5X%K4qs&eWRq|GS3*x#Xe|H9DktyX3G2w=EHd(D%HU8J6Pq{s@XiDBWql@sKt`4juMsxpm>7~ueO6ry z#_wEw>yJPC(*3_gPydlnsWxNE>%x25-2UMYuw=>8YL6(+PM*ARC@s!b#WAZ})hyRw z-vmSbGb|oH$PahrdLphwSeQ4+=4T`euq}rWwdkB3KWTrfi?UtICJvsQJ0_%@=*Seq zt9J(}mcKpxUtedY{##gq5FwN%lp>Unz^cW8oukvtG%5y!Y+fUzD8RNljO5Z*U@pBl zt2X{6tF+wR?Bd}5rKFr-uy2~~C1WKS9Nf?5k^0=rpi?ms2t=8fD7FJhSI&wHUi7Pn z1_T%r6&y%QIZuowg@`Qyl8j*q4s^txlMc}Lp-}=SxOy4Pj1Iwsl7(F~i0PMz0h~k} z3=$n7M3dvQP8rp3|FkI;6`Yhss9XwKTuMnpa?|CCC50ZR?R&;FeR5EQ5qi|ec__t-TUli%uNn6HhKmW z0fiWCb(oGS78oNFHzWl94|CJ~VRC$7v2XYJADJB=DU=n7k}}vz0AZqJrAchY7(Vv$ zwoBLS-g@ODN!BjYYJ=H{5s$rkhDc%p*cKgcHuM>Rg=7LhC#))v%=Yhm(Et7Io7l7c z^2yy>uWIhywt|`Qg|8@-{|+1iLXt#A2Bz;(lK;u@Y0rl5eeJ#jSWjUQk|K=mdfQCi z=-RFjQ=ko;3S7%^iDMuz563|7&^{KOzMbB|{k^s{HoBajZoc~8j_f`ADm#497D1-B zcP0-`XpAkbti9_d3}3f}WMYxNFne_26!$QI4M2;4!3p;z*ml98;G zdDIoLIZot2LH9KIc7B|l{~2sHT1qgQADv=YcV2%sS$3h^{ zVNxLDxJh24Nh2^|rqoyRKwv#%rR}j7#UQ%r!b_LO<)0^uwi<9QK)uJm%s)5jB*ST1m+sji4$Pz8{~R%<#F*vA7lW6UdG+4mB`>yXwb>f5%+{pgcHjWA25F3k$14WWaU0ez16N=sWCS>69?hvU!3iiLm;;eyt Date: Tue, 2 Jul 2024 13:33:40 +0000 Subject: [PATCH 8/8] refactor(prompts): prompt templates (#66) --- benchmark/dbally_benchmark/e2e_benchmark.py | 10 +- benchmark/dbally_benchmark/iql_benchmark.py | 14 +- .../text2sql/prompt_template.py | 2 +- docs/about/roadmap.md | 3 +- docs/how-to/llms/custom.md | 25 +- docs/how-to/views/few-shots.md | 97 ++++++++ docs/reference/collection.md | 2 - docs/reference/index.md | 1 - docs/reference/iql/iql_generator.md | 4 - docs/reference/nl_responder.md | 4 - docs/reference/prompt.md | 7 + .../view_selection/llm_view_selector.md | 2 - examples/recruiting.py | 39 ++- examples/recruiting/views.py | 2 +- mkdocs.yml | 2 + src/dbally/assistants/openai.py | 2 +- src/dbally/audit/events.py | 2 +- src/dbally/gradio/gradio_interface.py | 4 +- src/dbally/iql/_query.py | 14 +- src/dbally/iql_generator/iql_generator.py | 100 ++++---- .../iql_generator/iql_prompt_template.py | 69 ------ src/dbally/iql_generator/prompt.py | 87 +++++++ src/dbally/llms/base.py | 42 +--- src/dbally/llms/clients/base.py | 10 +- src/dbally/llms/clients/litellm.py | 16 +- src/dbally/llms/litellm.py | 17 +- src/dbally/nl_responder/nl_responder.py | 97 +++----- .../nl_responder_prompt_template.py | 47 ---- src/dbally/nl_responder/prompts.py | 111 +++++++++ .../query_explainer_prompt_template.py | 48 ---- src/dbally/prompt/__init__.py | 3 + src/dbally/{prompts => prompt}/elements.py | 2 +- src/dbally/prompt/template.py | 234 ++++++++++++++++++ src/dbally/prompts/__init__.py | 4 - src/dbally/prompts/common_validation_utils.py | 51 ---- src/dbally/prompts/formatters.py | 119 --------- src/dbally/prompts/prompt_template.py | 83 ------- .../view_selection/llm_view_selector.py | 44 +--- src/dbally/view_selection/prompt.py | 52 ++++ .../view_selector_prompt_template.py | 51 ---- src/dbally/views/base.py | 2 +- src/dbally/views/freeform/text2sql/prompt.py | 61 +++++ src/dbally/views/freeform/text2sql/view.py | 55 ++-- src/dbally/views/structured.py | 41 +-- src/dbally_codegen/autodiscovery.py | 93 +++++-- tests/integration/test_llm_options.py | 12 +- tests/unit/mocks.py | 9 +- tests/unit/test_assistants_adapters.py | 2 +- tests/unit/test_collection.py | 46 +--- tests/unit/test_fewshot.py | 27 +- tests/unit/test_iql_format.py | 147 ++++++----- tests/unit/test_iql_generator.py | 112 +++++---- tests/unit/test_prompt_builder.py | 149 +++++------ tests/unit/test_view_selector.py | 2 +- 54 files changed, 1200 insertions(+), 1081 deletions(-) create mode 100644 docs/how-to/views/few-shots.md create mode 100644 docs/reference/prompt.md delete mode 100644 src/dbally/iql_generator/iql_prompt_template.py create mode 100644 src/dbally/iql_generator/prompt.py delete mode 100644 src/dbally/nl_responder/nl_responder_prompt_template.py create mode 100644 src/dbally/nl_responder/prompts.py delete mode 100644 src/dbally/nl_responder/query_explainer_prompt_template.py create mode 100644 src/dbally/prompt/__init__.py rename src/dbally/{prompts => prompt}/elements.py (97%) create mode 100644 src/dbally/prompt/template.py delete mode 100644 src/dbally/prompts/__init__.py delete mode 100644 src/dbally/prompts/common_validation_utils.py delete mode 100644 src/dbally/prompts/formatters.py delete mode 100644 src/dbally/prompts/prompt_template.py create mode 100644 src/dbally/view_selection/prompt.py delete mode 100644 src/dbally/view_selection/view_selector_prompt_template.py create mode 100644 src/dbally/views/freeform/text2sql/prompt.py diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index 9ba0871c..aa686727 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -23,9 +23,9 @@ import dbally from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError from dbally.llms.litellm import LiteLLM -from dbally.view_selection.view_selector_prompt_template import default_view_selector_template +from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult: @@ -126,9 +126,9 @@ async def evaluate(cfg: DictConfig) -> Any: logger.info(f"db-ally predictions saved under directory: {output_dir}") if run: - run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat) - run["config/view_selection_prompt_template"] = stringify_unsupported(default_view_selector_template.chat) - run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template) + run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat) + run["config/view_selection_prompt_template"] = stringify_unsupported(VIEW_SELECTION_TEMPLATE.chat) + run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE) run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix()) run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix()) run["evaluation/metrics"] = stringify_unsupported(metrics) diff --git a/benchmark/dbally_benchmark/iql_benchmark.py b/benchmark/dbally_benchmark/iql_benchmark.py index 7bb2ae28..2557b2c2 100644 --- a/benchmark/dbally_benchmark/iql_benchmark.py +++ b/benchmark/dbally_benchmark/iql_benchmark.py @@ -21,9 +21,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError from dbally.llms.litellm import LiteLLM -from dbally.prompts.formatters import IQLInputFormatter from dbally.views.structured import BaseStructuredView @@ -32,14 +31,17 @@ async def _run_iql_for_single_example( ) -> IQLResult: filter_list = view.list_filters() event_tracker = EventTracker() - input_formatter = IQLInputFormatter(question=example.question, filters=filter_list) try: - iql_filters, _ = await iql_generator.generate_iql(input_formatter=input_formatter, event_tracker=event_tracker) + iql_filters = await iql_generator.generate_iql( + question=example.question, + filters=filter_list, + event_tracker=event_tracker, + ) except UnsupportedQueryError: return IQLResult(question=example.question, iql_filters="UNSUPPORTED_QUERY", exception_raised=True) - return IQLResult(question=example.question, iql_filters=iql_filters, exception_raised=False) + return IQLResult(question=example.question, iql_filters=str(iql_filters), exception_raised=False) async def run_iql_for_dataset( @@ -139,7 +141,7 @@ async def evaluate(cfg: DictConfig) -> Any: logger.info(f"IQL predictions saved under directory: {output_dir}") if run: - run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat) + run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat) run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix()) run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix()) run["evaluation/metrics"] = stringify_unsupported(metrics) diff --git a/benchmark/dbally_benchmark/text2sql/prompt_template.py b/benchmark/dbally_benchmark/text2sql/prompt_template.py index abee9659..60349f38 100644 --- a/benchmark/dbally_benchmark/text2sql/prompt_template.py +++ b/benchmark/dbally_benchmark/text2sql/prompt_template.py @@ -1,4 +1,4 @@ -from dbally.prompts import PromptTemplate +from dbally.prompt import PromptTemplate TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate( ( diff --git a/docs/about/roadmap.md b/docs/about/roadmap.md index f6449c88..288aa359 100644 --- a/docs/about/roadmap.md +++ b/docs/about/roadmap.md @@ -10,14 +10,13 @@ Below you can find a list of planned features and integrations. ## Planned Features - [ ] **Support analytical queries**: support for exposing operations beyond filtering. -- [ ] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to +- [x] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to improve IQL generation accuracy. - [ ] **Request contextualization**: allow to provide extra context for db-ally runs, such as user asking the question. - [X] **OpenAI Assistants API adapter**: allow to embed db-ally into OpenAI's Assistants API to easily extend the capabilities of the assistant. - [ ] **Langchain adapter**: allow to embed db-ally into Langchain applications. - ## Integrations Being agnostic to the underlying technology is one of the main goals of db-ally. diff --git a/docs/how-to/llms/custom.md b/docs/how-to/llms/custom.md index c262351d..7e249847 100644 --- a/docs/how-to/llms/custom.md +++ b/docs/how-to/llms/custom.md @@ -44,42 +44,29 @@ class MyLLMClient(LLMClient[LiteLLMOptions]): async def call( self, - prompt: ChatFormat, - response_format: Optional[Dict[str, str]], + conversation: ChatFormat, options: LiteLLMOptions, event: LLMEvent, + json_mode: bool = False, ) -> str: # Your LLM API call ``` -The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response. +The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response in string format. ### Step 3: Use tokenizer to count tokens -The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the messages. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model. +The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the prompt. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model. ```python class MyLLM(LLM[LiteLLMOptions]): - def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: - # Count tokens in the messages in a custom way + def count_tokens(self, prompt: PromptTemplate) -> int: + # Count tokens in the prompt in a custom way ``` !!!warning Incorrect token counting can cause problems in the [`NLResponder`](../../reference/nl_responder.md#dbally.nl_responder.nl_responder.NLResponder) and force the use of an explanation prompt template that is more generic and does not include specific rows from the IQL response. -### Step 4: Define custom prompt formatting - -The [`format_prompt`](../../reference/llms/index.md#dbally.llms.base.LLM.format_prompt) method is used to apply formatting to the prompt template. You can override this method in your custom class to change how the formatting is performed. - -```python -class MyLLM(LLM[LiteLLMOptions]): - - def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat: - # Apply custom formatting to the prompt template -``` -!!!note - In general, implementation of this method is not required unless the LLM API does not support [OpenAI conversation formatting](https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages){:target="_blank"}. If your model API expects a different format, override this method to avoid issues with inference call. - ## Customising LLM Options [`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) is a class that defines the options your LLM will use. To create a custom options, you need to create a subclass of [`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) and define the required properties that will be passed to the [`LLMClient`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient). diff --git a/docs/how-to/views/few-shots.md b/docs/how-to/views/few-shots.md new file mode 100644 index 00000000..806ab171 --- /dev/null +++ b/docs/how-to/views/few-shots.md @@ -0,0 +1,97 @@ +# How-To: Define few shots + +There are many ways to improve the accuracy of IQL generation - one of them is to use few-shot prompting. db-ally allows you to inject few-shot examples for any type of defined view, both structured and freeform. + +Few shots are defined in the [`list_few_shots`](../../reference/views/index.md#dbally.views.base.BaseView.list_few_shots) method, each few shot example should be an instance of [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) class that defines example question and expected LLM answer. + +## Structured views + +For structured views, both questions and answers for [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) can be defined as a strings, whereas in case of answers Python expressions are also allowed (please see lambda function in example below). + +```python +from dbally.prompt.elements import FewShotExample +from dbally.views.sqlalchemy_base import SqlAlchemyBaseView + +class RecruitmentView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def list_few_shots(self) -> List[FewShotExample]: + return [ + FewShotExample( + "Which candidates studied at University of Toronto?", + 'studied_at("University of Toronto")', + ), + FewShotExample( + "Do we have any soon available perfect fits for senior data scientist positions?", + lambda: ( + self.is_available_within_months(1) + and self.data_scientist_position() + and self.has_seniority("senior") + ), + ), + ... + ] +``` + +## Freeform views + +Currently freeform views accept SQL query syntax as a raw string. The larger variety of passing parameters is considered to be implemented in further db-ally releases. + +```python +from dbally.prompt.elements import FewShotExample +from dbally.views.freeform.text2sql import BaseText2SQLView + +class RecruitmentView(BaseText2SQLView): + """ + A view for retrieving candidates from the database. + """ + + def list_few_shots(self) -> List[FewShotExample]: + return [ + FewShotExample( + "Which candidates studied at University of Toronto?", + 'SELECT name FROM candidates WHERE university = "University of Toronto"', + ), + FewShotExample( + "Which clients are from NY?", + 'SELECT name FROM clients WHERE city = "NY"', + ), + ... + ] +``` + +## Prompt format + +By default each few shot is injected subsequent to a system prompt message. The format is as follows: + +```python +[ + { + "role" "user", + "content": "Question", + }, + { + "role": "assistant", + "content": "Answer", + } +] +``` + +If you use `examples` formatting tag in content field of the system or user message, all examples are going to be injected inside the message without additional conversation. + +The example of prompt utilizing `examples` tag: + +```python +[ + { + "role" "system", + "content": "Here are example resonses:\n {examples}", + }, +] +``` + +!!!info + There is no best way to inject a few shot example. Different models can behave diffrently based on few shots formatting of choice. + Generally, first appoach should yield the best results in most cases. Therefore, adding example tags in your custom prompts is not recommended. diff --git a/docs/reference/collection.md b/docs/reference/collection.md index cb9b4b97..c7b7269a 100644 --- a/docs/reference/collection.md +++ b/docs/reference/collection.md @@ -3,8 +3,6 @@ !!! tip To understand the general idea better, visit the [Collection concept page](../concepts/collections.md). -::: dbally.create_collection - ::: dbally.collection.Collection ::: dbally.collection.results.ExecutionResult diff --git a/docs/reference/index.md b/docs/reference/index.md index 0deb591a..fa1abc4f 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -1,4 +1,3 @@ # dbally - ::: dbally.create_collection diff --git a/docs/reference/iql/iql_generator.md b/docs/reference/iql/iql_generator.md index 15edcb56..b91a0b0c 100644 --- a/docs/reference/iql/iql_generator.md +++ b/docs/reference/iql/iql_generator.md @@ -1,7 +1,3 @@ # IQLGenerator ::: dbally.iql_generator.iql_generator.IQLGenerator - -::: dbally.iql_generator.iql_prompt_template.IQLPromptTemplate - -::: dbally.iql_generator.iql_prompt_template.default_iql_template diff --git a/docs/reference/nl_responder.md b/docs/reference/nl_responder.md index fb80741c..531243de 100644 --- a/docs/reference/nl_responder.md +++ b/docs/reference/nl_responder.md @@ -26,7 +26,3 @@ Otherwise, a response is generated using a `nl_responder_prompt_template`. To understand general idea better, visit the [NL Responder concept page](../concepts/nl_responder.md). ::: dbally.nl_responder.nl_responder.NLResponder - -::: dbally.nl_responder.query_explainer_prompt_template - -::: dbally.nl_responder.nl_responder_prompt_template.default_nl_responder_template diff --git a/docs/reference/prompt.md b/docs/reference/prompt.md new file mode 100644 index 00000000..42ab8901 --- /dev/null +++ b/docs/reference/prompt.md @@ -0,0 +1,7 @@ +# Prompt + +::: dbally.prompt.template.PromptTemplate + +::: dbally.prompt.template.PromptFormat + +::: dbally.prompt.elements.FewShotExample diff --git a/docs/reference/view_selection/llm_view_selector.md b/docs/reference/view_selection/llm_view_selector.md index 774aa4b9..a177a8bd 100644 --- a/docs/reference/view_selection/llm_view_selector.md +++ b/docs/reference/view_selection/llm_view_selector.md @@ -1,5 +1,3 @@ # LLMViewSelector ::: dbally.view_selection.LLMViewSelector - -::: dbally.view_selection.view_selector_prompt_template.default_view_selector_template diff --git a/examples/recruiting.py b/examples/recruiting.py index a4813b41..ea16a934 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -9,9 +9,37 @@ from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.audit.event_tracker import EventTracker from dbally.llms.litellm import LiteLLM -from dbally.prompts import PromptTemplate +from dbally.prompt import PromptTemplate +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat -TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate( + +class Text2SQLPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by SQL prompt. + """ + + def __init__( + self, + *, + question: str, + schema: str, + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new Text2SQLInputFormat instance. + + Args: + question: Question to be asked. + schema: SQL schema description. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.schema = schema + + +TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate[Text2SQLPromptFormat]( ( { "role": "system", @@ -112,9 +140,10 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example for question in benchmark.questions: await recruitment_db.ask(question.dbally_question, return_natural_response=True) gpt_question = question.gpt_question if question.gpt_question else question.dbally_question - gpt_response = await llm.generate_text( - TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_description, "question": gpt_question}, event_tracker=event_tracker - ) + + prompt_format = Text2SQLPromptFormat(question=gpt_question, schema=db_description) + formatted_prompt = TEXT2SQL_PROMPT_TEMPLATE.format_prompt(prompt_format) + gpt_response = await llm.generate_text(formatted_prompt, event_tracker=event_tracker) print(f"GPT response: {gpt_response}") diff --git a/examples/recruiting/views.py b/examples/recruiting/views.py index 63a6c821..773d3f62 100644 --- a/examples/recruiting/views.py +++ b/examples/recruiting/views.py @@ -7,7 +7,7 @@ from sqlalchemy import and_, select from dbally import SqlAlchemyBaseView, decorators -from dbally.prompts.elements import FewShotExample +from dbally.prompt.elements import FewShotExample from .db import Candidate diff --git a/mkdocs.yml b/mkdocs.yml index 852eac20..826ffe15 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -21,6 +21,7 @@ nav: - how-to/views/text-to-sql.md - how-to/views/pandas.md - how-to/views/custom.md + - how-to/views/few-shots.md - Using LLMs: - how-to/llms/litellm.md - how-to/llms/custom.md @@ -59,6 +60,7 @@ nav: - LLMs: - reference/llms/index.md - reference/llms/litellm.md + - reference/prompt.md - Similarity: - reference/similarity/index.md - Store: diff --git a/src/dbally/assistants/openai.py b/src/dbally/assistants/openai.py index 8560cc95..4ec239df 100644 --- a/src/dbally/assistants/openai.py +++ b/src/dbally/assistants/openai.py @@ -6,7 +6,7 @@ from dbally.assistants.base import AssistantAdapter, FunctionCallingError, FunctionCallState from dbally.collection import Collection -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError +from dbally.iql_generator.prompt import UnsupportedQueryError _DBALLY_INFO = "Dbally has access to the following database views: " diff --git a/src/dbally/audit/events.py b/src/dbally/audit/events.py index c02cd5cb..de397a74 100644 --- a/src/dbally/audit/events.py +++ b/src/dbally/audit/events.py @@ -3,7 +3,7 @@ from typing import Optional, Union from dbally.collection.results import ExecutionResult -from dbally.prompts import ChatFormat +from dbally.prompt.template import ChatFormat @dataclass diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py index 30182b37..761b0dd2 100644 --- a/src/dbally/gradio/gradio_interface.py +++ b/src/dbally/gradio/gradio_interface.py @@ -9,8 +9,8 @@ from dbally.audit import CLIEventHandler from dbally.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError -from dbally.prompts import PromptTemplateError +from dbally.iql_generator.prompt import UnsupportedQueryError +from dbally.prompt.template import PromptTemplateError async def create_gradio_interface(user_collection: Collection, preview_limit: int = 10) -> gradio.Interface: diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index 7ad86490..c2131a57 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -15,12 +15,19 @@ class IQLQuery: root: syntax.Node - def __init__(self, root: syntax.Node): + def __init__(self, root: syntax.Node, source: str) -> None: self.root = root + self._source = source + + def __str__(self) -> str: + return self._source @classmethod async def parse( - cls, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None + cls, + source: str, + allowed_functions: List["ExposedFunction"], + event_tracker: Optional[EventTracker] = None, ) -> "IQLQuery": """ Parse IQL string to IQLQuery object. @@ -32,4 +39,5 @@ async def parse( Returns: IQLQuery object """ - return cls(await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process()) + root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process() + return cls(root=root, source=source) diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index cea13957..7eeb9154 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,10 +1,16 @@ -from typing import List, Optional, Tuple, TypeVar +from typing import List, Optional from dbally.audit.event_tracker import EventTracker -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template # noqa +from dbally.iql import IQLError, IQLQuery +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts.formatters import IQLInputFormatter +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptTemplate +from dbally.views.exposed_functions import ExposedFunction + +ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ + generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" class IQLGenerator: @@ -18,67 +24,61 @@ class IQLGenerator: It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question. """ - _ERROR_MSG_PREFIX = "Unfortunately, generated IQL is not valid. Please try again, \ - generation of correct IQL is very important. Below you have errors generated by the system: \n" - - TException = TypeVar("TException", bound=Exception) - - def __init__(self, llm: LLM) -> None: + def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None) -> None: """ + Constructs a new IQLGenerator instance. + Args: llm: LLM used to generate IQL """ self._llm = llm + self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE async def generate_iql( self, - input_formatter: IQLInputFormatter, + question: str, + filters: List[ExposedFunction], event_tracker: EventTracker, - conversation: Optional[IQLPromptTemplate] = None, + examples: Optional[List[FewShotExample]] = None, llm_options: Optional[LLMOptions] = None, - ) -> Tuple[str, IQLPromptTemplate]: + n_retries: int = 3, + ) -> IQLQuery: """ - Uses LLM to generate IQL in text form + Generates IQL in text form using LLM. Args: - input_formatter: formatter used to prepare prompt arguments dictionary - event_tracker: event store used to audit the generation process - conversation: conversation to be continued - llm_options: options to use for the LLM client + question: User question. + filters: List of filters exposed by the view. + event_tracker: Event store used to audit the generation process. + examples: List of examples to be injected into the conversation. + llm_options: Options to use for the LLM client. + n_retries: Number of retries to regenerate IQL in case of errors. Returns: - IQL - iql generated based on the user question + Generated IQL query. """ - - conversation, fmt = input_formatter(conversation or default_iql_template) - - llm_response = await self._llm.generate_text( - template=conversation, - fmt=fmt, - event_tracker=event_tracker, - options=llm_options, + prompt_format = IQLGenerationPromptFormat( + question=question, + filters=filters, + examples=examples, ) - - iql_filters = conversation.llm_response_parser(llm_response) - - conversation = conversation.add_assistant_message(content=llm_response) - - return iql_filters, conversation - - def add_error_msg(self, conversation: IQLPromptTemplate, errors: List[TException]) -> IQLPromptTemplate: - """ - Appends to the conversation error messages returned due to the invalid IQL generated by the LLM. - - Args: - conversation (IQLPromptTemplate): conversation containing current IQL generation trace - errors (List[Exception]): errors to be appended - - Returns: - IQLPromptTemplate: Conversation extended with errors - """ - - msg = self._ERROR_MSG_PREFIX - for error in errors: - msg += str(error) + "\n" - - return conversation.add_user_message(content=msg) + formatted_prompt = self._prompt_template.format_prompt(prompt_format) + + for _ in range(n_retries + 1): + try: + response = await self._llm.generate_text( + prompt=formatted_prompt, + event_tracker=event_tracker, + options=llm_options, + ) + # TODO: Move response parsing to llm generate_text method + iql = formatted_prompt.response_parser(response) + # TODO: Move IQL query parsing to prompt response parser + return await IQLQuery.parse( + source=iql, + allowed_functions=filters, + event_tracker=event_tracker, + ) + except IQLError as exc: + formatted_prompt = formatted_prompt.add_assistant_message(response) + formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py deleted file mode 100644 index 2da8abd2..00000000 --- a/src/dbally/iql_generator/iql_prompt_template.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Callable, Dict, Optional - -from dbally.exceptions import DbAllyError -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class UnsupportedQueryError(DbAllyError): - """ - Error raised when IQL generator is unable to construct a query - with given filters. - """ - - -class IQLPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant for the IQL - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ): - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"filters", "question"}) - - -def _validate_iql_response(llm_response: str) -> str: - """ - Validates LLM response to IQL - - Args: - llm_response: LLM response - - Returns: - A string containing IQL for filters. - - Raises: - UnsuppotedQueryError: When IQL generator is unable to construct a query - with given filters. - """ - - if "unsupported query" in llm_response.lower(): - raise UnsupportedQueryError - return llm_response - - -default_iql_template = IQLPromptTemplate( - chat=( - { - "role": "system", - "content": "You have access to API that lets you query a database:\n" - "\n{filters}\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" - "Remember! Don't give any comments, just the function calls.\n" - "The output will look like this:\n" - 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{filters}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ - "This is CRUCIAL, otherwise the system will crash. ", - }, - {"role": "user", "content": "{question}"}, - ), - llm_response_parser=_validate_iql_response, -) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py new file mode 100644 index 00000000..44bb2cd4 --- /dev/null +++ b/src/dbally/iql_generator/prompt.py @@ -0,0 +1,87 @@ +# pylint: disable=C0301 + +from typing import List + +from dbally.exceptions import DbAllyError +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate +from dbally.views.exposed_functions import ExposedFunction + + +class UnsupportedQueryError(DbAllyError): + """ + Error raised when IQL generator is unable to construct a query + with given filters. + """ + + +def _validate_iql_response(llm_response: str) -> str: + """ + Validates LLM response to IQL + + Args: + llm_response: LLM response + + Returns: + A string containing IQL for filters. + + Raises: + UnsuppotedQueryError: When IQL generator is unable to construct a query + with given filters. + """ + if "unsupported query" in llm_response.lower(): + raise UnsupportedQueryError + return llm_response + + +class IQLGenerationPromptFormat(PromptFormat): + """ + IQL prompt format, providing a question and filters to be used in the conversation. + """ + + def __init__( + self, + *, + question: str, + filters: List[ExposedFunction], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new IQLGenerationPromptFormat instance. + + Args: + question: Question to be asked. + filters: List of filters exposed by the view. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.filters = "\n".join([str(filter) for filter in filters]) + + +IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( + [ + { + "role": "system", + "content": ( + "You have access to API that lets you query a database:\n" + "\n{filters}\n" + "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Remember! Don't give any comments, just the function calls.\n" + "The output will look like this:\n" + 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{filters}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + "This is CRUCIAL, otherwise the system will crash. " + ), + }, + { + "role": "user", + "content": "{question}", + }, + ], + response_parser=_validate_iql_response, +) diff --git a/src/dbally/llms/base.py b/src/dbally/llms/base.py index 067fbe56..7e2381e1 100644 --- a/src/dbally/llms/base.py +++ b/src/dbally/llms/base.py @@ -1,12 +1,11 @@ from abc import ABC, abstractmethod from functools import cached_property -from typing import Dict, Generic, Optional, Type +from typing import Generic, Optional, Type from dbally.audit.event_tracker import EventTracker from dbally.audit.events import LLMEvent from dbally.llms.clients.base import LLMClient, LLMClientOptions, LLMOptions -from dbally.prompts.common_validation_utils import ChatFormat -from dbally.prompts.prompt_template import PromptTemplate +from dbally.prompt.template import PromptTemplate class LLM(Generic[LLMClientOptions], ABC): @@ -41,36 +40,21 @@ def client(self) -> LLMClient: Client for the LLM. """ - def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat: + def count_tokens(self, prompt: PromptTemplate) -> int: """ - Applies formatting to the prompt template. + Counts tokens in the prompt. Args: - template: Prompt template in system/user/assistant openAI format. - fmt: Dictionary with formatting. + prompt: Formatted prompt template with conversation and response parsing configuration. Returns: - Prompt in the format of the client. + Number of tokens in the prompt. """ - return [{"role": message["role"], "content": message["content"].format(**fmt)} for message in template.chat] - - def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: - """ - Counts tokens in the messages. - - Args: - messages: Messages to count tokens for. - fmt: Arguments to be used with prompt. - - Returns: - Number of tokens in the messages. - """ - return sum(len(message["content"].format(**fmt)) for message in messages) + return sum(len(message["content"]) for message in prompt.chat) async def generate_text( self, - template: PromptTemplate, - fmt: Dict[str, str], + prompt: PromptTemplate, *, event_tracker: Optional[EventTracker] = None, options: Optional[LLMOptions] = None, @@ -79,8 +63,7 @@ async def generate_text( Prepares and sends a prompt to the LLM and returns the response. Args: - template: Prompt template in system/user/assistant openAI format. - fmt: Dictionary with formatting. + prompt: Formatted prompt template with conversation and response parsing configuration. event_tracker: Event store used to audit the generation process. options: Options to use for the LLM client. @@ -88,16 +71,15 @@ async def generate_text( Text response from LLM. """ options = (self.default_options | options) if options else self.default_options - prompt = self.format_prompt(template, fmt) - event = LLMEvent(prompt=prompt, type=type(template).__name__) + event = LLMEvent(prompt=prompt.chat, type=type(prompt).__name__) event_tracker = event_tracker or EventTracker() async with event_tracker.track_event(event) as span: event.response = await self.client.call( - prompt=prompt, - response_format=template.response_format, + conversation=prompt.chat, options=options, event=event, + json_mode=prompt.json_mode, ) span(event) diff --git a/src/dbally/llms/clients/base.py b/src/dbally/llms/clients/base.py index 5de63ce7..0293390f 100644 --- a/src/dbally/llms/clients/base.py +++ b/src/dbally/llms/clients/base.py @@ -3,7 +3,7 @@ from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar from dbally.audit.events import LLMEvent -from dbally.prompts import ChatFormat +from dbally.prompt.template import ChatFormat from ..._types import NotGiven @@ -67,19 +67,19 @@ def __init__(self, model_name: str) -> None: @abstractmethod async def call( self, - prompt: ChatFormat, - response_format: Optional[Dict[str, str]], + conversation: ChatFormat, options: LLMClientOptions, event: LLMEvent, + json_mode: bool = False, ) -> str: """ Calls LLM inference API. Args: - prompt: Prompt passed to the LLM. - response_format: Optional argument used in the OpenAI API - used to force a json output + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. options: Additional settings used by LLM. event: LLMEvent instance which fields should be filled during the method execution. + json_mode: Force the response to be in JSON format. Returns: Response string from LLM. diff --git a/src/dbally/llms/clients/litellm.py b/src/dbally/llms/clients/litellm.py index b15ad362..1e23df91 100644 --- a/src/dbally/llms/clients/litellm.py +++ b/src/dbally/llms/clients/litellm.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union try: import litellm @@ -12,7 +12,7 @@ from dbally.audit.events import LLMEvent from dbally.llms.clients.base import LLMClient, LLMOptions from dbally.llms.clients.exceptions import LLMConnectionError, LLMResponseError, LLMStatusError -from dbally.prompts import ChatFormat +from dbally.prompt.template import ChatFormat from ..._types import NOT_GIVEN, NotGiven @@ -72,19 +72,19 @@ def __init__( async def call( self, - prompt: ChatFormat, - response_format: Optional[Dict[str, str]], + conversation: ChatFormat, options: LiteLLMOptions, event: LLMEvent, + json_mode: bool = False, ) -> str: """ Calls the appropriate LLM endpoint with the given prompt and options. Args: - prompt: Prompt as an OpenAI client style list. - response_format: Optional argument used in the OpenAI API - used to force the json output + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. options: Additional settings used by the LLM. event: Container with the prompt, LLM response and call metrics. + json_mode: Force the response to be in JSON format. Returns: Response string from LLM. @@ -94,9 +94,11 @@ async def call( LLMStatusError: If the LLM API returns an error status code. LLMResponseError: If the LLM API response is invalid. """ + response_format = {"type": "json_object"} if json_mode else None + try: response = await litellm.acompletion( - messages=prompt, + messages=conversation, model=self.model_name, base_url=self.base_url, api_key=self.api_key, diff --git a/src/dbally/llms/litellm.py b/src/dbally/llms/litellm.py index c5699a1e..077474e9 100644 --- a/src/dbally/llms/litellm.py +++ b/src/dbally/llms/litellm.py @@ -1,5 +1,5 @@ from functools import cached_property -from typing import Dict, Optional +from typing import Optional try: import litellm @@ -10,7 +10,7 @@ from dbally.llms.base import LLM from dbally.llms.clients.litellm import LiteLLMClient, LiteLLMOptions -from dbally.prompts import ChatFormat +from dbally.prompt.template import PromptTemplate class LiteLLM(LLM[LiteLLMOptions]): @@ -65,17 +65,14 @@ def client(self) -> LiteLLMClient: api_version=self.api_version, ) - def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: + def count_tokens(self, prompt: PromptTemplate) -> int: """ - Counts tokens in the messages using a specified model. + Counts tokens in the prompt. Args: - messages: Messages to count tokens for. - fmt: Arguments to be used with prompt. + prompt: Formatted prompt template with conversation and response parsing configuration. Returns: - Number of tokens in the messages. + Number of tokens in the prompt. """ - return sum( - litellm.token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages - ) + return sum(litellm.token_counter(model=self.model_name, text=message["content"]) for message in prompt.chat) diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index 8bcafb11..7a8f98e4 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -1,48 +1,44 @@ -import copy -from typing import Dict, List, Optional - -import pandas as pd +from typing import Optional from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template -from dbally.nl_responder.query_explainer_prompt_template import ( - QueryExplainerPromptTemplate, - default_query_explainer_template, +from dbally.nl_responder.prompts import ( + NL_RESPONSE_TEMPLATE, + QUERY_EXPLANATION_TEMPLATE, + NLResponsePromptFormat, + QueryExplanationPromptFormat, ) +from dbally.prompt.template import PromptTemplate class NLResponder: - """Class used to generate natural language response from the database output.""" - - # Keys used to extract the query from the context (ordered by priority) - QUERY_KEYS = ["iql", "sql", "query"] + """ + Class used to generate natural language response from the database output. + """ def __init__( self, llm: LLM, - query_explainer_prompt_template: Optional[QueryExplainerPromptTemplate] = None, - nl_responder_prompt_template: Optional[NLResponderPromptTemplate] = None, + prompt_template: Optional[PromptTemplate[NLResponsePromptFormat]] = None, + explainer_prompt_template: Optional[PromptTemplate[QueryExplanationPromptFormat]] = None, max_tokens_count: int = 4096, ) -> None: """ + Constructs a new NLResponder instance. + Args: - llm: LLM used to generate natural language response - query_explainer_prompt_template: template for the prompt used to generate the iql explanation - if not set defaults to `default_query_explainer_template` - nl_responder_prompt_template: template for the prompt used to generate the NL response - if not set defaults to `nl_responder_prompt_template` - max_tokens_count: maximum number of tokens that can be used in the prompt + llm: LLM used to generate natural language response. + prompt_template: Template for the prompt used to generate the NL response + if not set defaults to `NL_RESPONSE_TEMPLATE`. + explainer_prompt_template: Template for the prompt used to generate the iql explanation + if not set defaults to `QUERY_EXPLANATION_TEMPLATE`. + max_tokens_count: Maximum number of tokens that can be used in the prompt. """ self._llm = llm - self._nl_responder_prompt_template = nl_responder_prompt_template or copy.deepcopy( - default_nl_responder_template - ) - self._query_explainer_prompt_template = query_explainer_prompt_template or copy.deepcopy( - default_query_explainer_template - ) + self._prompt_template = prompt_template or NL_RESPONSE_TEMPLATE + self._explainer_prompt_template = explainer_prompt_template or QUERY_EXPLANATION_TEMPLATE self._max_tokens_count = max_tokens_count async def generate_response( @@ -56,53 +52,38 @@ async def generate_response( Uses LLM to generate a response in natural language form. Args: - result: object representing the result of the query execution - question: user question - event_tracker: event store used to audit the generation process - llm_options: options to use for the LLM client. + result: Object representing the result of the query execution. + question: User question. + event_tracker: Event store used to audit the generation process. + llm_options: Options to use for the LLM client. Returns: Natural language response to the user question. """ - rows = _promptify_rows(result.results) - - tokens_count = self._llm.count_tokens( - messages=self._nl_responder_prompt_template.chat, - fmt={"rows": rows, "question": question}, + prompt_format = NLResponsePromptFormat( + question=question, + results=result.results, ) + formatted_prompt = self._prompt_template.format_prompt(prompt_format) + tokens_count = self._llm.count_tokens(formatted_prompt) if tokens_count > self._max_tokens_count: - context = result.context - query = next((context.get(key) for key in self.QUERY_KEYS if context.get(key)), question) + prompt_format = QueryExplanationPromptFormat( + question=question, + context=result.context, + results=result.results, + ) + formatted_prompt = self._explainer_prompt_template.format_prompt(prompt_format) llm_response = await self._llm.generate_text( - template=self._query_explainer_prompt_template, - fmt={"question": question, "query": query, "number_of_results": len(result.results)}, + prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) - return llm_response llm_response = await self._llm.generate_text( - template=self._nl_responder_prompt_template, - fmt={"rows": _promptify_rows(result.results), "question": question}, + prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) return llm_response - - -def _promptify_rows(rows: List[Dict]) -> str: - """ - Formats rows into a markdown table. - - Args: - rows: list of rows to be formatted - - Returns: - str: formatted rows - """ - - df = pd.DataFrame.from_records(rows) - - return df.to_markdown(index=False, headers="keys", tablefmt="psql") diff --git a/src/dbally/nl_responder/nl_responder_prompt_template.py b/src/dbally/nl_responder/nl_responder_prompt_template.py deleted file mode 100644 index 9e6e687e..00000000 --- a/src/dbally/nl_responder/nl_responder_prompt_template.py +++ /dev/null @@ -1,47 +0,0 @@ -from typing import Callable, Dict, Optional - -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class NLResponderPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant for the natural response. - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ) -> None: - """ - Initializes NLResponderPromptTemplate class. - - Args: - chat: chat format - response_format: response format - llm_response_parser: function to parse llm response - """ - - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"rows", "question"}) - - -default_nl_responder_template = NLResponderPromptTemplate( - chat=( - { - "role": "system", - "content": "You are a helpful assistant that helps answer the user's questions " - "based on the table provided. You MUST use the table to answer the question. " - "You are very intelligent and obedient.\n" - "The table ALWAYS contains full answer to a question.\n" - "Answer the question in a way that is easy to understand and informative.\n" - "DON'T MENTION using a table in your answer.", - }, - { - "role": "user", - "content": "The table below represents the answer to a question: {question}.\n" - "{rows}\nAnswer the question: {question}.", - }, - ) -) diff --git a/src/dbally/nl_responder/prompts.py b/src/dbally/nl_responder/prompts.py new file mode 100644 index 00000000..f99a8a6c --- /dev/null +++ b/src/dbally/nl_responder/prompts.py @@ -0,0 +1,111 @@ +from typing import Any, Dict, List + +import pandas as pd + +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate + + +class NLResponsePromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default NL response prompt. + """ + + def __init__( + self, + *, + question: str, + results: List[Dict[str, Any]], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new IQLGenerationPromptFormat instance. + + Args: + question: Question to be asked. + filters: List of filters exposed by the view. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.results = pd.DataFrame.from_records(results).to_markdown(index=False, headers="keys", tablefmt="psql") + + +class QueryExplanationPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default query explanation prompt. + """ + + def __init__( + self, + *, + question: str, + context: Dict[str, Any], + results: List[Dict[str, Any]], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new QueryExplanationPromptFormat instance. + + Args: + question: Question to be asked. + context: Context of the query. + results: List of results returned by the query. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.query = next((context.get(key) for key in ("iql", "sql", "query") if context.get(key)), question) + self.number_of_results = len(results) + + +NL_RESPONSE_TEMPLATE = PromptTemplate[NLResponsePromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a helpful assistant that helps answer the user's questions " + "based on the table provided. You MUST use the table to answer the question. " + "You are very intelligent and obedient.\n" + "The table ALWAYS contains full answer to a question.\n" + "Answer the question in a way that is easy to understand and informative.\n" + "DON'T MENTION using a table in your answer." + ), + }, + { + "role": "user", + "content": ( + "The table below represents the answer to a question: {question}.\n" + "{results}\n" + "Answer the question: {question}." + ), + }, + ], +) + +QUERY_EXPLANATION_TEMPLATE = PromptTemplate[QueryExplanationPromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a helpful assistant that helps describe a table generated by a query " + "that answers users' question. " + "You are very intelligent and obedient.\n" + "Your task is to provide natural language description of the table used by the logical query " + "to the database.\n" + "Describe the table in a way that is short and informative.\n" + "Make your answer as short as possible, start it by infroming the user that the underlying " + "data is too long to print and then describe the table based on the question and the query.\n" + "DON'T MENTION using a query in your answer." + ), + }, + { + "role": "user", + "content": ( + "The query below represents the answer to a question: {question}.\n" + "Describe the table generated using this query: {query}.\n" + "Number of results to this query: {number_of_results}." + ), + }, + ], +) diff --git a/src/dbally/nl_responder/query_explainer_prompt_template.py b/src/dbally/nl_responder/query_explainer_prompt_template.py deleted file mode 100644 index 00a3e6a6..00000000 --- a/src/dbally/nl_responder/query_explainer_prompt_template.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Callable, Dict, Optional - -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class QueryExplainerPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant to generate explanations for queries - (when the data cannot be shown due to token limit). - - Args: - chat: chat format - response_format: response format - llm_response_parser: function to parse llm response - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ) -> None: - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"}) - - -default_query_explainer_template = QueryExplainerPromptTemplate( - chat=( - { - "role": "system", - "content": "You are a helpful assistant that helps describe a table generated by a query " - "that answers users' question. " - "You are very intelligent and obedient.\n" - "Your task is to provide natural language description of the table used by the logical query " - "to the database.\n" - "Describe the table in a way that is short and informative.\n" - "Make your answer as short as possible, start it by infroming the user that the underlying " - "data is too long to print and then describe the table based on the question and the query.\n" - "DON'T MENTION using a query in your answer.\n", - }, - { - "role": "user", - "content": "The query below represents the answer to a question: {question}.\n" - "Describe the table generated using this query: {query}.\n" - "Number of results to this query: {number_of_results}.\n", - }, - ) -) diff --git a/src/dbally/prompt/__init__.py b/src/dbally/prompt/__init__.py new file mode 100644 index 00000000..61495d33 --- /dev/null +++ b/src/dbally/prompt/__init__.py @@ -0,0 +1,3 @@ +from .template import ChatFormat, PromptTemplate, PromptTemplateError + +__all__ = ["PromptTemplate", "PromptTemplateError", "ChatFormat"] diff --git a/src/dbally/prompts/elements.py b/src/dbally/prompt/elements.py similarity index 97% rename from src/dbally/prompts/elements.py rename to src/dbally/prompt/elements.py index 2937d7c1..37375508 100644 --- a/src/dbally/prompts/elements.py +++ b/src/dbally/prompt/elements.py @@ -58,4 +58,4 @@ def _parse_lambda(self, expr: Callable) -> str: return parsed_expr def __str__(self) -> str: - return self.answer + return f"{self.question} -> {self.answer}" diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py new file mode 100644 index 00000000..124a3e1c --- /dev/null +++ b/src/dbally/prompt/template.py @@ -0,0 +1,234 @@ +import copy +import re +from typing import Callable, Dict, Generic, List, TypeVar + +from typing_extensions import Self + +from dbally.exceptions import DbAllyError +from dbally.prompt.elements import FewShotExample + +ChatFormat = List[Dict[str, str]] + + +class PromptTemplateError(DbAllyError): + """ + Error raised on incorrect PromptTemplate construction. + """ + + +def _check_chat_order(chat: ChatFormat) -> ChatFormat: + """ + Pydantic validator. Checks if the chat template is constructed correctly (system, user, assistant alternating). + + Args: + chat: Chat template + + Raises: + PromptTemplateError: if chat template is not constructed correctly. + + Returns: + Chat template + """ + if len(chat) == 0: + raise PromptTemplateError("Template should not be empty") + + expected_order = ["user", "assistant"] + for i, message in enumerate(chat): + role = message["role"] + if role == "system": + if i != 0: + raise PromptTemplateError("Only first message should come from system") + continue + index = i % len(expected_order) + if role != expected_order[index - 1]: + raise PromptTemplateError( + "Template format is not correct. It should be system, and then user/assistant alternating." + ) + + if expected_order[index] not in ["user", "assistant"]: + raise PromptTemplateError("Template needs to end on either user or assistant turn") + return chat + + +class PromptFormat: + """ + Generic format for prompts allowing to inject few shot examples into the conversation. + """ + + def __init__(self, examples: List[FewShotExample] = None) -> None: + """ + Constructs a new PromptFormat instance. + + Args: + examples: List of examples to be injected into the conversation. + """ + self.examples = examples or [] + + +PromptFormatT = TypeVar("PromptFormatT", bound=PromptFormat) + + +class PromptTemplate(Generic[PromptFormatT]): + """ + Class for prompt templates. + """ + + def __init__( + self, + chat: ChatFormat, + *, + json_mode: bool = False, + response_parser: Callable = lambda x: x, + ) -> None: + """ + Constructs a new PromptTemplate instance. + + Args: + chat: Chat-formatted conversation template. + json_mode: Whether to enforce JSON response from LLM. + response_parser: Function parsing the LLM response into the desired format. + """ + self.chat: ChatFormat = _check_chat_order(chat) + self.json_mode = json_mode + self.response_parser = response_parser + + def __eq__(self, other: "PromptTemplate") -> bool: + return isinstance(other, PromptTemplate) and self.chat == other.chat + + def _has_variable(self, variable: str) -> bool: + """ + Validates a given chat to make sure it contains variables required. + + Args: + variable: Variable to check. + + Returns: + True if the variable is present in the chat. + """ + for message in self.chat: + if re.match(rf"{{{variable}}}", message["content"]): + return True + return False + + def format_prompt(self, prompt_format: PromptFormatT) -> Self: + """ + Applies formatting to the prompt template chat contents. + + Args: + prompt_format: Format to be applied to the prompt. + + Returns: + PromptTemplate with formatted chat contents. + """ + formatted_prompt = copy.deepcopy(self) + formatting = dict(prompt_format.__dict__) + + if self._has_variable("examples"): + formatting["examples"] = "\n".join(prompt_format.examples) + else: + formatted_prompt = formatted_prompt.clear_few_shot_messages() + for example in prompt_format.examples: + formatted_prompt = formatted_prompt.add_few_shot_message(example) + + formatted_prompt.chat = [ + { + "role": message.get("role"), + "content": message.get("content").format(**formatting), + "is_example": message.get("is_example", False), + } + for message in formatted_prompt.chat + ] + return formatted_prompt + + def set_system_message(self, content: str) -> Self: + """ + Sets a system message to the template prompt. + + Args: + content: Message to be added. + + Returns: + PromptTemplate with appended system message. + """ + return self.__class__( + chat=[{"role": "system", "content": content}, *self.chat], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def add_user_message(self, content: str) -> Self: + """ + Add a user message to the template prompt. + + Args: + content: Message to be added. + + Returns: + PromptTemplate with appended user message. + """ + return self.__class__( + chat=[*self.chat, {"role": "user", "content": content}], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def add_assistant_message(self, content: str) -> Self: + """ + Add an assistant message to the template prompt. + + Args: + content: Message to be added. + + Returns: + PromptTemplate with appended assistant message. + """ + return self.__class__( + chat=[*self.chat, {"role": "assistant", "content": content}], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def add_few_shot_message(self, example: FewShotExample) -> Self: + """ + Add a few-shot message to the template prompt. + + Args: + example: Few-shot example to be added. + + Returns: + PromptTemplate with appended few-shot message. + + Raises: + PromptTemplateError: if the template is empty. + """ + if len(self.chat) == 0: + raise PromptTemplateError("Cannot add few-shot messages to an empty template.") + + few_shot = [ + {"role": "user", "content": example.question, "is_example": True}, + {"role": "assistant", "content": example.answer, "is_example": True}, + ] + few_shot_index = max( + (i for i, entry in enumerate(self.chat) if entry.get("is_example") or entry.get("role") == "system"), + default=0, + ) + chat = self.chat[: few_shot_index + 1] + few_shot + self.chat[few_shot_index + 1 :] + + return self.__class__( + chat=chat, + json_mode=self.json_mode, + response_parser=self.response_parser, + ) + + def clear_few_shot_messages(self) -> Self: + """ + Removes all few-shot messages from the template prompt. + + Returns: + PromptTemplate with few-shot messages removed. + """ + return self.__class__( + chat=[message for message in self.chat if not message.get("is_example")], + json_mode=self.json_mode, + response_parser=self.response_parser, + ) diff --git a/src/dbally/prompts/__init__.py b/src/dbally/prompts/__init__.py deleted file mode 100644 index 38e20cc7..00000000 --- a/src/dbally/prompts/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .common_validation_utils import ChatFormat, PromptTemplateError, check_prompt_variables -from .prompt_template import PromptTemplate - -__all__ = ["PromptTemplate", "PromptTemplateError", "check_prompt_variables", "ChatFormat"] diff --git a/src/dbally/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py deleted file mode 100644 index f4660810..00000000 --- a/src/dbally/prompts/common_validation_utils.py +++ /dev/null @@ -1,51 +0,0 @@ -import re -from typing import Dict, List, Set - -from dbally.exceptions import DbAllyError - -ChatFormat = List[Dict[str, str]] - - -class PromptTemplateError(DbAllyError): - """Error raised on incorrect PromptTemplate construction""" - - -def _extract_variables(text: str) -> List[str]: - """ - Given a text string, extract all variables that can be filled using .format - - Args: - text: string to process - - Returns: - list of variables extracted from text - """ - pattern = r"\{([^}]+)\}" - return re.findall(pattern, text) - - -def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> ChatFormat: - """ - Function validates a given chat to make sure it contains variables required. - - Args: - chat: chat to validate - variables_to_check: set of variables to assert - - Raises: - PromptTemplateError: If required variables are missing - - Returns: - Chat, if it's valid. - """ - variables = [] - for message in chat: - content = message["content"] - variables.extend(_extract_variables(content)) - if not set(variables_to_check).issubset(variables): - raise PromptTemplateError( - "Cannot build a prompt template from the provided chat, " - "because it lacks necessary string variables. " - "You need to format the following variables: {variables_to_check}" - ) - return chat diff --git a/src/dbally/prompts/formatters.py b/src/dbally/prompts/formatters.py deleted file mode 100644 index c2cce950..00000000 --- a/src/dbally/prompts/formatters.py +++ /dev/null @@ -1,119 +0,0 @@ -import copy -from abc import ABCMeta, abstractmethod -from typing import Dict, List, Tuple - -from dbally.prompts.elements import FewShotExample -from dbally.prompts.prompt_template import PromptTemplate -from dbally.views.exposed_functions import ExposedFunction - - -def _promptify_filters( - filters: List[ExposedFunction], -) -> str: - """ - Formats filters for prompt - - Args: - filters: list of filters exposed by the view - - Returns: - filters formatted for prompt - """ - filters_for_prompt = "\n".join([str(filter) for filter in filters]) - return filters_for_prompt - - -class InputFormatter(metaclass=ABCMeta): - """ - Formats provided parameters to a form acceptable by IQL prompt - """ - - @abstractmethod - def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: - """ - Runs the input formatting for provided prompt template. - - Args: - conversation_template: a prompt template to use. - - Returns: - A tuple with template and a dictionary with formatted inputs. - """ - - -class IQLInputFormatter(InputFormatter): - """ - Formats provided parameters to a form acceptable by default IQL prompt - """ - - def __init__(self, filters: List[ExposedFunction], question: str) -> None: - self.filters = filters - self.question = question - - def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: - """ - Runs the input formatting for provided prompt template. - - Args: - conversation_template: a prompt template to use. - - Returns: - A tuple with template and a dictionary with formatted filters and a question. - """ - return conversation_template, { - "filters": _promptify_filters(self.filters), - "question": self.question, - } - - -class IQLFewShotInputFormatter(InputFormatter): - """ - Formats provided parameters to a form acceptable by default IQL prompt. - Calling it will inject `examples` before last message in a conversation. - """ - - def __init__( - self, - filters: List[ExposedFunction], - examples: List[FewShotExample], - question: str, - ) -> None: - self.filters = filters - self.question = question - self.examples = examples - - def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]: - """ - Performs a deep copy of provided template and injects examples into chat history. - Also prepares filters and question to be included within the prompt. - - Args: - conversation_template: a prompt template to use to inject few-shot examples. - - Returns: - A tuple with deeply-copied and enriched with examples template - and a dictionary with formatted filters and a question. - """ - - template_copy = copy.deepcopy(conversation_template) - sys_msg = template_copy.chat[0] - existing_msgs = [msg for msg in template_copy.chat[1:] if "is_example" not in msg] - chat_examples = [ - msg - for example in self.examples - for msg in [ - {"role": "user", "content": example.question, "is_example": True}, - {"role": "assistant", "content": example.answer, "is_example": True}, - ] - ] - - template_copy.chat = ( - sys_msg, - *chat_examples, - *existing_msgs, - ) - - return template_copy, { - "filters": _promptify_filters(self.filters), - "question": self.question, - } diff --git a/src/dbally/prompts/prompt_template.py b/src/dbally/prompts/prompt_template.py deleted file mode 100644 index 8e2746fe..00000000 --- a/src/dbally/prompts/prompt_template.py +++ /dev/null @@ -1,83 +0,0 @@ -from typing import Callable, Dict, Optional - -from typing_extensions import Self - -from .common_validation_utils import ChatFormat, PromptTemplateError - - -def _check_chat_order(chat: ChatFormat) -> ChatFormat: - """ - Pydantic validator. Checks if the chat template is constructed correctly (system, user, assistant alternating). - - Args: - chat: Chat template - - Raises: - PromptTemplateError: if chat template is not constructed correctly. - - Returns: - Chat template - """ - expected_order = ["user", "assistant"] - for i, message in enumerate(chat): - role = message["role"] - if role == "system": - if i != 0: - raise PromptTemplateError("Only first message should come from system") - continue - index = i % len(expected_order) - if role != expected_order[index - 1]: - raise PromptTemplateError( - "Template format is not correct. It should be system, and then user/assistant alternating." - ) - - if expected_order[index] not in ["user", "assistant"]: - raise PromptTemplateError("Template needs to end on either user or assistant turn") - return chat - - -class PromptTemplate: - """ - Class for prompt templates - - Attributes: - response_format: Optional argument for OpenAI Turbo models - may be used to force json output - llm_response_parser: Function parsing the LLM response into IQL - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ): - self.chat: ChatFormat = _check_chat_order(chat) - self.response_format = response_format - self.llm_response_parser = llm_response_parser - - def __eq__(self, __value: object) -> bool: - return isinstance(__value, PromptTemplate) and self.chat == __value.chat - - def add_user_message(self, content: str) -> Self: - """ - Add a user message to the template prompt. - - Args: - content: Message to be added - - Returns: - PromptTemplate with appended user message - """ - return self.__class__((*self.chat, {"role": "user", "content": content})) - - def add_assistant_message(self, content: str) -> Self: - """ - Add an assistant message to the template prompt. - - Args: - content: Message to be added - - Returns: - PromptTemplate with appended assistant message - """ - return self.__class__((*self.chat, {"role": "assistant", "content": content})) diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py index 2d501922..b4069bb1 100644 --- a/src/dbally/view_selection/llm_view_selector.py +++ b/src/dbally/view_selection/llm_view_selector.py @@ -1,12 +1,11 @@ -import copy -from typing import Callable, Dict, Optional +from typing import Dict, Optional from dbally.audit.event_tracker import EventTracker -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions +from dbally.prompt.template import PromptTemplate from dbally.view_selection.base import ViewSelector -from dbally.view_selection.view_selector_prompt_template import default_view_selector_template +from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE, ViewSelectionPromptFormat class LLMViewSelector(ViewSelector): @@ -20,22 +19,16 @@ class LLMViewSelector(ViewSelector): ultimately returning the name of the most suitable view. """ - def __init__( - self, - llm: LLM, - prompt_template: Optional[IQLPromptTemplate] = None, - promptify_views: Optional[Callable[[Dict[str, str]], str]] = None, - ) -> None: + def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[ViewSelectionPromptFormat]] = None) -> None: """ + Constructs a new LLMViewSelector instance. + Args: llm: LLM used to generate IQL prompt_template: template for the prompt used for the view selection - promptify_views: Function formatting filters for prompt. By default names and descriptions of\ - all views are concatenated """ self._llm = llm - self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template) - self._promptify_views = promptify_views or _promptify_views + self._prompt_template = prompt_template or VIEW_SELECTION_TEMPLATE async def select_view( self, @@ -56,28 +49,13 @@ async def select_view( Returns: The most relevant view name. """ - - views_for_prompt = self._promptify_views(views) + prompt_format = ViewSelectionPromptFormat(question=question, views=views) + formatted_prompt = self._prompt_template.format_prompt(prompt_format) llm_response = await self._llm.generate_text( - template=self._prompt_template, - fmt={"views": views_for_prompt, "question": question}, + prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) - selected_view = self._prompt_template.llm_response_parser(llm_response) + selected_view = self._prompt_template.response_parser(llm_response) return selected_view - - -def _promptify_views(views: Dict[str, str]) -> str: - """ - Formats views for prompt - - Args: - views: dictionary of available view names with corresponding descriptions. - - Returns: - views_for_prompt: views formatted for prompt - """ - - return "\n".join([f"{name}: {description}" for name, description in views.items()]) diff --git a/src/dbally/view_selection/prompt.py b/src/dbally/view_selection/prompt.py new file mode 100644 index 00000000..cdbedf5a --- /dev/null +++ b/src/dbally/view_selection/prompt.py @@ -0,0 +1,52 @@ +from typing import Dict, List + +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate + + +class ViewSelectionPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default IQL prompt. + """ + + def __init__( + self, + *, + question: str, + views: Dict[str, str], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new ViewSelectionPromptFormat instance. + + Args: + question: Question to be asked. + views: Dictionary of available view names with corresponding descriptions. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.views = "\n".join([f"{name}: {description}" for name, description in views.items()]) + + +VIEW_SELECTION_TEMPLATE = PromptTemplate[ViewSelectionPromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a very smart database programmer. " + "You have access to API that lets you query a database:\n" + "First you need to select a class to query, based on its description and the user question. " + "You have the following classes to choose from:\n" + "{views}\n" + "Return only the selected view name. Don't give any comments.\n" + "You can only use the classes that were listed. " + "If none of the classes listed can be used to answer the user question, say `NoViewFoundError`" + ), + }, + { + "role": "user", + "content": "{question}", + }, + ], +) diff --git a/src/dbally/view_selection/view_selector_prompt_template.py b/src/dbally/view_selection/view_selector_prompt_template.py deleted file mode 100644 index 60440c84..00000000 --- a/src/dbally/view_selection/view_selector_prompt_template.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from typing import Callable, Dict, Optional - -from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables - - -class ViewSelectorPromptTemplate(PromptTemplate): - """ - Class for prompt templates meant for the ViewSelector - """ - - def __init__( - self, - chat: ChatFormat, - response_format: Optional[Dict[str, str]] = None, - llm_response_parser: Callable = lambda x: x, - ): - super().__init__(chat, response_format, llm_response_parser) - self.chat = check_prompt_variables(chat, {"views"}) - - -def _convert_llm_json_response_to_selected_view(llm_response_json: str) -> str: - """ - Converts LLM json response to IQL - - Args: - llm_response_json: LLM response in JSON format - - Returns: - A string containing selected view - """ - llm_response_dict = json.loads(llm_response_json) - return llm_response_dict.get("view") - - -default_view_selector_template = ViewSelectorPromptTemplate( - chat=( - { - "role": "system", - "content": "You are a very smart database programmer. " - "You have access to API that lets you query a database:\n" - "First you need to select a class to query, based on its description and the user question. " - "You have the following classes to choose from:\n" - "{views}\n" - "Return only the selected view name. Don't give any comments.\n" - "You can only use the classes that were listed. " - "If none of the classes listed can be used to answer the user question, say `NoViewFoundError`", - }, - {"role": "user", "content": "{question}"}, - ), -) diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index a3278281..d5103884 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -5,7 +5,7 @@ from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts.elements import FewShotExample +from dbally.prompt.elements import FewShotExample from dbally.similarity import AbstractSimilarityIndex IndexLocation = Tuple[str, str, str] diff --git a/src/dbally/views/freeform/text2sql/prompt.py b/src/dbally/views/freeform/text2sql/prompt.py new file mode 100644 index 00000000..5f9a547d --- /dev/null +++ b/src/dbally/views/freeform/text2sql/prompt.py @@ -0,0 +1,61 @@ +# pylint: disable=C0301 + +from typing import List + +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import PromptFormat, PromptTemplate +from dbally.views.freeform.text2sql.config import TableConfig + + +class SQLGenerationPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default SQL prompt. + """ + + def __init__( + self, + *, + question: str, + dialect: str, + tables: List[TableConfig], + examples: List[FewShotExample] = None, + ) -> None: + """ + Constructs a new SQLGenerationPromptFormat instance. + + Args: + question: Question to be asked. + context: Context of the query. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question + self.dialect = dialect + self.tables = "\n".join(table.ddl for table in tables) + + +SQL_GENERATION_TEMPLATE = PromptTemplate[SQLGenerationPromptFormat]( + [ + { + "role": "system", + "content": ( + "You are a very smart database programmer. " + "You have access to the following {dialect} tables:\n" + "{tables}\n" + "Create SQL query to answer user question. Response with JSON containing following keys:\n\n" + "- sql: SQL query to answer the question, with parameter :placeholders for user input.\n" + "- parameters: a list of parameters to be used in the query, represented by maps with the following keys:\n" + " - name: the name of the parameter\n" + " - value: the value of the parameter\n" + " - table: the table the parameter is used with (if any)\n" + " - column: the column the parameter is compared to (if any)\n\n" + "Respond ONLY with the raw JSON response. Don't include any additional text or characters." + ), + }, + { + "role": "user", + "content": "{question}", + }, + ], + json_mode=True, +) diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py index 6891785e..7f24f00e 100644 --- a/src/dbally/views/freeform/text2sql/view.py +++ b/src/dbally/views/freeform/text2sql/view.py @@ -10,32 +10,12 @@ from dbally.collection.results import ViewExecutionResult from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts import PromptTemplate +from dbally.prompt.template import PromptTemplate from dbally.similarity import AbstractSimilarityIndex, SimpleSqlAlchemyFetcher from dbally.views.base import BaseView, IndexLocation from dbally.views.freeform.text2sql.config import TableConfig from dbally.views.freeform.text2sql.exceptions import Text2SQLError - -text2sql_prompt = PromptTemplate( - chat=( - { - "role": "system", - "content": "You are a very smart database programmer. " - "You have access to the following {dialect} tables:\n" - "{tables}\n" - "Create SQL query to answer user question. Response with JSON containing following keys:\n\n" - "- sql: SQL query to answer the question, with parameter :placeholders for user input.\n" - "- parameters: a list of parameters to be used in the query, represented by maps with the following keys:\n" - " - name: the name of the parameter\n" - " - value: the value of the parameter\n" - " - table: the table the parameter is used with (if any)\n" - " - column: the column the parameter is compared to (if any)\n\n" - "Respond ONLY with the raw JSON response. Don't include any additional text or characters.", - }, - {"role": "user", "content": "{question}"}, - ), - response_format={"type": "json_object"}, -) +from dbally.views.freeform.text2sql.prompt import SQL_GENERATION_TEMPLATE, SQLGenerationPromptFormat @dataclass @@ -142,17 +122,26 @@ async def ask( Raises: Text2SQLError: If the text2sql query generation fails after n_retries. """ - conversation = text2sql_prompt sql, rows = None, None exceptions = [] - for _ in range(n_retries): + tables = self.get_tables() + examples = self.list_few_shots() + + prompt_format = SQLGenerationPromptFormat( + question=query, + dialect=self._engine.dialect.name, + tables=tables, + examples=examples, + ) + formatted_prompt = SQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + for _ in range(n_retries + 1): # We want to catch all exceptions to retry the process. # pylint: disable=broad-except try: - sql, parameters, conversation = await self._generate_sql( - query=query, - conversation=conversation, + sql, parameters, formatted_prompt = await self._generate_sql( + conversation=formatted_prompt, llm=llm, event_tracker=event_tracker, llm_options=llm_options, @@ -164,7 +153,7 @@ async def ask( rows = await self._execute_sql(sql, parameters, event_tracker=event_tracker) break except Exception as e: - conversation = conversation.add_user_message(f"Response is invalid! Error: {e}") + formatted_prompt = formatted_prompt.add_user_message(f"Response is invalid! Error: {e}") exceptions.append(e) continue @@ -182,15 +171,13 @@ async def ask( async def _generate_sql( self, - query: str, conversation: PromptTemplate, llm: LLM, event_tracker: EventTracker, llm_options: Optional[LLMOptions] = None, ) -> Tuple[str, List[SQLParameterOption], PromptTemplate]: response = await llm.generate_text( - template=conversation, - fmt={"tables": self._get_tables_context(), "dialect": self._engine.dialect.name, "question": query}, + prompt=conversation, event_tracker=event_tracker, options=llm_options, ) @@ -221,12 +208,6 @@ async def _execute_sql( with self._engine.connect() as conn: return conn.execute(text(sql), param_values).fetchall() - def _get_tables_context(self) -> str: - context = "" - for table in self._table_index.values(): - context += f"{table.ddl}\n" - return context - def _create_default_fetcher(self, table: str, column: str) -> SimpleSqlAlchemyFetcher: return SimpleSqlAlchemyFetcher( sqlalchemy_engine=self._engine, diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 8b95ecaa..b5863075 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,11 +4,10 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLError, IQLQuery +from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter from dbally.views.exposed_functions import ExposedFunction from ..similarity import AbstractSimilarityIndex @@ -26,10 +25,10 @@ def get_iql_generator(self, llm: LLM) -> IQLGenerator: Returns the IQL generator for the view. Args: - llm: LLM used to generate the IQL queries + llm: LLM used to generate the IQL queries. Returns: - IQLGenerator: IQL generator for the view + IQL generator for the view. """ return IQLGenerator(llm=llm) @@ -57,46 +56,30 @@ async def ask( Returns: The result of the query. """ + iql_generator = self.get_iql_generator(llm) filters = self.list_filters() examples = self.list_few_shots() - iql_generator = self.get_iql_generator(llm) - input_formatter = ( - IQLFewShotInputFormatter(question=query, filters=filters, examples=examples) - if examples - else IQLInputFormatter(question=query, filters=filters) - ) - - iql_filters, conversation = await iql_generator.generate_iql( - input_formatter=input_formatter, + iql = await iql_generator.generate_iql( + question=query, + filters=filters, + examples=examples, event_tracker=event_tracker, llm_options=llm_options, + n_retries=n_retries, ) - - for _ in range(n_retries): - try: - filters = await IQLQuery.parse(iql_filters, filters, event_tracker=event_tracker) - await self.apply_filters(filters) - break - except (IQLError, ValueError) as e: - conversation = iql_generator.add_error_msg(conversation, [e]) - iql_filters, conversation = await iql_generator.generate_iql( - input_formatter=input_formatter, - event_tracker=event_tracker, - conversation=conversation, - llm_options=llm_options, - ) - continue + await self.apply_filters(iql) result = self.execute(dry_run=dry_run) - result.context["iql"] = iql_filters + result.context["iql"] = f"{iql}" return result @abc.abstractmethod def list_filters(self) -> List[ExposedFunction]: """ + Lists all available filters for the View. Returns: Filters defined inside the View. diff --git a/src/dbally_codegen/autodiscovery.py b/src/dbally_codegen/autodiscovery.py index 1e20c542..c842a07f 100644 --- a/src/dbally_codegen/autodiscovery.py +++ b/src/dbally_codegen/autodiscovery.py @@ -6,12 +6,59 @@ from typing_extensions import Self from dbally.llms.base import LLM -from dbally.prompts import PromptTemplate +from dbally.prompt.template import PromptFormat, PromptTemplate from dbally.similarity.index import SimilarityIndex -from dbally.views.freeform.text2sql import ColumnConfig, TableConfig +from dbally.views.freeform.text2sql.config import ColumnConfig, TableConfig -DISCOVERY_TEMPLATE = PromptTemplate( - chat=( + +class DiscoveryPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default discovery prompt. + """ + + def __init__( + self, + *, + dialect: str, + table_ddl: str, + samples: List[Dict[str, Any]], + ) -> None: + """ + Constructs a new DiscoveryPromptFormat instance. + + Args: + dialect: The SQL dialect of the database. + table_ddl: The DDL of the table. + samples: The example rows from the table. + """ + super().__init__() + self.dialect = dialect + self.table_ddl = table_ddl + self.samples = samples + + +class SimilarityPromptFormat(PromptFormat): + """ + Formats provided parameters to a form acceptable by default similarity prompt. + """ + + def __init__(self, *, table_summary: str, column_name: str, samples: List[Any]) -> None: + """ + Constructs a new SimilarityPromptFormat instance. + + Args: + table_summary: The summary of the table. + column_name: The name of the column. + samples: The example values from the column. + """ + super().__init__() + self.table_summary = table_summary + self.column_name = column_name + self.samples = samples + + +DISCOVERY_TEMPLATE = PromptTemplate[DiscoveryPromptFormat]( + [ { "role": "system", "content": ( @@ -24,11 +71,11 @@ "role": "user", "content": "DDL:\n {table_ddl}\n" "EXAMPLE ROWS:\n {samples}", }, - ), + ], ) -SIMILARITY_TEMPLATE = PromptTemplate( - chat=( +SIMILARITY_TEMPLATE = PromptTemplate[SimilarityPromptFormat]( + [ { "role": "system", "content": ( @@ -43,7 +90,7 @@ "role": "user", "content": "TABLE SUMMARY: {table_summary}\n" "COLUMN NAME: {column_name}\n" "EXAMPLE VALUES: {samples}", }, - ) + ], ) @@ -108,14 +155,15 @@ async def extract_description(self, table: Table, connection: Connection) -> str """ ddl = self._generate_ddl(table) samples = self._fetch_samples(connection, table) - return await self.llm.generate_text( - template=DISCOVERY_TEMPLATE, - fmt={ - "dialect": self.engine.dialect.name, - "table_ddl": ddl, - "samples": samples, - }, + + prompt_format = DiscoveryPromptFormat( + dialect=self.engine.dialect.name, + table_ddl=ddl, + samples=samples, ) + formatted_prompt = DISCOVERY_TEMPLATE.format_prompt(prompt_format) + + return await self.llm.generate_text(formatted_prompt) def _fetch_samples(self, connection: Connection, table: Table) -> List[Dict[str, Any]]: rows = connection.execute(table.select().limit(self.samples_count)).fetchall() @@ -218,14 +266,15 @@ async def select_index( table=table, column=column, ) - use_index = await self.llm.generate_text( - template=SIMILARITY_TEMPLATE, - fmt={ - "table_summary": description, - "column_name": column.name, - "samples": samples, - }, + + prompt_format = SimilarityPromptFormat( + table_summary=description, + column_name=column.name, + samples=samples, ) + formatted_prompt = SIMILARITY_TEMPLATE.format_prompt(prompt_format) + + use_index = await self.llm.generate_text(formatted_prompt) return self.index_builder(connection.engine, table, column) if use_index.upper() == "TRUE" else None def _fetch_samples(self, connection: Connection, table: Table, column: Column) -> List[Any]: diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index e8c53435..fb8cfba4 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -35,20 +35,20 @@ async def test_llm_options_propagation(): llm.client.call.assert_has_calls( [ call( - prompt=ANY, - response_format=ANY, + conversation=ANY, + json_mode=ANY, event=ANY, options=expected_options, ), call( - prompt=ANY, - response_format=ANY, + conversation=ANY, + json_mode=ANY, event=ANY, options=expected_options, ), call( - prompt=ANY, - response_format=ANY, + conversation=ANY, + json_mode=ANY, event=ANY, options=expected_options, ), diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 9858e45f..75cc914b 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -6,12 +6,11 @@ from dataclasses import dataclass from functools import cached_property -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from dbally import NOT_GIVEN, NotGiven from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions from dbally.similarity.index import AbstractSimilarityIndex @@ -35,12 +34,12 @@ def execute(self, dry_run=False) -> ViewExecutionResult: class MockIQLGenerator(IQLGenerator): - def __init__(self, iql: str) -> None: + def __init__(self, iql: IQLQuery) -> None: self.iql = iql super().__init__(llm=MockLLM()) - async def generate_iql(self, *_, **__) -> Tuple[str, IQLPromptTemplate]: - return self.iql, default_iql_template + async def generate_iql(self, *_, **__) -> IQLQuery: + return self.iql class MockViewSelector(ViewSelector): diff --git a/tests/unit/test_assistants_adapters.py b/tests/unit/test_assistants_adapters.py index 72a55e06..9c203bd6 100644 --- a/tests/unit/test_assistants_adapters.py +++ b/tests/unit/test_assistants_adapters.py @@ -8,7 +8,7 @@ from dbally.assistants.base import FunctionCallingError, FunctionCallState from dbally.assistants.openai import _DBALLY_INFO, _DBALLY_INSTRUCTION, OpenAIAdapter, OpenAIDballyResponse -from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError +from dbally.iql_generator.prompt import UnsupportedQueryError MOCK_VIEWS = {"view1": "description1", "view2": "description2"} F_ID = "f_id" diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 3e5bddb5..38ec3e99 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -1,7 +1,7 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name, missing-return-type-doc from typing import List, Tuple, Type -from unittest.mock import AsyncMock, Mock, call, patch +from unittest.mock import AsyncMock, Mock import pytest from typing_extensions import Annotated @@ -10,9 +10,9 @@ from dbally.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ViewExecutionResult -from dbally.iql._exceptions import IQLError +from dbally.iql import IQLQuery +from dbally.iql.syntax import FunctionCall from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector @@ -59,8 +59,8 @@ def execute(self, dry_run=False) -> ViewExecutionResult: def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] - def get_iql_generator(self, *_, **__): - return MockIQLGenerator("test_filter()") + def get_iql_generator(self, *_, **__) -> MockIQLGenerator: + return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) @pytest.fixture(name="similarity_classes") @@ -275,42 +275,6 @@ def get_iql_generator(self, *_, **__): return collection -async def test_ask_feedback_loop(collection_feedback: Collection) -> None: - """ - Tests that the ask_feedback_loop method works correctly - """ - - mock_node = Mock(col_offset=0, end_col_offset=-1) - errors = [ - IQLError("err1", mock_node, "src1"), - IQLError("err2", mock_node, "src2"), - ValueError("err3"), - ValueError("err4"), - ] - with patch("dbally.iql._query.IQLQuery.parse") as mock_iql_query: - mock_iql_query.side_effect = errors - view = collection_feedback.get("ViewWithMockGenerator") - assert isinstance(view, BaseStructuredView) - iql_generator = view.get_iql_generator(llm=MockLLM()) - - await collection_feedback.ask("Mock question") - - iql_gen_error: Mock = iql_generator.add_error_msg # type: ignore - - iql_gen_error.assert_has_calls( - [call("iql1_c", [errors[0]]), call("iql2_c", [errors[1]]), call("iql3_c", [errors[2]])] - ) - assert iql_gen_error.call_count == 3 - - iql_gen_gen_iql: Mock = iql_generator.generate_iql # type: ignore - - for i, c in enumerate(iql_gen_gen_iql.call_args_list): - if i > 0: - assert c[1]["conversation"] == f"err{i}" - - assert iql_gen_gen_iql.call_count == 4 - - async def test_ask_view_selection_single_view() -> None: """ Tests that the ask method select view correctly when there is only one view diff --git a/tests/unit/test_fewshot.py b/tests/unit/test_fewshot.py index e2f4cf8d..2b8ba8b3 100644 --- a/tests/unit/test_fewshot.py +++ b/tests/unit/test_fewshot.py @@ -2,20 +2,20 @@ import pytest -from dbally.prompts.elements import FewShotExample +from dbally.prompt.elements import FewShotExample class TestExamples: - def studied_at(self, _: str): + def studied_at(self, _: str) -> bool: return False - def is_available_within_months(self, _: int): + def is_available_within_months(self, _: int) -> bool: return False - def data_scientist_position(self): + def data_scientist_position(self) -> bool: return False - def has_seniority(self, _: str): + def has_seniority(self, _: str) -> bool: return False def __call__(self) -> List[Tuple[str, Callable]]: # pylint: disable=W0602, C0116, W9011 @@ -57,16 +57,17 @@ def __call__(self) -> List[Tuple[str, Callable]]: # pylint: disable=W0602, C011 ] -def test_fewshot_string(): - result = FewShotExample("question", "answer") - assert result.answer == "answer" - assert str(result) == "answer" - - @pytest.mark.parametrize( "repr_lambda", TestExamples()(), ) -def test_fewshot_lambda(repr_lambda: Tuple[str, Callable]): +def test_fewshot_lambda(repr_lambda: Tuple[str, Callable]) -> None: result = FewShotExample("question", repr_lambda[1]) - assert str(result) == repr_lambda[0] + assert result.answer == repr_lambda[0] + assert str(result) == f"question -> {repr_lambda[0]}" + + +def test_fewshot_string() -> None: + result = FewShotExample("question", "answer") + assert result.answer == "answer" + assert str(result) == "question -> answer" diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index c2fb4274..8f583c4c 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -1,68 +1,89 @@ -from typing import List - -import pytest - -from dbally.iql_generator.iql_prompt_template import default_iql_template -from dbally.prompts.elements import FewShotExample -from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter - - -async def test_iql_input_format_default() -> None: - input_fmt = IQLInputFormatter([], "") - - conversation, format = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) - assert "filters" in format - assert "question" in format - - -async def test_iql_input_format_few_shot_default() -> None: - input_fmt = IQLFewShotInputFormatter([], [], "") - - conversation, format = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) - assert "filters" in format - assert "question" in format - - -@pytest.mark.parametrize( - "examples", - [ - [], - [FewShotExample("q1", "a1")], - ], -) -async def test_iql_input_format_few_shot_examples_injected(examples: List[FewShotExample]) -> None: +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.prompt.elements import FewShotExample + + +async def test_iql_prompt_format_default() -> None: + prompt_format = IQLGenerationPromptFormat( + question="", + filters=[], + examples=[], + ) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + assert formatted_prompt.chat == [ + { + "role": "system", + "content": "You have access to API that lets you query a database:\n" + "\n\n" + "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Remember! Don't give any comments, just the function calls.\n" + "The output will look like this:\n" + 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + "This is CRUCIAL, otherwise the system will crash. ", + "is_example": False, + }, + {"role": "user", "content": "", "is_example": False}, + ] + + +async def test_iql_prompt_format_few_shots_injected() -> None: examples = [FewShotExample("q1", "a1")] - input_fmt = IQLFewShotInputFormatter([], examples, "") - - conversation, format = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) + (len(examples) * 2) - assert "filters" in format - assert "question" in format + prompt_format = IQLGenerationPromptFormat( + question="", + filters=[], + examples=examples, + ) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + assert formatted_prompt.chat == [ + { + "role": "system", + "content": "You have access to API that lets you query a database:\n" + "\n\n" + "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "Remember! Don't give any comments, just the function calls.\n" + "The output will look like this:\n" + 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ + "This is CRUCIAL, otherwise the system will crash. ", + "is_example": False, + }, + {"role": "user", "content": examples[0].question, "is_example": True}, + {"role": "assistant", "content": examples[0].answer, "is_example": True}, + {"role": "user", "content": "", "is_example": False}, + ] async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() -> None: examples = [FewShotExample("q1", "a1")] - input_fmt = IQLFewShotInputFormatter([], examples, "q") - - conversation, _ = input_fmt(default_iql_template) - - assert len(conversation.chat) == len(default_iql_template.chat) + (len(examples) * 2) - assert conversation.chat[1]["role"] == "user" - assert conversation.chat[1]["content"] == examples[0].question - assert conversation.chat[2]["role"] == "assistant" - assert conversation.chat[2]["content"] == examples[0].answer - - conversation = conversation.add_assistant_message("response") - - conversation2, _ = input_fmt(conversation) - - assert len(conversation2.chat) == len(conversation.chat) - assert conversation2.chat[1]["role"] == "user" - assert conversation2.chat[1]["content"] == examples[0].question - assert conversation2.chat[2]["role"] == "assistant" - assert conversation2.chat[2]["content"] == examples[0].answer + prompt_format = IQLGenerationPromptFormat( + question="", + filters=[], + examples=examples, + ) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2) + assert formatted_prompt.chat[1]["role"] == "user" + assert formatted_prompt.chat[1]["content"] == examples[0].question + assert formatted_prompt.chat[2]["role"] == "assistant" + assert formatted_prompt.chat[2]["content"] == examples[0].answer + + formatted_prompt = formatted_prompt.add_assistant_message("response") + + formatted_prompt2 = formatted_prompt.format_prompt(prompt_format) + + assert len(formatted_prompt2.chat) == len(formatted_prompt.chat) + assert formatted_prompt2.chat[1]["role"] == "user" + assert formatted_prompt2.chat[1]["content"] == examples[0].question + assert formatted_prompt2.chat[2]["role"] == "assistant" + assert formatted_prompt2.chat[2]["content"] == examples[0].answer diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index 8c8df9e7..ce3f593d 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,17 +1,15 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock, patch import pytest import sqlalchemy from dbally import decorators from dbally.audit.event_tracker import EventTracker -from dbally.iql import IQLQuery +from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.iql_prompt_template import default_iql_template -from dbally.prompts.elements import FewShotExample -from dbally.prompts.formatters import IQLFewShotInputFormatter, IQLInputFormatter +from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM @@ -43,7 +41,7 @@ def view() -> MockView: @pytest.fixture def llm() -> MockLLM: llm = MockLLM() - llm.client.call = AsyncMock(return_value="LLM IQL mock answer") + llm.generate_text = AsyncMock(return_value="filter_by_id(1)") return llm @@ -52,58 +50,64 @@ def event_tracker() -> EventTracker: return EventTracker() -@pytest.mark.asyncio -async def test_iql_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None: - iql_generator = IQLGenerator(llm) - - filters = {str(_filter) for _filter in view.list_filters()} - assert filters == {"filter_by_id(idx: int)", "filter_by_name(city: str)"} - - input_formatter = IQLInputFormatter(question="Mock_question", filters=view.list_filters()) - - response = await iql_generator.generate_iql(input_formatter, event_tracker, default_iql_template) - - template_after_response = default_iql_template.add_assistant_message(content="LLM IQL mock answer") - assert response == ("LLM IQL mock answer", template_after_response) - - template_after_response = template_after_response.add_user_message(content="Mock_error") - response2 = await iql_generator.generate_iql(input_formatter, event_tracker, template_after_response) - template_after_2nd_response = template_after_response.add_assistant_message(content="LLM IQL mock answer") - assert response2 == ("LLM IQL mock answer", template_after_2nd_response) +@pytest.fixture +def iql_generator(llm: MockLLM) -> IQLGenerator: + return IQLGenerator(llm) @pytest.mark.asyncio -async def test_iql_few_shot_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None: - iql_generator = IQLGenerator(llm) - - filters = {str(_filter) for _filter in view.list_filters()} - assert filters == {"filter_by_id(idx: int)", "filter_by_name(city: str)"} - - input_formatter = IQLFewShotInputFormatter( +async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None: + filters = view.list_filters() + prompt_format = IQLGenerationPromptFormat( question="Mock_question", - filters=view.list_filters(), - examples=[FewShotExample("question", "filter_by_id(0)")], + filters=filters, ) - - response = await iql_generator.generate_iql(input_formatter, event_tracker, default_iql_template) - - expected_conversation, _ = input_formatter(default_iql_template) - template_after_response = expected_conversation.add_assistant_message(content="LLM IQL mock answer") - assert response == ("LLM IQL mock answer", template_after_response) - - template_after_response = template_after_response.add_user_message(content="Mock_error") - response2 = await iql_generator.generate_iql(input_formatter, event_tracker, template_after_response) - template_after_2nd_response = template_after_response.add_assistant_message(content="LLM IQL mock answer") - assert response2 == ("LLM IQL mock answer", template_after_2nd_response) + formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + + with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: + iql = await iql_generator.generate_iql( + question="Mock_question", + filters=filters, + event_tracker=event_tracker, + ) + assert iql == "filter_by_id(1)" + iql_generator._llm.generate_text.assert_called_once_with( + prompt=formatted_prompt, + event_tracker=event_tracker, + options=None, + ) + mock_parse.assert_called_once_with( + source="filter_by_id(1)", + allowed_functions=filters, + event_tracker=event_tracker, + ) -def test_add_error_msg(llm: MockLLM) -> None: - iql_generator = IQLGenerator(llm) - errors = [ValueError("Mock_error")] - - conversation = default_iql_template.add_assistant_message(content="Assistant") - - conversation_with_error = iql_generator.add_error_msg(conversation, errors) - - error_msg = iql_generator._ERROR_MSG_PREFIX + "Mock_error\n" - assert conversation_with_error == conversation.add_user_message(content=error_msg) +@pytest.mark.asyncio +async def test_iql_generation_error_handling( + iql_generator: IQLGenerator, + event_tracker: EventTracker, + view: MockView, +) -> None: + filters = view.list_filters() + + mock_node = Mock(col_offset=0, end_col_offset=-1) + errors = [ + IQLError("err1", mock_node, "src1"), + IQLError("err2", mock_node, "src2"), + IQLError("err3", mock_node, "src3"), + IQLError("err4", mock_node, "src4"), + ] + + with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: + mock_parse.side_effect = errors + iql = await iql_generator.generate_iql( + question="Mock_question", + filters=filters, + event_tracker=event_tracker, + ) + + assert iql is None + assert iql_generator._llm.generate_text.call_count == 4 + for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] diff --git a/tests/unit/test_prompt_builder.py b/tests/unit/test_prompt_builder.py index f8a886fe..00fa7fd5 100644 --- a/tests/unit/test_prompt_builder.py +++ b/tests/unit/test_prompt_builder.py @@ -1,116 +1,99 @@ +from typing import List + import pytest -from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate -from dbally.prompts import ChatFormat, PromptTemplate, PromptTemplateError -from tests.unit.mocks import MockLLM +from dbally.prompt.elements import FewShotExample +from dbally.prompt.template import ChatFormat, PromptFormat, PromptTemplate, PromptTemplateError + + +class QuestionPromptFormat(PromptFormat): + """ + Generic format for prompts allowing to inject few shot examples into the conversation. + """ + + def __init__(self, question: str, examples: List[FewShotExample] = None) -> None: + """ + Constructs a new PromptFormat instance. + + Args: + question: Question to be asked. + examples: List of examples to be injected into the conversation. + """ + super().__init__(examples) + self.question = question @pytest.fixture() -def simple_template(): - simple_template = PromptTemplate( - chat=( +def template() -> PromptTemplate[QuestionPromptFormat]: + return PromptTemplate[QuestionPromptFormat]( + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "{question}"}, - ) + ] ) - return simple_template -@pytest.fixture() -def llm(): - return MockLLM() +def test_prompt_template_formatting(template: PromptTemplate[QuestionPromptFormat]) -> None: + prompt_format = QuestionPromptFormat(question="Example user question?") + formatted_prompt = template.format_prompt(prompt_format) + assert formatted_prompt.chat == [ + {"content": "You are a helpful assistant.", "role": "system", "is_example": False}, + {"content": "Example user question?", "role": "user", "is_example": False}, + ] -def test_default_llm_format_prompt(llm, simple_template): - prompt = llm.format_prompt( - template=simple_template, - fmt={"question": "Example user question?"}, - ) - assert prompt == [ - {"content": "You are a helpful assistant.", "role": "system"}, - {"content": "Example user question?", "role": "user"}, +def test_missing_prompt_template_formatting(template: PromptTemplate[QuestionPromptFormat]) -> None: + prompt_format = PromptFormat() + with pytest.raises(KeyError): + template.format_prompt(prompt_format) + + +def test_add_few_shots(template: PromptTemplate[QuestionPromptFormat]) -> None: + examples = [ + FewShotExample( + question="What is the capital of France?", + answer_expr="Paris", + ), + FewShotExample( + question="What is the capital of Germany?", + answer_expr="Berlin", + ), ] + for example in examples: + template = template.add_few_shot_message(example) -def test_missing_format_dict(llm, simple_template): - with pytest.raises(KeyError): - _ = llm.format_prompt(simple_template, fmt={}) + assert template.chat == [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?", "is_example": True}, + {"role": "assistant", "content": "Paris", "is_example": True}, + {"role": "user", "content": "What is the capital of Germany?", "is_example": True}, + {"role": "assistant", "content": "Berlin", "is_example": True}, + {"role": "user", "content": "{question}"}, + ] @pytest.mark.parametrize( "invalid_chat", [ - ( + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "{question}"}, {"role": "user", "content": "{question}"}, - ), - ( + ], + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "assistant", "content": "{question}"}, {"role": "assistant", "content": "{question}"}, - ), - ( + ], + [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "{question}"}, {"role": "assistant", "content": "{question}"}, {"role": "system", "content": "{question}"}, - ), + ], ], ) -def test_chat_order_validation(invalid_chat): +def test_chat_order_validation(invalid_chat: ChatFormat) -> None: with pytest.raises(PromptTemplateError): - _ = PromptTemplate(chat=invalid_chat) - - -def test_dynamic_few_shot(llm, simple_template): - assert ( - len( - llm.format_prompt( - simple_template.add_assistant_message("assistant message").add_user_message("user message"), - fmt={"question": "user question"}, - ) - ) - == 4 - ) - - -@pytest.mark.parametrize( - "invalid_chat", - [ - ( - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "{question}"}, - ), - ( - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello"}, - ), - ( - {"role": "system", "content": "You are a helpful assistant. {filters}}"}, - {"role": "user", "content": "Hello"}, - ), - ], - ids=["Missing filters", "Missing filters, question", "Missing question"], -) -def test_bad_iql_prompt_template(invalid_chat: ChatFormat): - with pytest.raises(PromptTemplateError): - _ = IQLPromptTemplate(invalid_chat) - - -@pytest.mark.parametrize( - "chat", - [ - ( - {"role": "system", "content": "You are a helpful assistant.{filters}"}, - {"role": "user", "content": "{question}"}, - ), - ( - {"role": "system", "content": "{filters}{filters}{filters}}}"}, - {"role": "user", "content": "{question}"}, - ), - ], - ids=["Good template", "Good template with repeating variables"], -) -def test_good_iql_prompt_template(chat: ChatFormat): - _ = IQLPromptTemplate(chat) + PromptTemplate[QuestionPromptFormat](invalid_chat) diff --git a/tests/unit/test_view_selector.py b/tests/unit/test_view_selector.py index 2d3b1d9c..8de038e2 100644 --- a/tests/unit/test_view_selector.py +++ b/tests/unit/test_view_selector.py @@ -31,7 +31,7 @@ def views() -> Dict[str, str]: @pytest.mark.asyncio -async def test_view_selection(llm: LLM, views: Dict[str, str]): +async def test_view_selection(llm: LLM, views: Dict[str, str]) -> None: view_selector = LLMViewSelector(llm) view = await view_selector.select_view("Mock question?", views, event_tracker=EventTracker()) assert view == "MockView1"