Skip to content

Commit

Permalink
fix cursor context and add corresponding test
Browse files Browse the repository at this point in the history
  • Loading branch information
leo-schick committed Jan 31, 2023
1 parent 501f56a commit 184c181
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 4 deletions.
5 changes: 3 additions & 2 deletions mara_db/dbs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Abstract definition of database connections"""

import contextlib
import functools
import pathlib

Expand Down Expand Up @@ -37,6 +38,7 @@ def connect(self) -> object:
"""
raise NotImplementedError(f'Please implement connect for type "{self.__class__.__name__}"')

@contextlib.contextmanager
def cursor_context(self) -> object:
"""
A single iteration with a cursor context. When the iteration is
Expand All @@ -50,11 +52,10 @@ def cursor_context(self) -> object:
try:
cursor = connection.cursor()
yield cursor
connection.commit()
except Exception:
connection.rollback()
raise
else:
connection.commit()
finally:
cursor.close()
connection.close()
Expand Down
10 changes: 8 additions & 2 deletions tests/db_test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def db_replace_placeholders(db: dbs.DB, docker_ip: str, docker_port: int) -> dbs
Basic tests which can be used for different DB engines.
"""

def _test_sqlalchemy(db):
def _test_sqlalchemy(db: dbs.DB):
"""
A simple test to check if the SQLAlchemy connection works
"""
Expand All @@ -41,7 +41,7 @@ def _test_sqlalchemy(db):
# appropriately formatted for the backend
assert conn.scalar(select(1)) == 1

def _test_connect(db):
def _test_connect(db: dbs.DB):
connection = db.connect()
cursor = connection.cursor()
try:
Expand All @@ -55,3 +55,9 @@ def _test_connect(db):
finally:
cursor.close()
connection.close()

def _test_cursor_context(db: dbs.DB):
with db.cursor_context() as cursor:
cursor.execute('SELECT 1')
row = cursor.fetchone()
assert row[0] == 1
8 changes: 8 additions & 0 deletions tests/mssql/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def test_mssql_connect(mssql_db):
_test_connect(mssql_db)


def test_mssql_cursor_context(mssql_db):
"""
A simple test to check if the cursor context of the db works.
"""
from ..db_test_helper import _test_cursor_context
_test_cursor_context(mssql_db)



"""
#################################################################################################################################
Expand Down
8 changes: 8 additions & 0 deletions tests/postgres/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,11 @@ def test_postgres_connect(postgres_db):
"""
from ..db_test_helper import _test_connect
_test_connect(postgres_db)


def test_postgres_cursor_context(postgres_db):
"""
A simple test to check if the cursor context of the db works.
"""
from ..db_test_helper import _test_cursor_context
_test_cursor_context(postgres_db)
8 changes: 8 additions & 0 deletions tests/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ def test_databricks_connect():
"""
from .db_test_helper import _test_connect
_test_connect(DATABRICKS_DB)


def test_databricks_cursor_context():
"""
A simple test to check if the cursor context of the db works.
"""
from .db_test_helper import _test_cursor_context
_test_cursor_context(DATABRICKS_DB)

0 comments on commit 184c181

Please sign in to comment.