Skip to content

Commit

Permalink
feat: replace naive evaluation with lark
Browse files Browse the repository at this point in the history
* fixing expression by enforcing @ expression
* adding array support
  • Loading branch information
LeonardHd committed Nov 29, 2023
1 parent 1c81e43 commit e043b07
Show file tree
Hide file tree
Showing 11 changed files with 899 additions and 414 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ExpressionEvaluationValueError(Exception):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,25 @@
import inspect
from typing import Callable

from lark import Token, Transformer, Tree

from lark import Discard, Token, Transformer

from azure_data_factory_testing_framework.exceptions.activity_not_found_error import ActivityNotFoundError
from azure_data_factory_testing_framework.exceptions.dataset_parameter_not_found_error import (
DatasetParameterNotFoundError,
)
from azure_data_factory_testing_framework.exceptions.expression_evaluation_value_error import (
ExpressionEvaluationValueError,
)
from azure_data_factory_testing_framework.exceptions.expression_parameter_not_found_error import (
ExpressionParameterNotFoundError,
)
from azure_data_factory_testing_framework.exceptions.linked_service_parameter_not_found_error import (
LinkedServiceParameterNotFoundError,
)
from azure_data_factory_testing_framework.exceptions.state_iteration_item_not_set_error import (
StateIterationItemNotSetError,
)
from azure_data_factory_testing_framework.exceptions.variable_not_found_error import VariableNotFoundError
from azure_data_factory_testing_framework.functions.functions_repository import FunctionsRepository
from azure_data_factory_testing_framework.state.pipeline_run_state import PipelineRunState
from azure_data_factory_testing_framework.state.run_parameter import RunParameter
Expand All @@ -16,17 +33,72 @@ def __init__(self, state: PipelineRunState) -> None:
self.state: PipelineRunState = state
super().__init__()

def literal(self, value: list[Token]):
def LITERAL_LETTER(self, token: Token): # noqa: N802
return str(token.value)

def LITERAL_INT(self, token: Token): # noqa: N802
return int(token.value)

def LITERAL_FLOAT(self, token: Token): # noqa: N802
return float(token.value)

def LITERAL_SINGLE_QUOTED_STRING(self, token: Token): # noqa: N802
return str(token.value)

def LITERAL_BOOLEAN(self, token: Token): # noqa: N802
return bool(token.value)

def LITERAL_NULL(self, token: Token): # noqa: N802
return None

def literal_evaluation(self, value: list[Token, str, int, float, bool]):
if len(value) != 1:
raise ExpressionEvaluationValueError()
return value[0]

def parameter_name(self, value: list[Token]):
return value[0].value
def EXPRESSION_NULL(self, token: Token): # noqa: N802
return None

def EXPRESSION_STRING(self, token: Token): # noqa: N802
string = str(token.value)
string = string.replace("''", "'") # replace escaped single quotes
string = string[1:-1]

return string

def pipeline_reference(self, value: list[Token]):
pipeline_reference_property = value[0]
pipeline_reference_property_parameter = value[1]
def EXPRESSION_INTEGER(self, token: Token): # noqa: N802
return int(token.value)

def EXPRESSION_FLOAT(self, token: Token): # noqa: N802
return float(token.value)

def EXPRESSION_BOOLEAN(self, token: Token): # noqa: N802
return bool(token.value)

def EXPRESSION_WS(self, token: Token): # noqa: N802
# Discard whitespaces in expressions
return Discard

def EXPRESSION_ARRAY_INDEX(self, token: Token): # noqa: N802
token.value = int(token.value[1:-1])
return token

def expression_pipeline_reference(self, value: list[Token, str, int, float, bool]):
if not isinstance(value[0], Token):
raise ExpressionEvaluationValueError()

if not isinstance(value[1], Token):
raise ExpressionEvaluationValueError()

pipeline_reference_property: Token = value[0]
pipeline_reference_property_parameter: Token = value[1]

if not (
pipeline_reference_property.type == "EXPRESSION_PIPELINE_PROPERTY"
and pipeline_reference_property_parameter.type == "EXPRESSION_PARAMETER_NAME"
):
raise ExpressionEvaluationValueError()

# TODO: need to improve this
global_parameters: list[RunParameter] = list(
filter(lambda p: p.type == RunParameterType.Global, self.state.parameters)
)
Expand All @@ -39,108 +111,86 @@ def pipeline_reference(self, value: list[Token]):
first = list(filter(lambda p: p.name == pipeline_reference_property_parameter, global_parameters))

if len(first) == 0:
raise Exception("Parameter not found")
raise ExpressionParameterNotFoundError(pipeline_reference_property_parameter)

return first[0].value

def variable_reference(self, value: list[Token]):
def expression_variable_reference(self, value: list[Token, str, int, float, bool]):
variable_name = value[0].value
variable_name = variable_name[1:-1] # remove quotes

# variable_property = value[1].value

