From ceca07921199d5feaa7f6b12b3e42182d66358c5 Mon Sep 17 00:00:00 2001 From: Leo Schick <67712864+leo-schick@users.noreply.github.com> Date: Tue, 31 Jan 2023 14:44:03 +0100 Subject: [PATCH] Refactor cursor_context calls (#68) * add connect, cusor_context to DB * refactor cursor context * add tests * add deprecation warnings --- mara_db/bigquery.py | 21 +++++----- mara_db/databricks.py | 17 ++------ mara_db/dbs.py | 72 +++++++++++++++++++++++++++++++++ mara_db/mysql.py | 19 ++------- mara_db/postgresql.py | 19 +++------ mara_db/sqlserver.py | 23 ++--------- tests/db_test_helper.py | 23 ++++++++++- tests/mssql/test_mssql.py | 16 ++++++++ tests/postgres/test_postgres.py | 16 ++++++++ tests/test_databricks.py | 16 ++++++++ 10 files changed, 165 insertions(+), 77 deletions(-) diff --git a/mara_db/bigquery.py b/mara_db/bigquery.py index ba5cfcb..a024e88 100644 --- a/mara_db/bigquery.py +++ b/mara_db/bigquery.py @@ -2,6 +2,7 @@ import contextlib import typing +from warnings import warn import mara_db.dbs import sys @@ -35,18 +36,14 @@ def bigquery_client(db: typing.Union[str, mara_db.dbs.BigQueryDB]) -> 'google.cl def bigquery_cursor_context(db: typing.Union[str, mara_db.dbs.BigQueryDB]) \ -> 'google.cloud.bigquery.dbapi.cursor.Cursor': """Creates a context with a bigquery cursor for a database alias""" - client = bigquery_client(db) - - from google.cloud.bigquery.dbapi.connection import Connection - - connection = Connection(client) - cursor = connection.cursor() # type: google.cloud.bigquery.dbapi.cursor.Cursor - try: - yield cursor - connection.commit() - except Exception as e: - connection.close() - raise e + warn('Function bigquery_cursor_context(db) is deprecated. Please use db.cursor_context() instead.') + + if isinstance(db, str): + db = mara_db.dbs.db(db) + + assert (isinstance(db, mara_db.dbs.BigQueryDB)) + + return db.cursor_context() def create_bigquery_table_from_postgresql_query( diff --git a/mara_db/databricks.py b/mara_db/databricks.py index 4d1a578..7f79a31 100644 --- a/mara_db/databricks.py +++ b/mara_db/databricks.py @@ -2,6 +2,7 @@ import contextlib import typing +from warnings import warn import mara_db.dbs @@ -9,23 +10,11 @@ @contextlib.contextmanager def databricks_cursor_context(db: typing.Union[str, mara_db.dbs.DatabricksDB]) \ -> 'databricks.sql.client.Cursor': - from databricks_dbapi import odbc + warn('Function databricks_cursor_context(db) is deprecated. Please use db.cursor_context() instead.') if isinstance(db, str): db = mara_db.dbs.db(db) assert (isinstance(db, mara_db.dbs.DatabricksDB)) - connection = odbc.connect( - host=db.host, - http_path=db.http_path, - token=db.access_token, - driver_path=db.odbc_driver_path) - - cursor = connection.cursor() # type: databricks.sql.client.Cursor - try: - yield cursor - connection.commit() - except Exception as e: - connection.close() - raise e + return db.cursor_context() diff --git a/mara_db/dbs.py b/mara_db/dbs.py index 91353d2..00171c1 100644 --- a/mara_db/dbs.py +++ b/mara_db/dbs.py @@ -1,5 +1,6 @@ """Abstract definition of database connections""" +import contextlib import functools import pathlib @@ -28,6 +29,37 @@ def sqlalchemy_url(self): """Returns the SQLAlchemy url for a database""" raise NotImplementedError(f'Please implement sqlalchemy_url for type "{self.__class__.__name__}"') + def connect(self) -> object: + """ + Constructor for creating a connection to the database. + The returned connection object is PIP-249 compatible (DB-API). + + See also: https://peps.python.org/pep-0249/#connection-objects + """ + raise NotImplementedError(f'Please implement connect for type "{self.__class__.__name__}"') + + @contextlib.contextmanager + def cursor_context(self) -> object: + """ + A single iteration with a cursor context. When the iteration is + closed, a commit is executed on the cursor. + + Example usage: + with db.cursor_context() as c: + c.execute('UPDATE table SET table.c1 = 1 WHERE table.id = 5') + """ + connection = self.connect() + try: + cursor = connection.cursor() + yield cursor + connection.commit() + except Exception: + connection.rollback() + raise + finally: + cursor.close() + connection.close() + class PostgreSQLDB(DB): def __init__(self, host: str = None, port: int = None, database: str = None, @@ -53,6 +85,11 @@ def __init__(self, host: str = None, port: int = None, database: str = None, def sqlalchemy_url(self): return (f'postgresql+psycopg2://{self.user}{":" + self.password if self.password else ""}@{self.host}' + f'{":" + str(self.port) if self.port else ""}/{self.database}') + + def connect(self) -> 'psycopg2.extensions.cursor': + import psycopg2 + return psycopg2.connect(dbname=self.database, user=self.user, password=self.password, + host=self.host, port=self.port) class RedshiftDB(PostgreSQLDB): @@ -106,6 +143,14 @@ def sqlalchemy_url(self): url += '/' + self.dataset return url + def connect(self): + from google.oauth2.service_account import Credentials + from google.cloud.bigquery.client import Client + from google.cloud.bigquery.dbapi.connection import Connection + credentials = Credentials.from_service_account_file(self.service_account_json_file_name) + client = Client(project=credentials.project_id, credentials=credentials, location=self.location) + return Connection(client) + class MysqlDB(DB): def __init__(self, host: str = None, port: int = None, database: str = None, @@ -117,6 +162,12 @@ def __init__(self, host: str = None, port: int = None, database: str = None, self.password = password self.ssl = ssl self.charset = charset + + def connect(self) -> 'MySQLdb.cursors.Cursor': + import MySQLdb.cursors # requires https://github.com/PyMySQL/mysqlclient-python + return MySQLdb.connect( + host=self.host, user=self.user, passwd=self.password, db=self.database, port=self.port, + cursorclass=MySQLdb.cursors.Cursor) class SQLServerDB(DB): @@ -156,6 +207,14 @@ def sqlalchemy_url(self): driver = self.odbc_driver.replace(' ','+') return f'mssql+pyodbc://{self.user}:{urllib.parse.quote(self.password)}@{self.host}:{port}/{self.database}?driver={driver}' + def connect(self) -> 'pyodbc.Cursor': + import pyodbc # requires https://github.com/mkleehammer/pyodbc/wiki/Install + server = self.host + if self.port: # connecting via TCP/IP port + server = f"{server},{self.port}" + return pyodbc.connect(f"DRIVER={{{self.odbc_driver}}};SERVER={server};DATABASE={self.database};UID={self.user};PWD={self.password}" \ + + (';Encrypt=YES;TrustServerCertificate=YES' if self.trust_server_certificate else '')) + class SqshSQLServerDB(SQLServerDB): def __init__(self, host: str = None, port: int = None, database: str = None, @@ -200,6 +259,7 @@ def sqlalchemy_url(self): return super().sqlalchemy_url \ + ('&TrustServerCertificate=yes' if self.trust_server_certificate else '') + class OracleDB(DB): def __init__(self, host: str = None, port: int = 0, endpoint: str = None, user: str = None, password: str = None): self.host = host @@ -216,6 +276,10 @@ def __init__(self, file_name: pathlib.Path) -> None: @property def sqlalchemy_url(self): return f'sqlite:///{self.file_name}' + + def connect(self): + import sqlite3 + return sqlite3.connect(database=self.file_name) class SnowflakeDB(DB): @@ -267,3 +331,11 @@ def __init__(self, host: str = None, http_path: str = None, access_token: str = @property def sqlalchemy_url(self): return f"databricks+connector://token:{self.access_token}@{self.host}:443/" + + def connect(self): + from databricks_dbapi import odbc + return odbc.connect( + host=self.host, + http_path=self.http_path, + token=self.access_token, + driver_path=self.odbc_driver_path) diff --git a/mara_db/mysql.py b/mara_db/mysql.py index c00fd06..bb51842 100644 --- a/mara_db/mysql.py +++ b/mara_db/mysql.py @@ -2,6 +2,7 @@ import contextlib import typing +from warnings import warn import mara_db.dbs @@ -9,25 +10,11 @@ @contextlib.contextmanager def mysql_cursor_context(db: typing.Union[str, mara_db.dbs.MysqlDB]) -> 'MySQLdb.cursors.Cursor': """Creates a context with a mysql-client cursor for a database alias or database""" - import MySQLdb.cursors # requires https://github.com/PyMySQL/mysqlclient-python + warn('Function mysql_cursor_context(db) is deprecated. Please use db.cursor_context() instead.') if isinstance(db, str): db = mara_db.dbs.db(db) assert (isinstance(db, mara_db.dbs.MysqlDB)) - cursor = None - connection = MySQLdb.connect( - host=db.host, user=db.user, passwd=db.password, db=db.database, port=db.port, - cursorclass=MySQLdb.cursors.Cursor) - try: - cursor = connection.cursor() - yield cursor - except Exception: - connection.rollback() - raise - else: - connection.commit() - finally: - cursor.close() - connection.close() + return db.cursor_context() diff --git a/mara_db/postgresql.py b/mara_db/postgresql.py index fca64c8..06aaa9d 100644 --- a/mara_db/postgresql.py +++ b/mara_db/postgresql.py @@ -2,28 +2,19 @@ import contextlib import typing +from warnings import warn + import mara_db.dbs @contextlib.contextmanager def postgres_cursor_context(db: typing.Union[str, mara_db.dbs.PostgreSQLDB]) -> 'psycopg2.extensions.cursor': """Creates a context with a psycopg2 cursor for a database alias""" - import psycopg2 - import psycopg2.extensions + warn('Function databricks_cursor_context(db) is deprecated. Please use db.cursor_context() instead.') if isinstance(db, str): db = mara_db.dbs.db(db) assert (isinstance(db, mara_db.dbs.PostgreSQLDB)) - connection = psycopg2.connect(dbname=db.database, user=db.user, password=db.password, - host=db.host, port=db.port) # type: psycopg2.extensions.connection - cursor = connection.cursor() # type: psycopg2.extensions.cursor - try: - yield cursor - connection.commit() - except Exception as e: - connection.rollback() - raise e - finally: - cursor.close() - connection.close() + + return db.cursor_context() diff --git a/mara_db/sqlserver.py b/mara_db/sqlserver.py index f71927e..39e5c52 100644 --- a/mara_db/sqlserver.py +++ b/mara_db/sqlserver.py @@ -2,6 +2,7 @@ import contextlib import typing +from warnings import warn import mara_db.dbs @@ -9,29 +10,11 @@ @contextlib.contextmanager def sqlserver_cursor_context(db: typing.Union[str, mara_db.dbs.SQLServerDB]) -> 'pyodbc.Cursor': """Creates a context with a pyodbc-client cursor for a database alias or database""" - import pyodbc # requires https://github.com/mkleehammer/pyodbc/wiki/Install + warn('Function sqlserver_cursor_context(db) is deprecated. Please use db.cursor_context() instead.') if isinstance(db, str): db = mara_db.dbs.db(db) assert (isinstance(db, mara_db.dbs.SQLServerDB)) - cursor = None - - server = db.host - if db.port: # connecting via TCP/IP port - server = f"{server},{db.port}" - - connection = pyodbc.connect(f"DRIVER={{{db.odbc_driver}}};SERVER={server};DATABASE={db.database};UID={db.user};PWD={db.password}" \ - + (';Encrypt=YES;TrustServerCertificate=YES' if db.trust_server_certificate else '')) - try: - cursor = connection.cursor() - yield cursor - except Exception: - connection.rollback() - raise - else: - connection.commit() - finally: - cursor.close() - connection.close() + return db.cursor_context() diff --git a/tests/db_test_helper.py b/tests/db_test_helper.py index d0e551f..94feb91 100644 --- a/tests/db_test_helper.py +++ b/tests/db_test_helper.py @@ -26,7 +26,7 @@ def db_replace_placeholders(db: dbs.DB, docker_ip: str, docker_port: int) -> dbs Basic tests which can be used for different DB engines. """ -def _test_sqlalchemy(db): +def _test_sqlalchemy(db: dbs.DB): """ A simple test to check if the SQLAlchemy connection works """ @@ -40,3 +40,24 @@ def _test_sqlalchemy(db): # the SELECT of a scalar value without a table is # appropriately formatted for the backend assert conn.scalar(select(1)) == 1 + +def _test_connect(db: dbs.DB): + connection = db.connect() + cursor = connection.cursor() + try: + cursor.execute('SELECT 1') + row = cursor.fetchone() + assert row[0] == 1 + connection.commit() + except Exception as e: + connection.rollback() + raise e + finally: + cursor.close() + connection.close() + +def _test_cursor_context(db: dbs.DB): + with db.cursor_context() as cursor: + cursor.execute('SELECT 1') + row = cursor.fetchone() + assert row[0] == 1 diff --git a/tests/mssql/test_mssql.py b/tests/mssql/test_mssql.py index 78b68f2..bfdcfb3 100644 --- a/tests/mssql/test_mssql.py +++ b/tests/mssql/test_mssql.py @@ -56,6 +56,22 @@ def test_mssql_sqlalchemy(mssql_db): _test_sqlalchemy(mssql_db) +def test_mssql_connect(mssql_db): + """ + A simple test to check if the connect API works. + """ + from ..db_test_helper import _test_connect + _test_connect(mssql_db) + + +def test_mssql_cursor_context(mssql_db): + """ + A simple test to check if the cursor context of the db works. + """ + from ..db_test_helper import _test_cursor_context + _test_cursor_context(mssql_db) + + """ ################################################################################################################################# diff --git a/tests/postgres/test_postgres.py b/tests/postgres/test_postgres.py index 4d5fc1a..945c68b 100644 --- a/tests/postgres/test_postgres.py +++ b/tests/postgres/test_postgres.py @@ -150,3 +150,19 @@ def test_postgres_sqlalchemy(postgres_db): """ from ..db_test_helper import _test_sqlalchemy _test_sqlalchemy(postgres_db) + + +def test_postgres_connect(postgres_db): + """ + A simple test to check if the connect API works. + """ + from ..db_test_helper import _test_connect + _test_connect(postgres_db) + + +def test_postgres_cursor_context(postgres_db): + """ + A simple test to check if the cursor context of the db works. + """ + from ..db_test_helper import _test_cursor_context + _test_cursor_context(postgres_db) diff --git a/tests/test_databricks.py b/tests/test_databricks.py index 6c31dad..266e798 100644 --- a/tests/test_databricks.py +++ b/tests/test_databricks.py @@ -41,3 +41,19 @@ def test_databricks_sqlalchemy(): engine = sqlalchemy_engine.engine(DATABRICKS_DB) with engine.connect() as con: con.execute(statement = text("SELECT 1")) + + +def test_databricks_connect(): + """ + A simple test to check if the connect API works. + """ + from .db_test_helper import _test_connect + _test_connect(DATABRICKS_DB) + + +def test_databricks_cursor_context(): + """ + A simple test to check if the cursor context of the db works. + """ + from .db_test_helper import _test_cursor_context + _test_cursor_context(DATABRICKS_DB)