Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: request contextualisation - core functionality #65

Open
wants to merge 57 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
21506e6
context logic subpackage; type-hint context extraction
Jun 21, 2024
a87e8e2
reworked type hint info extraction; extended functionality to also re…
ds-jakub-cierocki Jun 24, 2024
3ad4ecd
hidden args handling enabled
ds-jakub-cierocki Jun 24, 2024
b0cc0ae
improved type hints parsing and compatibility using package
ds-jakub-cierocki Jun 28, 2024
4ff5f62
dedicated exceptions for contex-related operations
ds-jakub-cierocki Jun 28, 2024
c479c50
useful classmethods for context-related operations
ds-jakub-cierocki Jun 28, 2024
e3bb127
make whole context utils module protected; added IQL parsing helper; …
ds-jakub-cierocki Jun 28, 2024
de72c7c
parsing type hints _extract_params_and_context() no longer excludes B…
ds-jakub-cierocki Jun 28, 2024
d3958c0
adjusted the existing code to be aware of contexts (promts yet untouc…
ds-jakub-cierocki Jun 28, 2024
be338bf
adjusted _type_validators.validate_arg_type() to handle typing.Union[]
ds-jakub-cierocki Jul 2, 2024
78f1535
context._utils._does_arg_allow_context() fix
ds-jakub-cierocki Jul 2, 2024
308e2e1
context record is now based on pydantic.BaseModel rather than datacla…
ds-jakub-cierocki Jul 2, 2024
73741d9
type hint lifting
ds-jakub-cierocki Jul 2, 2024
902f5ff
IQL generating LLM prompt passes BaseCallerContext() as filter argume…
ds-jakub-cierocki Jul 2, 2024
6309070
comments cleanup
ds-jakub-cierocki Jul 2, 2024
d523bf7
type hint fixes
ds-jakub-cierocki Jul 3, 2024
efe212f
Merge branch 'main' (which includes a large refactor by Michal) into …
ds-jakub-cierocki Jul 3, 2024
9ba89e5
post-merge fixes + minor refactor
ds-jakub-cierocki Jul 3, 2024
5fd802f
added missing docstrings; fixed type hints; fixed issues detected by …
ds-jakub-cierocki Jul 4, 2024
09bac55
reworked parse_param_type() function to increase performance, general…
ds-jakub-cierocki Jul 4, 2024
d42a369
fix: removed duplicated line from the prompt template
ds-jakub-cierocki Jul 4, 2024
c0b0522
adjusted existing unit tests to work with new contextualization logic
ds-jakub-cierocki Jul 4, 2024
9b2e131
linter-recommended fixes
ds-jakub-cierocki Jul 4, 2024
2d0ef4b
contextualization mechanism - dedicated unit tests
ds-jakub-cierocki Jul 5, 2024
6466f61
cleaned up overengineered code remanining from the previous iteration…
ds-jakub-cierocki Jul 5, 2024
637f7fa
replaced pydantic.BaseModel by dataclasses.dataclass, pydantic no lon…
ds-jakub-cierocki Jul 8, 2024
f867e25
BaseCallerContext: dataclass w.o. fields -> interface (abstract class…
ds-jakub-cierocki Jul 8, 2024
3423033
LLM now pastes Context() instead of BaseCallerContext() to indicate t…
ds-jakub-cierocki Jul 8, 2024
0d8cd1e
docstring typo fixes; more precise return type hint
ds-jakub-cierocki Jul 9, 2024
c97ba15
renamed Context() -> AskerContext(); added more detailed detailed exa…
ds-jakub-cierocki Jul 9, 2024
1294a9c
type hint parsing changes: SomeCustomContext -> AskerContext; Union[a…
ds-jakub-cierocki Jul 9, 2024
999759b
refactor: collection.results.[ViewExecutionResult, ExecutionResult]."…
ds-jakub-cierocki Jul 12, 2024
2e1005a
param type parsing: correctly handling builtins types with args (e.g.…
ds-jakub-cierocki Jul 12, 2024
820066d
type hint fix: explcitly marked BaseCallerContext.alias as typing.Cla…
ds-jakub-cierocki Jul 12, 2024
25fbfa6
docs + benchmarks adjusted to meet new naming [ExecutionResult, ViewE…
ds-jakub-cierocki Jul 15, 2024
a154577
redesigned context-not-available error to follow the same principles …
ds-jakub-cierocki Jul 15, 2024
623effd
EXPERIMENTAL: reworked context injection such it is handled immediate…
ds-jakub-cierocki Jul 15, 2024
afacf5b
additional unit tests for the new contextualization mechanism
ds-jakub-cierocki Jul 19, 2024
dd8b339
context benchmark script and data
ds-jakub-cierocki Jul 22, 2024
6bb0816
refactored main prompt (too long lines), missing end-of-line characters
ds-jakub-cierocki Jul 22, 2024
f388f92
better error handling
ds-jakub-cierocki Jul 22, 2024
fbecc51
context benchmark dataset fix
ds-jakub-cierocki Jul 23, 2024
5d4ff64
added polars-based accuracy summary to the benchmark
ds-jakub-cierocki Jul 23, 2024
e7e8826
adjusted prompt to reduce halucinations: nested filter/context calls …
ds-jakub-cierocki Jul 23, 2024
f8bf64e
merged main (inc. new benchmarks + large refactor) -> jc/issue-54-req…
ds-jakub-cierocki Aug 7, 2024
c1c871b
merge main
micpst Sep 23, 2024
8eefd9b
fix linters
micpst Sep 23, 2024
c28091f
fix tests
micpst Sep 23, 2024
69a8d58
fix tests
micpst Sep 23, 2024
d6c8fc6
fix tests
micpst Sep 23, 2024
d7026d4
rm old benchmarks
micpst Sep 23, 2024
e8271ac
some renames and stuff
micpst Sep 23, 2024
bdcc7b3
fix benchmarks
micpst Sep 23, 2024
71f53be
merge main
micpst Sep 25, 2024
c82e579
rm chroma file
micpst Sep 25, 2024
f5a40cb
add contexts to benchmarks + fix types
micpst Sep 30, 2024
fab9d3f
small refactor
micpst Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/dbally/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import textwrap
import time
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Type, TypeVar
from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar

from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.event_tracker import EventTracker
from dbally.audit.events import RequestEnd, RequestStart
from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.results import ExecutionResult
from dbally.context.context import CustomContext
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.nl_responder.nl_responder import NLResponder
Expand Down Expand Up @@ -156,6 +157,7 @@ async def ask(
dry_run: bool = False,
return_natural_response: bool = False,
llm_options: Optional[LLMOptions] = None,
contexts: Optional[Iterable[CustomContext]] = None,
) -> ExecutionResult:
"""
Ask question in a text form and retrieve the answer based on the available views.
Expand All @@ -175,6 +177,8 @@ async def ask(
the natural response will be included in the answer
llm_options: options to use for the LLM client. If provided, these options will be merged with the default
options provided to the LLM client, prioritizing option values other than NOT_GIVEN
contexts: An iterable (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query.

Returns:
ExecutionResult object representing the result of the query execution.
Expand Down Expand Up @@ -215,6 +219,7 @@ async def ask(
n_retries=self.n_retries,
dry_run=dry_run,
llm_options=llm_options,
contexts=contexts,
)
end_time_view = time.monotonic()

Expand Down
3 changes: 3 additions & 0 deletions src/dbally/context/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .context import BaseCallerContext

__all__ = ["BaseCallerContext"]
75 changes: 75 additions & 0 deletions src/dbally/context/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from inspect import isclass
from typing import Any, Optional, Sequence, Tuple, Type, Union

import typing_extensions as type_ext

from dbally.context.context import BaseCallerContext
from dbally.views.exposed_functions import MethodParamWithTyping

ContextClass: type_ext.TypeAlias = Optional[Type[BaseCallerContext]]


def _extract_params_and_context(
filter_method_: type_ext.Callable, hidden_args: Sequence[str]
) -> Tuple[Sequence[MethodParamWithTyping], ContextClass]:
"""
Processes the MethodsBaseView filter method signauture to extract the args and type hints in the desired format.
micpst marked this conversation as resolved.
Show resolved Hide resolved
Context claases are getting excluded the returned MethodParamWithTyping list. Only the first BaseCallerContext
class is returned.

Args:
filter_method_: MethodsBaseView filter method (annotated with @decorators.view_filter() decorator)
hidden_args: method arguments that should not be extracted

Returns:
The first field contains the list of arguments, each encapsulated as MethodParamWithTyping.
The 2nd is the BaseCallerContext subclass provided for this filter, or None if no context specified.
"""

params = []
context = None
# TODO confirm whether to use typing.get_type_hints(method) or method.__annotations__
for name_, type_ in type_ext.get_type_hints(filter_method_).items():
if name_ in hidden_args:
continue

if isclass(type_) and issubclass(type_, BaseCallerContext):
# this is the case when user provides a context but no other type hint for a specifc arg
context = type_
type_ = Any
elif type_ext.get_origin(type_) is Union:
union_subtypes = type_ext.get_args(type_)
if not union_subtypes:
type_ = Any

for subtype_ in union_subtypes: # type: ignore
# TODO add custom error for the situation when user provides more than two contexts for a single filter
# for now we extract only the first context
if isclass(subtype_) and issubclass(subtype_, BaseCallerContext):
if context is None:
context = subtype_

params.append(MethodParamWithTyping(name_, type_))

return params, context


def _does_arg_allow_context(arg: MethodParamWithTyping) -> bool:
"""
Verifies whether a method argument allows contextualization based on the type hints attached to a method signature.

Args:
arg: MethodParamWithTyping container preserving information about the method argument

Returns:
Verification result.
"""

if type_ext.get_origin(arg.type) is not Union and not issubclass(arg.type, BaseCallerContext):
return False

for subtype in type_ext.get_args(arg.type):
if issubclass(subtype, BaseCallerContext):
return True

return False
64 changes: 64 additions & 0 deletions src/dbally/context/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import ast
from abc import ABC
from typing import Iterable

from typing_extensions import Self, TypeAlias

from dbally.context.exceptions import ContextNotAvailableError

CustomContext: TypeAlias = "BaseCallerContext"


class BaseCallerContext(ABC):
"""
An interface for contexts that are used to pass additional knowledge about
the caller environment to the filters. LLM will always return `Context()`
when the context is required and this call will be later substituted by an instance of
a class implementing this interface, selected based on the filter method signature (type hints).
"""

_alias: str = "Context"
jcierocki marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def select_context(cls, contexts: Iterable[CustomContext]) -> Self:
"""
Typically called from a subclass of BaseCallerContext, selects a member object from `contexts` being
an instance of the same class. Effectively provides a type dispatch mechanism, substituting the context
class by its right instance.

Args:
contexts: A sequence of objects, each being an instance of a different BaseCallerContext subclass.

Returns:
An instance of the same BaseCallerContext subclass this method is caller from.

Raises:
ContextNotAvailableError: If the sequence of context objects passed as argument is empty.
"""

if not contexts:
raise ContextNotAvailableError(
"The LLM detected that the context is required to execute the query +\
and the filter signature allows contextualization while the context was not provided."
)

# TODO confirm whether it is possible to design a correct type hints here and skipping `type: ignore`
return next(filter(lambda obj: isinstance(obj, cls), contexts)) # type: ignore

@classmethod
def is_context_call(cls, node: ast.expr) -> bool:
"""
Verifies whether an AST node indicates context substitution.

Args:
node: An AST node (expression) to verify:

Returns:
Verification result.
"""

return (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id in [cls._alias, cls.__name__]
)
26 changes: 26 additions & 0 deletions src/dbally/context/exceptions.py
micpst marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from abc import ABC


class BaseContextException(Exception, ABC):
"""
A base (abstract) exception for all specification context-related exception.
"""


class ContextNotAvailableError(Exception):
"""
An exception inheriting from BaseContextException pointining that no sufficient context information
was provided by the user while calling view.ask().
"""


class ContextualisationNotAllowed(Exception):
"""
An exception inheriting from BaseContextException pointining that the filter method signature
does not allow to provide an additional context.
"""


# WORKAROUND - traditional inhertiance syntax is not working in context of abstract Exceptions
BaseContextException.register(ContextNotAvailableError)
BaseContextException.register(ContextualisationNotAllowed)
76 changes: 65 additions & 11 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import ast
from typing import TYPE_CHECKING, Any, List, Optional, Union
from typing import Any, Iterable, List, Mapping, Optional, Union

from dbally.audit.event_tracker import EventTracker
from dbally.context._utils import _does_arg_allow_context
from dbally.context.context import BaseCallerContext, CustomContext
from dbally.context.exceptions import ContextualisationNotAllowed
from dbally.iql import syntax
from dbally.iql._exceptions import (
IQLArgumentParsingError,
Expand All @@ -11,21 +14,46 @@
IQLUnsupportedSyntaxError,
)
from dbally.iql._type_validators import validate_arg_type

if TYPE_CHECKING:
from dbally.views.structured import ExposedFunction
from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping


class IQLProcessor:
"""
Parses IQL string to tree structure.

Attributes:
source: Raw LLM response containing IQL filter calls.
allowed_functions: A mapping (typically a dict) of all filters implemented for a certain View.
contexts: A sequence (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query.
"""

source: str
allowed_functions: Mapping[str, "ExposedFunction"]
contexts: Iterable[CustomContext]
_event_tracker: EventTracker

def __init__(
self, source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None
self,
source: str,
allowed_functions: Iterable[ExposedFunction],
contexts: Optional[Iterable[CustomContext]] = None,
event_tracker: Optional[EventTracker] = None,
) -> None:
"""
IQLProcessor class constructor.

Args:
source: Raw LLM response containing IQL filter calls.
allowed_functions: An interable (typically a list) of all filters implemented for a certain View.
contexts: An iterable (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext.
even_tracker: An EvenTracker instance.
"""

self.source = source
self.allowed_functions = {func.name: func for func in allowed_functions}
self.contexts = contexts or []
self._event_tracker = event_tracker or EventTracker()

async def process(self) -> syntax.Node:
Expand All @@ -38,6 +66,7 @@ async def process(self) -> syntax.Node:
Raises:
IQLError: if parsing fails.
"""

self.source = self._to_lower_except_in_quotes(self.source, ["AND", "OR", "NOT"])

ast_tree = ast.parse(self.source)
Expand Down Expand Up @@ -84,13 +113,13 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:
if len(func_def.parameters) != len(node.args):
raise ValueError(f"The method {func.id} has incorrect number of arguments")

for arg, arg_def in zip(node.args, func_def.parameters):
arg_value = self._parse_arg(arg)
for arg, arg_spec in zip(node.args, func_def.parameters):
arg_value = self._parse_arg(arg, arg_spec=arg_spec, parent_func_def=func_def)

if arg_def.similarity_index:
arg_value = await arg_def.similarity_index.similar(arg_value, event_tracker=self._event_tracker)
if arg_spec.similarity_index:
arg_value = await arg_spec.similarity_index.similar(arg_value, event_tracker=self._event_tracker)

check_result = validate_arg_type(arg_def.type, arg_value)
check_result = validate_arg_type(arg_spec.type, arg_value)

if not check_result.valid:
raise IQLArgumentValidationError(message=check_result.reason or "", node=arg, source=self.source)
Expand All @@ -99,12 +128,37 @@ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall:

return syntax.FunctionCall(func.id, args)

def _parse_arg(self, arg: ast.expr) -> Any:
def _parse_arg(
self,
arg: ast.expr,
arg_spec: Optional[MethodParamWithTyping] = None,
parent_func_def: Optional[ExposedFunction] = None,
) -> Any:
if isinstance(arg, ast.List):
return [self._parse_arg(x) for x in arg.elts]

if BaseCallerContext.is_context_call(arg):
if parent_func_def is None or arg_spec is None:
# not sure whether this line will be ever reached
raise IQLArgumentParsingError(arg, self.source)

if parent_func_def.context_class is None:
raise ContextualisationNotAllowed(
"The LLM detected that the context is required +\
to execute the query while the filter signature does not allow it at all."
)

if not _does_arg_allow_context(arg_spec):
raise ContextualisationNotAllowed(
f"The LLM detected that the context is required +\
to execute the query while the filter signature does allow it for `{arg_spec.name}` argument."
)

return parent_func_def.context_class.select_context(self.contexts)

if not isinstance(arg, ast.Constant):
raise IQLArgumentParsingError(arg, self.source)

return arg.value

@staticmethod
Expand Down
14 changes: 11 additions & 3 deletions src/dbally/iql/_query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Iterable, List, Optional

from typing_extensions import Self

from dbally.context.context import CustomContext

from ..audit.event_tracker import EventTracker
from . import syntax
Expand Down Expand Up @@ -28,16 +32,20 @@ async def parse(
source: str,
allowed_functions: List["ExposedFunction"],
event_tracker: Optional[EventTracker] = None,
) -> "IQLQuery":
contexts: Optional[Iterable[CustomContext]] = None,
) -> Self:
"""
Parse IQL string to IQLQuery object.

Args:
source: IQL string that needs to be parsed
allowed_functions: list of IQL functions that are allowed for this query
event_tracker: EventTracker object to track events
contexts: An iterable (typically a list) of context objects, each being
an instance of a subclass of BaseCallerContext.
Returns:
IQLQuery object
"""
root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process()

root = await IQLProcessor(source, allowed_functions, contexts, event_tracker).process()
return cls(root=root, source=source)
Loading
Loading