From f4337d505b1a785b80f91051075e3310aa857e10 Mon Sep 17 00:00:00 2001 From: Bichitra Kumar Sahoo <32828151+bichitra95@users.noreply.github.com> Date: Wed, 7 Jun 2023 11:02:05 +0530 Subject: [PATCH 1/5] fix: column_metadata_catalog_column method to default value table_catalog --- soda/athena/soda/data_sources/athena_data_source.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/soda/athena/soda/data_sources/athena_data_source.py b/soda/athena/soda/data_sources/athena_data_source.py index cfd710061..479fd6a64 100644 --- a/soda/athena/soda/data_sources/athena_data_source.py +++ b/soda/athena/soda/data_sources/athena_data_source.py @@ -100,10 +100,6 @@ def quote_column(self, column_name: str) -> str: def regex_replace_flags(self) -> str: return "" - @staticmethod - def column_metadata_catalog_column() -> str: - return "table_schema" - def default_casify_table_name(self, identifier: str) -> str: return identifier.lower() From 21cabc7c65b8e190b5bc342e96a31bea35545263 Mon Sep 17 00:00:00 2001 From: Bichitra Kumar Sahoo <32828151+bichitra95@users.noreply.github.com> Date: Mon, 24 Jul 2023 17:37:47 +0530 Subject: [PATCH 2/5] fix databricks datatype cast and support for int type decimal --- soda/spark/soda/data_sources/spark_data_source.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/soda/spark/soda/data_sources/spark_data_source.py b/soda/spark/soda/data_sources/spark_data_source.py index b66b60eae..033f4ed82 100644 --- a/soda/spark/soda/data_sources/spark_data_source.py +++ b/soda/spark/soda/data_sources/spark_data_source.py @@ -425,6 +425,7 @@ class SparkDataSource(SparkSQLBase): def __init__(self, logs: Logs, data_source_name: str, data_source_properties: dict): super().__init__(logs, data_source_name, data_source_properties) + self.NUMERIC_TYPES_FOR_PROFILING = ["integer", "int", "double", "float", "decimal"] self.method = data_source_properties.get("method", "hive") self.host = data_source_properties.get("host", "localhost") @@ -474,3 +475,6 @@ def connect(self): self.connection = connection except Exception as e: raise DataSourceConnectionError(self.type, e) + + def cast_to_text(self, expr: str) -> str: + return f"CAST({expr} AS VARCHAR(100))" From 6d8e8b0ef1af0e68cc556554dc21fa1a0b9f1777 Mon Sep 17 00:00:00 2001 From: Bichitra Kumar Sahoo <32828151+bichitra95@users.noreply.github.com> Date: Wed, 6 Sep 2023 14:17:30 +0530 Subject: [PATCH 3/5] Databricks: Profiling support for bigint datatypes --- soda/spark/soda/data_sources/spark_data_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/soda/spark/soda/data_sources/spark_data_source.py b/soda/spark/soda/data_sources/spark_data_source.py index 033f4ed82..31d7dac93 100644 --- a/soda/spark/soda/data_sources/spark_data_source.py +++ b/soda/spark/soda/data_sources/spark_data_source.py @@ -425,7 +425,7 @@ class SparkDataSource(SparkSQLBase): def __init__(self, logs: Logs, data_source_name: str, data_source_properties: dict): super().__init__(logs, data_source_name, data_source_properties) - self.NUMERIC_TYPES_FOR_PROFILING = ["integer", "int", "double", "float", "decimal"] + self.NUMERIC_TYPES_FOR_PROFILING = ["integer", "int", "double", "float", "decimal", "bigint"] self.method = data_source_properties.get("method", "hive") self.host = data_source_properties.get("host", "localhost") From fc7cff55c608254ea552036f18d6f8635ca9ad22 Mon Sep 17 00:00:00 2001 From: sharma-shreyas Date: Wed, 12 Jun 2024 16:13:49 +0530 Subject: [PATCH 4/5] add support for externalId for athena connection --- soda/athena/soda/data_sources/athena_data_source.py | 2 ++ soda/core/soda/common/aws_credentials.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/soda/athena/soda/data_sources/athena_data_source.py b/soda/athena/soda/data_sources/athena_data_source.py index e07ed05bb..e4135c336 100644 --- a/soda/athena/soda/data_sources/athena_data_source.py +++ b/soda/athena/soda/data_sources/athena_data_source.py @@ -34,6 +34,7 @@ def __init__( session_token=data_source_properties.get("session_token"), region_name=data_source_properties.get("region_name"), profile_name=data_source_properties.get("profile_name"), + external_id=data_source_properties.get("external_id") ) def connect(self): @@ -45,6 +46,7 @@ def connect(self): s3_staging_dir=self.athena_staging_dir, region_name=self.aws_credentials.region_name, role_arn=self.aws_credentials.role_arn, + external_id=self.aws_credentials.external_id, catalog_name=self.catalog, work_group=self.work_group, schema_name=self.schema, diff --git a/soda/core/soda/common/aws_credentials.py b/soda/core/soda/common/aws_credentials.py index 9bd7ac1ac..8dcac8ca9 100644 --- a/soda/core/soda/common/aws_credentials.py +++ b/soda/core/soda/common/aws_credentials.py @@ -12,10 +12,12 @@ def __init__( session_token: Optional[str] = None, profile_name: Optional[str] = None, region_name: Optional[str] = "eu-west-1", + external_id: Optional[str] = None, ): self.access_key_id = access_key_id self.secret_access_key = secret_access_key self.role_arn = role_arn + self.external_id = external_id self.session_token = session_token self.profile_name = profile_name self.region_name = region_name @@ -32,6 +34,7 @@ def from_configuration(cls, configuration: dict): access_key_id=access_key_id, secret_access_key=configuration.get("secret_access_key"), role_arn=configuration.get("role_arn"), + external_id=configuration.get("external_id"), session_token=configuration.get("session_token"), profile_name=configuration.get("profile_name"), region_name=configuration.get("region", "eu-west-1"), From 653d537ddcd3da8d4316f3809f89585c468d5446 Mon Sep 17 00:00:00 2001 From: Divyanshu Patel Date: Fri, 11 Oct 2024 17:55:48 +0530 Subject: [PATCH 5/5] iam role fixes --- soda/core/soda/common/aws_credentials.py | 2 +- .../redshift/soda/data_sources/redshift_data_source.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/soda/core/soda/common/aws_credentials.py b/soda/core/soda/common/aws_credentials.py index 8dcac8ca9..36ca1957f 100644 --- a/soda/core/soda/common/aws_credentials.py +++ b/soda/core/soda/common/aws_credentials.py @@ -58,7 +58,7 @@ def assume_role(self, role_session_name: str): aws_session_token=self.session_token, ) - assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, RoleSessionName=role_session_name) + assumed_role_object = self.sts_client.assume_role(RoleArn=self.role_arn, ExternalId=self.external_id, RoleSessionName=role_session_name) credentials_dict = assumed_role_object["Credentials"] return AwsCredentials( region_name=self.region_name, diff --git a/soda/redshift/soda/data_sources/redshift_data_source.py b/soda/redshift/soda/data_sources/redshift_data_source.py index 9f07bacd3..f4e5f850d 100644 --- a/soda/redshift/soda/data_sources/redshift_data_source.py +++ b/soda/redshift/soda/data_sources/redshift_data_source.py @@ -22,6 +22,9 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di self.connect_timeout = data_source_properties.get("connection_timeout_sec") self.username = data_source_properties.get("username") self.password = data_source_properties.get("password") + self.dbuser = data_source_properties.get("dbuser") + self.dbname = data_source_properties.get("dbname") + self.cluster_id = data_source_properties.get("cluster_id") if not self.username or not self.password: aws_credentials = AwsCredentials( @@ -31,6 +34,7 @@ def __init__(self, logs: Logs, data_source_name: str, data_source_properties: di session_token=data_source_properties.get("session_token"), region_name=data_source_properties.get("region", "eu-west-1"), profile_name=data_source_properties.get("profile_name"), + external_id=data_source_properties.get("external_id"), ) self.username, self.password = self.__get_cluster_credentials(aws_credentials) @@ -60,9 +64,9 @@ def __get_cluster_credentials(self, aws_credentials: AwsCredentials): aws_session_token=resolved_aws_credentials.session_token, ) - cluster_name = self.host.split(".")[0] - username = self.username - db_name = self.database + cluster_name = self.cluster_id if self.cluster_id else self.host.split(".")[0] + username = self.dbuser if self.dbuser else self.username + db_name = self.dbname if self.dbname else self.database cluster_creds = client.get_cluster_credentials( DbUser=username, DbName=db_name, ClusterIdentifier=cluster_name, AutoCreate=False, DurationSeconds=3600 )