Skip to content

Commit

Permalink
Call PDL from Python API (#7)
Browse files Browse the repository at this point in the history
* Add `pdl.pdl.exec_*` functions with examples
  • Loading branch information
mandel authored Sep 4, 2024
1 parent bbbb04e commit c900d02
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 78 deletions.
8 changes: 8 additions & 0 deletions examples/sdk/hello.pdl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
document:
- Hello,
- model: watsonx/ibm/granite-20b-code-instruct
parameters:
stop:
- '!'
include_stop_sequence: true
- "\n"
21 changes: 21 additions & 0 deletions examples/sdk/hello_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pdl.pdl import exec_dict

hello = {
"document": [
"Hello,",
{
"model": "watsonx/ibm/granite-20b-code-instruct",
"parameters": {"stop": ["!"], "include_stop_sequence": True},
},
"\n",
]
}


def main():
result = exec_dict(hello)
print(result)


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions examples/sdk/hello_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from pdl.pdl import exec_file


def main():
result = exec_file("./hello.pdl")
print(result)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions examples/sdk/hello_prog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from pdl.pdl import exec_program
from pdl.pdl_ast import DocumentBlock, LitellmModelBlock, LitellmParameters, Program

hello = Program(
DocumentBlock(
document=[
"Hello,",
LitellmModelBlock(
model="watsonx/ibm/granite-20b-code-instruct",
parameters=LitellmParameters(
stop=["!"], include_stop_sequence=True # pyright: ignore
),
),
"\n",
]
)
)


def main():
result = exec_program(hello)
print(result)


if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions examples/sdk/hello_str.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pdl.pdl import exec_str

HELLO = """
document:
- Hello,
- model: watsonx/ibm/granite-20b-code-instruct
parameters:
stop:
- '!'
include_stop_sequence: true
- "\n"
"""


def main():
result = exec_str(HELLO)
print(result)


if __name__ == "__main__":
main()
116 changes: 115 additions & 1 deletion pdl/pdl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,125 @@
import argparse
import json
from typing import Any, Optional, TypedDict

import yaml
from pydantic.json_schema import models_json_schema

from . import pdl_interpreter
from .pdl_ast import PdlBlock, PdlBlocks, Program
from .pdl_ast import (
LocationType,
PdlBlock,
PdlBlocks,
Program,
RoleType,
ScopeType,
empty_block_location,
)
from .pdl_interpreter import InterpreterState, process_prog
from .pdl_parser import parse_file, parse_str


class InterpreterConfig(TypedDict, total=False):
"""Configuration parameters of the PDL interpreter."""

yield_output: bool
"""Print the program messages during the execution.
"""
batch: int
"""Execution type:
- 0: streaming
- 1: non-streaming
"""
role: RoleType
"""Default role.
"""


def exec_program(
prog: Program,
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
loc: Optional[LocationType] = None,
):
"""Execute a PDL program given as a value of type `pdl.pdl_ast.Program`.
Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.
loc: Source code location mapping. Defaults to None.
Returns:
Return the final result.
"""
config = config or {}
state = InterpreterState(**config)
scope = scope or {}
loc = loc or empty_block_location
result = process_prog(state, scope, prog, loc)
return result


def exec_dict(
prog: dict[str, Any],
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
loc: Optional[LocationType] = None,
):
"""Execute a PDL program given as a dictionary.
Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.
loc: Source code location mapping. Defaults to None.
Returns:
Return the final result.
"""
program = Program.model_validate(prog)
result = exec_program(program, config, scope, loc)
return result


def exec_str(
prog: str,
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
):
"""Execute a PDL program given as YAML string.
Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.
Returns:
Return the final result.
"""
program, loc = parse_str(prog)
result = exec_program(program, config, scope, loc)
return result


def exec_file(
prog: str,
config: Optional[InterpreterConfig] = None,
scope: Optional[ScopeType] = None,
):
"""Execute a PDL program given as YAML file.
Args:
prog: Program to execute.
config: Interpreter configuration. Defaults to None.
scope: Environment defining the initial variables in scope to execute the program. Defaults to None.
Returns:
Return the final result.
"""
program, loc = parse_file(prog)
result = exec_program(program, config, scope, loc)
return result


def main():
Expand Down
129 changes: 69 additions & 60 deletions pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from .pdl_dumper import block_to_dict, dump_yaml
from .pdl_llms import BamModel, LitellmModel, WatsonxModel
from .pdl_location_utils import append, get_loc_string
from .pdl_parser import PDLParseError, parse_program
from .pdl_parser import PDLParseError, parse_file
from .pdl_scheduler import ModelCallMessage, OutputMessage, YieldMessage, schedule
from .pdl_schema_validator import type_check_args, type_check_spec

Expand Down Expand Up @@ -105,9 +105,8 @@ def generate(
if log_file is None:
log_file = "log.txt"
try:
prog, line_table = parse_program(pdl_file)
prog, loc = parse_file(pdl_file)
state = InterpreterState(yield_output=True)
loc = LocationType(path=[], file=pdl_file, table=line_table)
_, _, _, trace = process_prog(state, initial_scope, prog, loc)
with open(log_file, "w", encoding="utf-8") as log_fp:
for line in state.log:
Expand Down Expand Up @@ -515,61 +514,10 @@ def step_block_body(
result = closure
background = []
trace = closure.model_copy(update={})
case CallBlock(call=f):
result = None
background = []
args, errors = process_expr(scope, block.args, append(loc, "args"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "args"), None, errors, block.model_copy()
)
closure_expr, errors = process_expr(scope, block.call, append(loc, "call"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "call"), None, errors, block.model_copy()
)
closure = get_var(closure_expr, scope)
if closure is None:
trace = handle_error(
block,
append(loc, "call"),
f"Function is undefined: {f}",
[],
block.model_copy(),
)
else:
argsloc = append(loc, "args")
type_errors = type_check_args(args, closure.function, argsloc)
if len(type_errors) > 0:
trace = handle_error(
block,
argsloc,
f"Type errors during function call to {f}",
type_errors,
block.model_copy(),
)
else:
f_body = closure.returns
f_scope = closure.scope | {"context": scope["context"]} | args
funloc = LocationType(
file=closure.location.file,
path=closure.location.path + ["return"],
table=loc.table,
)
result, background, _, f_trace = yield from step_blocks(
IterationType.SEQUENCE, state, f_scope, f_body, funloc
)
trace = block.model_copy(update={"trace": f_trace})
if closure.spec is not None:
errors = type_check_spec(result, closure.spec, funloc)
if len(errors) > 0:
trace = handle_error(
block,
loc,
f"Type errors in result of function call to {f}",
errors,
trace,
)
case CallBlock():
result, background, scope, trace = yield from step_call(
state, scope, block, loc
)
case EmptyBlock():
result = ""
background = []
Expand Down Expand Up @@ -1074,6 +1022,68 @@ def call_python(code: str, scope: dict) -> Any:
return result


def step_call(
state: InterpreterState, scope: ScopeType, block: CallBlock, loc: LocationType
) -> Generator[
YieldMessage, Any, tuple[Any, Messages, ScopeType, CallBlock | ErrorBlock]
]:
result = None
background: Messages = []
args, errors = process_expr(scope, block.args, append(loc, "args"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "args"), None, errors, block.model_copy()
)
closure_expr, errors = process_expr(scope, block.call, append(loc, "call"))
if len(errors) != 0:
trace = handle_error(
block, append(loc, "call"), None, errors, block.model_copy()
)
closure = get_var(closure_expr, scope)
if closure is None:
trace = handle_error(
block,
append(loc, "call"),
f"Function is undefined: {block.call}",
[],
block.model_copy(),
)
else:
argsloc = append(loc, "args")
type_errors = type_check_args(args, closure.function, argsloc)
if len(type_errors) > 0:
trace = handle_error(
block,
argsloc,
f"Type errors during function call to {closure_expr}",
type_errors,
block.model_copy(),
)
else:
f_body = closure.returns
f_scope = closure.scope | {"context": scope["context"]} | args
funloc = LocationType(
file=closure.location.file,
path=closure.location.path + ["return"],
table=loc.table,
)
result, background, _, f_trace = yield from step_blocks(
IterationType.SEQUENCE, state, f_scope, f_body, funloc
)
trace = block.model_copy(update={"trace": f_trace})
if closure.spec is not None:
errors = type_check_spec(result, closure.spec, funloc)
if len(errors) > 0:
trace = handle_error(
block,
loc,
f"Type errors in result of function call to {closure_expr}",
errors,
trace,
)
return result, background, scope, trace


def process_input(
state: InterpreterState, scope: ScopeType, block: ReadBlock, loc: LocationType
) -> tuple[str, Messages, ScopeType, ReadBlock | ErrorBlock]:
Expand Down Expand Up @@ -1121,8 +1131,7 @@ def step_include(
YieldMessage, Any, tuple[Any, Messages, ScopeType, IncludeBlock | ErrorBlock]
]:
try:
prog, line_table = parse_program(block.include)
newloc = LocationType(file=block.include, path=[], table=line_table)
prog, newloc = parse_file(block.include)
result, background, scope, trace = yield from step_block(
state, scope, prog.root, newloc
)
Expand Down
4 changes: 2 additions & 2 deletions pdl/pdl_location_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def get_paths(
return ret


def get_line_map(file) -> dict[str, int]:
def get_line_map(prog: str) -> dict[str, int]:
indentation = []
fields = []
is_array_item = []
for line in file.readlines(): # line numbers are off by one
for line in prog.split("\n"): # line numbers are off by one
fields.append(
line.strip().split(":")[0].replace("-", "").strip()
if line.find(":") != -1
Expand Down
Loading

0 comments on commit c900d02

Please sign in to comment.