From 3044c290896daf2c0dfdda471dfd70da6ab769f6 Mon Sep 17 00:00:00 2001 From: Vitor Bellini Date: Sat, 13 Jul 2024 12:10:47 -0300 Subject: [PATCH] refac db-connection --- fastetl/custom_functions/fast_etl.py | 20 ++++++---- .../custom_functions/utils/db_connection.py | 39 +++++++++---------- fastetl/hooks/db_to_db_hook.py | 8 ++-- fastetl/operators/db_to_db_operator.py | 26 +++++-------- 4 files changed, 43 insertions(+), 50 deletions(-) diff --git a/fastetl/custom_functions/fast_etl.py b/fastetl/custom_functions/fast_etl.py index a23331e..f868de2 100644 --- a/fastetl/custom_functions/fast_etl.py +++ b/fastetl/custom_functions/fast_etl.py @@ -258,8 +258,8 @@ def copy_db_to_db( """ # validate connections - source = SourceConnection(**source) - destination = DestinationConnection(**destination) + source = SourceConnection(source) + destination = DestinationConnection(destination) # create table if not exists in destination db if not source.query: @@ -621,13 +621,17 @@ def _divide_chunks(l, n): if copy_table_comments: _copy_table_comments( source=SourceConnection( - conn_id=source_conn_id, - schema=source_table_name.split(".")[0], - table=source_table_name.split(".")[1], + { + "conn_id": source_conn_id, + "schema": source_table_name.split(".")[0], + "table": source_table_name.split(".")[1], + } ), destination=DestinationConnection( - conn_id=destination_conn_id, - schema=dest_table_name.split(".")[0], - table=dest_table_name.split(".")[1], + { + "conn_id": destination_conn_id, + "schema": dest_table_name.split(".")[0], + "table": dest_table_name.split(".")[1], + } ), ) diff --git a/fastetl/custom_functions/utils/db_connection.py b/fastetl/custom_functions/utils/db_connection.py index 02f5fd8..9908258 100644 --- a/fastetl/custom_functions/utils/db_connection.py +++ b/fastetl/custom_functions/utils/db_connection.py @@ -86,24 +86,21 @@ class SourceConnection: conn_type (str): Connection type/provider. """ - def __init__( - self, - conn_id: str, - schema: str = None, - table: str = None, - query: str = None, - ): - if not conn_id: + def __init__(self, params: dict): + self.conn_id = params.get("conn_id", None) + self.schema = params.get("schema", None) + self.table = params.get("table", None) + self.query = params.get("query", None) + + if not self.conn_id: raise ValueError("conn_id argument cannot be empty") - if not query and not (schema or table): + if not self.query and not ( + self.schema or self.table + ): raise ValueError("must provide either schema and table or query") - self.conn_id = conn_id - self.schema = schema - self.table = table - self.query = query - self.conn_type = get_conn_type(conn_id) - conn_values = BaseHook.get_connection(conn_id) + self.conn_type = get_conn_type(self.conn_id) + conn_values = BaseHook.get_connection(self.conn_id) self.conn_database = conn_values.schema @@ -124,12 +121,12 @@ class DestinationConnection: conn_type (str): Connection type/provider. """ - def __init__(self, conn_id: str, schema: str, table: str): - self.conn_id = conn_id - self.schema = schema - self.table = table - self.conn_type = get_conn_type(conn_id) - conn_values = BaseHook.get_connection(conn_id) + def __init__(self, params: dict): + self.conn_id = params.get("conn_id", None) + self.schema = params.get("schema", None) + self.table = params.get("table", None) + self.conn_type = get_conn_type(self.conn_id) + conn_values = BaseHook.get_connection(self.conn_id) self.conn_database = conn_values.schema diff --git a/fastetl/hooks/db_to_db_hook.py b/fastetl/hooks/db_to_db_hook.py index 993b670..a22020b 100644 --- a/fastetl/hooks/db_to_db_hook.py +++ b/fastetl/hooks/db_to_db_hook.py @@ -51,14 +51,14 @@ def incremental_copy( copy_table_comments: bool = False, ): sync_db_2_db( - source_conn_id=self.source.conn_id, - destination_conn_id=self.destination.conn_id, - source_schema=self.source.schema, + source_conn_id=self.source["conn_id"], + destination_conn_id=self.destination["conn_id"], + source_schema=self.source["schema"], source_exc_schema=self.source.get("source_exc_schema", None), source_exc_table=self.source.get("source_exc_table", None), source_exc_column=self.source.get("source_exc_column", None), select_sql=self.source.get("query", None), - destination_schema=self.destination.schema, + destination_schema=self.destination["schema"], increment_schema=self.destination.get("increment_schema", None), table=table, date_column=date_column, diff --git a/fastetl/operators/db_to_db_operator.py b/fastetl/operators/db_to_db_operator.py index 5e1c470..eb985e1 100644 --- a/fastetl/operators/db_to_db_operator.py +++ b/fastetl/operators/db_to_db_operator.py @@ -98,7 +98,6 @@ class DbToDbOperator(BaseOperator): template_fields = ["source"] - @apply_defaults def __init__( self, source: Dict[str, str], @@ -127,6 +126,15 @@ def __init__( self.key_column = key_column self.since_datetime = since_datetime self.sync_exclusions = sync_exclusions + + # rename if schema_name is present + if source.get("schema_name", None): + source["schema"] = source.pop("schema_name") + if destination.get("schema_name", None): + destination["schema"] = destination.pop("schema_name") + self.source = source + self.destination = destination + # any value that needs to be the same for inlets and outlets key = str(random.randint(10000000, 99999999)) if source.get("om_service", None): @@ -135,22 +143,6 @@ def __init__( self.outlets = [ OMEntity(entity=Table, fqn=self._get_fqn(destination), key=key) ] - # rename if schema_name is present - if source.get("schema_name", None): - source["schema"] = source.pop("schema_name") - if destination.get("schema_name", None): - destination["schema"] = destination.pop("schema_name") - # filter to keys accepted by DbToDbHook - self.source = { - key: source[key] - for key in ["conn_id", "schema", "table", "query"] - if key in source - } - self.destination = { - key: destination[key] - for key in ["conn_id", "schema", "table"] - if key in destination - } def _get_fqn(self, data): data["database"] = BaseHook.get_connection(data["conn_id"]).schema