Skip to content

Commit

Permalink
Add template field tests to AWS operators part1 (apache#42183)
Browse files Browse the repository at this point in the history
* adding template_fields tests in operators
  • Loading branch information
gopidesupavan authored Sep 12, 2024
1 parent b7a4e4d commit f9d0315
Show file tree
Hide file tree
Showing 17 changed files with 400 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/providers/amazon/aws/operators/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils import timezone
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

TEST_DAG_ID = "unit_tests"
DEFAULT_DATE = datetime(2018, 1, 1)
Expand Down Expand Up @@ -397,3 +398,6 @@ def mock_get_table_metadata(CatalogName, DatabaseName, TableName):
run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="12345", source="awsathena")},
)
assert op.get_openlineage_facets_on_complete(None) == expected_lineage

def test_template_fields(self):
validate_template_fields(self.athena)
27 changes: 27 additions & 0 deletions tests/providers/amazon/aws/operators/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
BedrockInvokeModelOperator,
BedrockRaGOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
Expand Down Expand Up @@ -176,6 +177,9 @@ def test_ensure_unique_job_name(self, _, side_effect, ensure_unique_name, mock_c
bedrock_hook.get_waiter.assert_not_called()
self.operator.defer.assert_not_called()

def test_template_fields(self):
validate_template_fields(self.operator)


class TestBedrockCreateProvisionedModelThroughputOperator:
MODEL_ARN = "testProvisionedModelArn"
Expand Down Expand Up @@ -222,6 +226,9 @@ def test_provisioned_model_wait_combinations(
assert bedrock_hook.get_waiter.call_count == wait_for_completion
assert self.operator.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.operator)


class TestBedrockCreateKnowledgeBaseOperator:
KNOWLEDGE_BASE_ID = "knowledge_base_id"
Expand Down Expand Up @@ -288,6 +295,9 @@ def test_returns_id(self, mock_conn):

assert result == self.KNOWLEDGE_BASE_ID

def test_template_fields(self):
validate_template_fields(self.operator)


class TestBedrockCreateDataSourceOperator:
DATA_SOURCE_ID = "data_source_id"
Expand Down Expand Up @@ -317,6 +327,9 @@ def test_id_returned(self, mock_conn):

assert result == self.DATA_SOURCE_ID

def test_template_fields(self):
validate_template_fields(self.operator)


class TestBedrockIngestDataOperator:
INGESTION_JOB_ID = "ingestion_job_id"
Expand Down Expand Up @@ -348,6 +361,9 @@ def test_id_returned(self, mock_conn):

assert result == self.INGESTION_JOB_ID

def test_template_fields(self):
validate_template_fields(self.operator)


class TestBedrockRaGOperator:
VECTOR_SEARCH_CONFIG = {"filter": {"equals": {"key": "some key", "value": "some value"}}}
Expand Down Expand Up @@ -520,3 +536,14 @@ def test_external_sources_build_rag_config(self, prompt_template):
**expected_config_without_template,
**expected_config_template,
}

def test_template_fields(self):
op = BedrockRaGOperator(
task_id="test_rag",
input="some text prompt",
source_type="EXTERNAL_SOURCES",
model_arn=self.MODEL_ARN,
knowledge_base_id=self.KNOWLEDGE_BASE_ID,
vector_search_config=self.VECTOR_SEARCH_CONFIG,
)
validate_template_fields(op)
28 changes: 28 additions & 0 deletions tests/providers/amazon/aws/operators/test_cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CloudFormationDeleteStackOperator,
)
from airflow.utils import timezone
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

DEFAULT_DATE = timezone.datetime(2019, 1, 1)
DEFAULT_ARGS = {"owner": "airflow", "start_date": DEFAULT_DATE}
Expand Down Expand Up @@ -87,6 +88,20 @@ def test_create_stack(self, mocked_hook_client):
StackName=stack_name, TemplateBody=template_body, TimeoutInMinutes=timeout
)

