Skip to content

Commit

Permalink
Refactor cursor_context calls (#68)
Browse files Browse the repository at this point in the history
* add connect, cusor_context to DB
* refactor cursor context
* add tests
* add deprecation warnings
  • Loading branch information
leo-schick authored Jan 31, 2023
1 parent f072dcf commit ceca079
Show file tree
Hide file tree
Showing 10 changed files with 165 additions and 77 deletions.
21 changes: 9 additions & 12 deletions mara_db/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import typing
from warnings import warn

import mara_db.dbs
import sys
Expand Down Expand Up @@ -35,18 +36,14 @@ def bigquery_client(db: typing.Union[str, mara_db.dbs.BigQueryDB]) -> 'google.cl
def bigquery_cursor_context(db: typing.Union[str, mara_db.dbs.BigQueryDB]) \
-> 'google.cloud.bigquery.dbapi.cursor.Cursor':
"""Creates a context with a bigquery cursor for a database alias"""
client = bigquery_client(db)

from google.cloud.bigquery.dbapi.connection import Connection

connection = Connection(client)
cursor = connection.cursor() # type: google.cloud.bigquery.dbapi.cursor.Cursor
try:
yield cursor
connection.commit()
except Exception as e:
connection.close()
raise e
warn('Function bigquery_cursor_context(db) is deprecated. Please use db.cursor_context() instead.')

if isinstance(db, str):
db = mara_db.dbs.db(db)

assert (isinstance(db, mara_db.dbs.BigQueryDB))

return db.cursor_context()


def create_bigquery_table_from_postgresql_query(
Expand Down
17 changes: 3 additions & 14 deletions mara_db/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,19 @@

import contextlib
import typing
from warnings import warn

import mara_db.dbs


@contextlib.contextmanager
def databricks_cursor_context(db: typing.Union[str, mara_db.dbs.DatabricksDB]) \
-> 'databricks.sql.client.Cursor':
from databricks_dbapi import odbc
warn('Function databricks_cursor_context(db) is deprecated. Please use db.cursor_context() instead.')

if isinstance(db, str):
db = mara_db.dbs.db(db)

assert (isinstance(db, mara_db.dbs.DatabricksDB))

connection = odbc.connect(
host=db.host,
http_path=db.http_path,
token=db.access_token,
driver_path=db.odbc_driver_path)

cursor = connection.cursor() # type: databricks.sql.client.Cursor
try:
yield cursor
connection.commit()
except Exception as e:
connection.close()
raise e
return db.cursor_context()
72 changes: 72 additions & 0 deletions mara_db/dbs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Abstract definition of database connections"""

import contextlib
import functools
import pathlib

Expand Down Expand Up @@ -28,6 +29,37 @@ def sqlalchemy_url(self):
"""Returns the SQLAlchemy url for a database"""
raise NotImplementedError(f'Please implement sqlalchemy_url for type "{self.__class__.__name__}"')

def connect(self) -> object:
"""
Constructor for creating a connection to the database.
The returned connection object is PIP-249 compatible (DB-API).
See also: https://peps.python.org/pep-0249/#connection-objects
"""
raise NotImplementedError(f'Please implement connect for type "{self.__class__.__name__}"')

@contextlib.contextmanager
def cursor_context(self) -> object:
"""
A single iteration with a cursor context. When the iteration is
closed, a commit is executed on the cursor.
Example usage:
with db.cursor_context() as c:
c.execute('UPDATE table SET table.c1 = 1 WHERE table.id = 5')
"""
connection = self.connect()
try:
cursor = connection.cursor()
yield cursor
connection.commit()
except Exception:
connection.rollback()
raise
finally:
cursor.close()
connection.close()


class PostgreSQLDB(DB):
def __init__(self, host: str = None, port: int = None, database: str = None,
Expand All @@ -53,6 +85,11 @@ def __init__(self, host: str = None, port: int = None, database: str = None,
def sqlalchemy_url(self):
return (f'postgresql+psycopg2://{self.user}{":" + self.password if self.password else ""}@{self.host}'
+ f'{":" + str(self.port) if self.port else ""}/{self.database}')

def connect(self) -> 'psycopg2.extensions.cursor':
import psycopg2
return psycopg2.connect(dbname=self.database, user=self.user, password=self.password,
host=self.host, port=self.port)


class RedshiftDB(PostgreSQLDB):
Expand Down Expand Up @@ -106,6 +143,14 @@ def sqlalchemy_url(self):
url += '/' + self.dataset
return url

def connect(self):
from google.oauth2.service_account import Credentials
from google.cloud.bigquery.client import Client
from google.cloud.bigquery.dbapi.connection import Connection
credentials = Credentials.from_service_account_file(self.service_account_json_file_name)
client = Client(project=credentials.project_id, credentials=credentials, location=self.location)
return Connection(client)


class MysqlDB(DB):
def __init__(self, host: str = None, port: int = None, database: str = None,
Expand All @@ -117,6 +162,12 @@ def __init__(self, host: str = None, port: int = None, database: str = None,
self.password = password
self.ssl = ssl
self.charset = charset

def connect(self) -> 'MySQLdb.cursors.Cursor':
import MySQLdb.cursors # requires https://github.com/PyMySQL/mysqlclient-python
return MySQLdb.connect(
host=self.host, user=self.user, passwd=self.password, db=self.database, port=self.port,
cursorclass=MySQLdb.cursors.Cursor)


class SQLServerDB(DB):
Expand Down Expand Up @@ -156,6 +207,14 @@ def sqlalchemy_url(self):
driver = self.odbc_driver.replace(' ','+')
return f'mssql+pyodbc://{self.user}:{urllib.parse.quote(self.password)}@{self.host}:{port}/{self.database}?driver={driver}'

def connect(self) -> 'pyodbc.Cursor':
import pyodbc # requires https://github.com/mkleehammer/pyodbc/wiki/Install
server = self.host
if self.port: # connecting via TCP/IP port
server = f"{server},{self.port}"
return pyodbc.connect(f"DRIVER={{{self.odbc_driver}}};SERVER={server};DATABASE={self.database};UID={self.user};PWD={self.password}" \
+ (';Encrypt=YES;TrustServerCertificate=YES' if self.trust_server_certificate else ''))


class SqshSQLServerDB(SQLServerDB):
def __init__(self, host: str = None, port: int = None, database: str = None,
Expand Down Expand Up @@ -200,6 +259,7 @@ def sqlalchemy_url(self):
return super().sqlalchemy_url \
+ ('&TrustServerCertificate=yes' if self.trust_server_certificate else '')


class OracleDB(DB):
def __init__(self, host: str = None, port: int = 0, endpoint: str = None, user: str = None, password: str = None):
self.host = host
Expand All @@ -216,6 +276,10 @@ def __init__(self, file_name: pathlib.Path) -> None:
@property
def sqlalchemy_url(self):
return f'sqlite:///{self.file_name}'

def connect(self):
import sqlite3
return sqlite3.connect(database=self.file_name)


class SnowflakeDB(DB):
Expand Down Expand Up @@ -267,3 +331,11 @@ def __init__(self, host: str = None, http_path: str = None, access_token: str =
@property
def sqlalchemy_url(self):
return f"databricks+connector://token:{self.access_token}@{self.host}:443/"

def connect(self):
from databricks_dbapi import odbc
return odbc.connect(
host=self.host,
http_path=self.http_path,
token=self.access_token,
driver_path=self.odbc_driver_path)
19 changes: 3 additions & 16 deletions mara_db/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,19 @@

import contextlib
import typing
from warnings import warn

import mara_db.dbs


@contextlib.contextmanager
def mysql_cursor_context(db: typing.Union[str, mara_db.dbs.MysqlDB]) -> 'MySQLdb.cursors.Cursor':
"""Creates a context with a mysql-client cursor for a database alias or database"""
import MySQLdb.cursors # requires https://github.com/PyMySQL/mysqlclient-python
warn('Function mysql_cursor_context(db) is deprecated. Please use db.cursor_context() instead.')

if isinstance(db, str):
db = mara_db.dbs.db(db)

assert (isinstance(db, mara_db.dbs.MysqlDB))

cursor = None
connection = MySQLdb.connect(
host=db.host, user=db.user, passwd=db.password, db=db.database, port=db.port,
cursorclass=MySQLdb.cursors.Cursor)
try:
cursor = connection.cursor()
yield cursor
except Exception:
connection.rollback()
raise
else:
connection.commit()
finally:
cursor.close()
connection.close()
return db.cursor_context()
19 changes: 5 additions & 14 deletions mara_db/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,19 @@

import contextlib
import typing
from warnings import warn

import mara_db.dbs


@contextlib.contextmanager
def postgres_cursor_context(db: typing.Union[str, mara_db.dbs.PostgreSQLDB]) -> 'psycopg2.extensions.cursor':
"""Creates a context with a psycopg2 cursor for a database alias"""
import psycopg2
import psycopg2.extensions
warn('Function databricks_cursor_context(db) is deprecated. Please use db.cursor_context() instead.')

if isinstance(db, str):
db = mara_db.dbs.db(db)

assert (isinstance(db, mara_db.dbs.PostgreSQLDB))
connection = psycopg2.connect(dbname=db.database, user=db.user, password=db.password,
host=db.host, port=db.port) # type: psycopg2.extensions.connection
cursor = connection.cursor() # type: psycopg2.extensions.cursor
try:
yield cursor
connection.commit()
except Exception as e:
connection.rollback()
raise e
finally:
cursor.close()
connection.close()

return db.cursor_context()
23 changes: 3 additions & 20 deletions mara_db/sqlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,19 @@

import contextlib
import typing
from warnings import warn

import mara_db.dbs


@contextlib.contextmanager
def sqlserver_cursor_context(db: typing.Union[str, mara_db.dbs.SQLServerDB]) -> 'pyodbc.Cursor':
"""Creates a context with a pyodbc-client cursor for a database alias or database"""
import pyodbc # requires https://github.com/mkleehammer/pyodbc/wiki/Install
warn('Function sqlserver_cursor_context(db) is deprecated. Please use db.cursor_context() instead.')

if isinstance(db, str):
db = mara_db.dbs.db(db)

assert (isinstance(db, mara_db.dbs.SQLServerDB))

cursor = None

server = db.host
if db.port: # connecting via TCP/IP port
server = f"{server},{db.port}"

connection = pyodbc.connect(f"DRIVER={{{db.odbc_driver}}};SERVER={server};DATABASE={db.database};UID={db.user};PWD={db.password}" \
+ (';Encrypt=YES;TrustServerCertificate=YES' if db.trust_server_certificate else ''))
try:
cursor = connection.cursor()
yield cursor
except Exception:
connection.rollback()
raise
else:
connection.commit()
finally:
cursor.close()
connection.close()
return db.cursor_context()
23 changes: 22 additions & 1 deletion tests/db_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def db_replace_placeholders(db: dbs.DB, docker_ip: str, docker_port: int) -> dbs
Basic tests which can be used for different DB engines.
"""

def _test_sqlalchemy(db):
def _test_sqlalchemy(db: dbs.DB):
"""
A simple test to check if the SQLAlchemy connection works
"""
Expand All @@ -40,3 +40,24 @@ def _test_sqlalchemy(db):
# the SELECT of a scalar value without a table is
# appropriately formatted for the backend
assert conn.scalar(select(1)) == 1

def _test_connect(db: dbs.DB):
connection = db.connect()
cursor = connection.cursor()
try:
cursor.execute('SELECT 1')
row = cursor.fetchone()
assert row[0] == 1
connection.commit()
except Exception as e:
connection.rollback()
raise e
finally:
cursor.close()
connection.close()

def _test_cursor_context(db: dbs.DB):
with db.cursor_context() as cursor:
cursor.execute('SELECT 1')
row = cursor.fetchone()
assert row[0] == 1
16 changes: 16 additions & 0 deletions tests/mssql/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ def test_mssql_sqlalchemy(mssql_db):
_test_sqlalchemy(mssql_db)


def test_mssql_connect(mssql_db):
"""
A simple test to check if the connect API works.
"""
from ..db_test_helper import _test_connect
_test_connect(mssql_db)


def test_mssql_cursor_context(mssql_db):
"""
A simple test to check if the cursor context of the db works.
"""
from ..db_test_helper import _test_cursor_context
_test_cursor_context(mssql_db)



"""
#################################################################################################################################
Expand Down
16 changes: 16 additions & 0 deletions tests/postgres/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,19 @@ def test_postgres_sqlalchemy(postgres_db):
"""
from ..db_test_helper import _test_sqlalchemy
_test_sqlalchemy(postgres_db)


def test_postgres_connect(postgres_db):
"""
A simple test to check if the connect API works.
"""
from ..db_test_helper import _test_connect
_test_connect(postgres_db)


def test_postgres_cursor_context(postgres_db):
"""
A simple test to check if the cursor context of the db works.
"""
from ..db_test_helper import _test_cursor_context
_test_cursor_context(postgres_db)
Loading

0 comments on commit ceca079

Please sign in to comment.