variable = list(filter(lambda p: p.name == variable_name, self.state.variables))

if len(variable) == 0:
raise Exception("Variable not found")
raise VariableNotFoundError(variable_name)
return variable[0].value

def dataset_reference(self, value: list[Token]):
def expression_dataset_reference(self, value: list[Token, str, int, float, bool]):
dataset_name = value[0].value
dataset_property = value[1].value

dataset = list(filter(lambda p: p.name == dataset_name, self.state.datasets))
dataset_name = dataset_name[1:-1] # remove quotes
datasets = list(filter(lambda p: p.type == RunParameterType.Dataset, self.state.parameters))
dataset = list(filter(lambda p: p.name == dataset_name, datasets))

if len(dataset) == 0:
raise Exception("Dataset not found")
return dataset[0].value[dataset_property]
raise DatasetParameterNotFoundError(dataset_name)
return dataset[0].value

def linked_service_reference(self, value: list[Token]):
def expression_linked_service_reference(self, value: list[Token, str, int, float, bool]):
linked_service_name = value[0].value
linked_service_property = value[1].value

linked_service = list(filter(lambda p: p.name == linked_service_name, self.state.linked_services))
linked_service_name = linked_service_name[1:-1] # remove quotes
linked_services = list(filter(lambda p: p.type == RunParameterType.LinkedService, self.state.parameters))
linked_service = list(filter(lambda p: p.name == linked_service_name, linked_services))

if len(linked_service) == 0:
raise Exception("Linked service not found")
return linked_service[0].value[linked_service_property]
raise LinkedServiceParameterNotFoundError(linked_service_name)
return linked_service[0].value

def activity_reference(self, value: list[Token]):
def expression_activity_reference(self, value: list[Token, str, int, float, bool]):
activity_name = value[0].value
activity_name = activity_name[1:-1] # remove quotes
# activity_property = value[1].value
activity_property = value[1]
# activity_property_parameter = value[2].value
activity_property_parameter = value[2]
property_fields = value[2:]

activity = self.state.try_get_scoped_activity_result_by_name(activity_name)
if activity is None:
raise Exception("Activity not found")
return activity[activity_property][activity_property_parameter]
raise ActivityNotFoundError(activity_name)

def item_reference(self, value: list[Token]):
activity_property_parameter = activity[activity_property]
for field in property_fields:
field_value = field.value
activity_property_parameter = activity_property_parameter[field_value]
return activity_property_parameter

def expression_item_reference(self, value: list[Token, str, int, float, bool]):
item = self.state.iteration_item
if item is None:
raise Exception("Item not found")
raise StateIterationItemNotSetError()
return item

def boolean(self, value: list[Token]): # noqa: ANN401, ANN201, ANN001
return bool(value[0].value)

def integer(self, value: list[Token]):
return int(value[0].value)

def float(self, value: Token): # noqa: A003
return float(value[0])

def single_quoted_string(self, value: list[Token]):
result: str = value[0].value
result = result[1:-1] # remove quotes
return result

def string(self, value: list[Token]):
if isinstance(value[0], Token):
return value[0].value
else:
return value[0]

def function_parameters(self, value):
parameters = []

if isinstance(value, list):
for v in value:
if isinstance(v, Tree):
result = self.transform(v)
parameters.append(result)
else:
result = v
parameters.append(result)
else:
raise Exception("Unexpected value type")
return parameters

def parameter(self, value):
return value[1]

def expression(self, value):
# TODO: need to improve this (array support)
return value[0]

def function_call(self, expression):
fn = expression[0]
fn_parameters = expression[1]
array_index = expression[2] if len(expression) > 2 else None

def expression_function_parameters(self, values: list[Token, str, int, float, bool]):
return values

def expression_parameter(self, values: list[Token, str, int, float, bool]):
if len(values) != 1:
raise ExpressionEvaluationValueError
return values[0]

def expression_evaluation(self, values: list[Token, str, int, float, bool, list]):
eval_value = values[0]
array_indices = values[1]
for array_index in array_indices:
if array_index is None:
continue
if not isinstance(array_index, Token) or array_index.type != "EXPRESSION_ARRAY_INDEX":
raise ExpressionEvaluationValueError()
eval_value = eval_value[array_index.value]
return eval_value

def expression_array_indices(self, values: list[Token, str, int, float, bool]):
return values

def expression_function_call(self, values: list[Token, str, int, float, bool]):
fn = values[0]
fn_parameters = values[1]
function: Callable = FunctionsRepository.functions.get(fn.value)

pos_or_keyword_parameters = []
Expand All @@ -156,7 +206,4 @@ def function_call(self, expression):
var_positional_values = fn_parameters[len(pos_or_keyword_parameters) :] # should be 0 or 1

result = function(*pos_or_keyword_values, *var_positional_values)
if array_index is not None:
result = result[array_index]

return result
Loading

0 comments on commit e043b07

Please sign in to comment.