def test_template_fields(self):
op = CloudFormationCreateStackOperator(
task_id="cf_create_stack_init",
stack_name="fake-stack",
cloudformation_parameters={},
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="eu-west-1",
verify=True,
botocore_config={"read_timeout": 42},
)

validate_template_fields(op)


class TestCloudFormationDeleteStackOperator:
def test_init(self):
Expand Down Expand Up @@ -125,3 +140,16 @@ def test_delete_stack(self, mocked_hook_client):
operator.execute(MagicMock())

mocked_hook_client.delete_stack.assert_any_call(StackName=stack_name)

def test_template_fields(self):
op = CloudFormationDeleteStackOperator(
task_id="cf_delete_stack_init",
stack_name="fake-stack",
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="us-east-1",
verify=False,
botocore_config={"read_timeout": 42},
)

validate_template_fields(op)
7 changes: 7 additions & 0 deletions tests/providers/amazon/aws/operators/test_comprehend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ComprehendStartPiiEntitiesDetectionJobOperator,
)
from airflow.utils.types import NOTSET
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
Expand Down Expand Up @@ -163,6 +164,9 @@ def test_start_pii_entities_detection_job_wait_combinations(
assert comprehend_hook.get_waiter.call_count == wait_for_completion
assert self.operator.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.operator)


class TestComprehendCreateDocumentClassifierOperator:
CLASSIFIER_ARN = (
Expand Down Expand Up @@ -259,3 +263,6 @@ def test_create_document_classifier_wait_combinations(
assert response == self.CLASSIFIER_ARN
assert comprehend_hook.get_waiter.call_count == wait_for_completion
assert self.operator.defer.call_count == deferrable

def test_template_fields(self):
validate_template_fields(self.operator)
5 changes: 5 additions & 0 deletions tests/providers/amazon/aws/operators/test_datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.utils import timezone
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

TEST_DAG_ID = "unit_tests"
DEFAULT_DATE = datetime(2018, 1, 1)
Expand Down Expand Up @@ -363,6 +364,10 @@ def test_return_value(self, mock_get_conn, session, clean_dags_and_dagruns):
# ### Check mocks:
mock_get_conn.assert_called()

def test_template_fields(self, mock_get_conn):
self.set_up_operator()
validate_template_fields(self.datasync)


@mock_aws
@mock.patch.object(DataSyncHook, "get_conn")
Expand Down
64 changes: 64 additions & 0 deletions tests/providers/amazon/aws/operators/test_dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)
from airflow.utils import timezone
from airflow.utils.types import DagRunType
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields

TASK_ARN = "test_arn"

Expand Down Expand Up @@ -121,6 +122,18 @@ def test_create_task_with_migration_type(

assert dms_hook.get_task_status(TASK_ARN) == "ready"

def test_template_fields(self):
op = DmsCreateTaskOperator(
task_id="create_task",
**self.TASK_DATA,
aws_conn_id="fake-conn-id",
region_name="ca-west-1",
verify=True,
botocore_config={"read_timeout": 42},
)

validate_template_fields(op)


class TestDmsDeleteTaskOperator:
TASK_DATA = {
Expand Down Expand Up @@ -174,6 +187,19 @@ def test_delete_task(

assert dms_hook.get_task_status(TASK_ARN) == "deleting"

def test_template_fields(self):
op = DmsDeleteTaskOperator(
task_id="delete_task",
replication_task_arn=TASK_ARN,
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="us-east-1",
verify=False,
botocore_config={"read_timeout": 42},
)

validate_template_fields(op)


class TestDmsDescribeTasksOperator:
FILTER = {"Name": "replication-task-arn", "Values": [TASK_ARN]}
Expand Down Expand Up @@ -267,6 +293,18 @@ def test_describe_tasks_return_value(self, mock_conn, mock_describe_replication_
assert marker is None
assert response == self.MOCK_RESPONSE

def test_template_fields(self):
op = DmsDescribeTasksOperator(
task_id="describe_tasks",
describe_tasks_kwargs={"Filters": [self.FILTER]},
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="eu-west-2",
verify="/foo/bar/spam.egg",
botocore_config={"read_timeout": 42},
)
validate_template_fields(op)


class TestDmsStartTaskOperator:
TASK_DATA = {
Expand Down Expand Up @@ -324,6 +362,19 @@ def test_start_task(

assert dms_hook.get_task_status(TASK_ARN) == "starting"

def test_template_fields(self):
op = DmsStartTaskOperator(
task_id="start_task",
replication_task_arn=TASK_ARN,
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="us-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)

validate_template_fields(op)


class TestDmsStopTaskOperator:
TASK_DATA = {
Expand Down Expand Up @@ -376,3 +427,16 @@ def test_stop_task(
mock_stop_replication_task.assert_called_once_with(replication_task_arn=TASK_ARN)

assert dms_hook.get_task_status(TASK_ARN) == "stopping"

def test_template_fields(self):
op = DmsStopTaskOperator(
task_id="stop_task",
replication_task_arn=TASK_ARN,
# Generic hooks parameters
aws_conn_id="fake-conn-id",
region_name="eu-west-1",
verify=True,
botocore_config={"read_timeout": 42},
)

validate_template_fields(op)
51 changes: 51 additions & 0 deletions tests/providers/amazon/aws/operators/test_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
EC2StopInstanceOperator,
EC2TerminateInstanceOperator,
)
from tests.providers.amazon.aws.utils.test_template_fields import validate_template_fields


class BaseEc2TestClass:
Expand Down Expand Up @@ -87,6 +88,13 @@ def test_create_multiple_instances(self):
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

def test_template_fields(self):
ec2_operator = EC2CreateInstanceOperator(
task_id="test_create_instance",
image_id="test_image_id",
)
validate_template_fields(ec2_operator)


class TestEC2TerminateInstanceOperator(BaseEc2TestClass):
def test_init(self):
Expand Down Expand Up @@ -140,6 +148,13 @@ def test_terminate_multiple_instances(self):
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "terminated"

def test_template_fields(self):
ec2_operator = EC2TerminateInstanceOperator(
task_id="test_terminate_instance",
instance_ids="test_image_id",
)
validate_template_fields(ec2_operator)


class TestEC2StartInstanceOperator(BaseEc2TestClass):
def test_init(self):
Expand Down Expand Up @@ -175,6 +190,17 @@ def test_start_instance(self):
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"

def test_template_fields(self):
ec2_operator = EC2StartInstanceOperator(
task_id="task_test",
instance_id="i-123abc",
aws_conn_id="aws_conn_test",
region_name="region-test",
check_interval=3,
)

validate_template_fields(ec2_operator)


class TestEC2StopInstanceOperator(BaseEc2TestClass):
def test_init(self):
Expand Down Expand Up @@ -210,6 +236,17 @@ def test_stop_instance(self):
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped"

def test_template_fields(self):
ec2_operator = EC2StopInstanceOperator(
task_id="task_test",
instance_id="i-123abc",
aws_conn_id="aws_conn_test",
region_name="region-test",
check_interval=3,
)

validate_template_fields(ec2_operator)


class TestEC2HibernateInstanceOperator(BaseEc2TestClass):
def test_init(self):
Expand Down Expand Up @@ -322,6 +359,13 @@ def test_cannot_hibernate_some_instances(self):
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

def test_template_fields(self):
ec2_operator = EC2HibernateInstanceOperator(
task_id="task_test",
instance_ids="i-123abc",
)
validate_template_fields(ec2_operator)


class TestEC2RebootInstanceOperator(BaseEc2TestClass):
def test_init(self):
Expand Down Expand Up @@ -372,3 +416,10 @@ def test_reboot_multiple_instances(self):
terminate_instance.execute(None)
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

def test_template_fields(self):
ec2_operator = EC2RebootInstanceOperator(
task_id="task_test",
instance_ids="i-123abc",
)
validate_template_fields(ec2_operator)
Loading

0 comments on commit f9d0315

Please sign in to comment.