Skip to content

Commit

Permalink
Merge pull request #201 from octoenergy/add-sqlalchemy-2-support
Browse files Browse the repository at this point in the history
Add sqlalchemy 2 support
  • Loading branch information
matt-fleming authored Apr 6, 2023
2 parents 28b8eb6 + f34d6a0 commit 0b93676
Show file tree
Hide file tree
Showing 18 changed files with 717 additions and 588 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.2.0] - 2023-04-06
### Added
- Add support for SQLAlchemy 2.0 by correctly handling `sqlalchemy.engine.url.URL`s, which are now immutable.
Older versions of tentaclio should pin `sqlalchemy < 2.0` to avoid this issue.

## [1.1.0] - 2023-04-05
### Changed
- Credential files error reporting to help users identify the credentials issues
Expand Down
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ functional-sftp:
pipenv run pytest tests/functional/sftp

format:
black -l 99 src
black -l 99 tests
isort -rc src
isort -rc tests
pipenv run black -l 99 src
pipenv run black -l 99 tests
pipenv run isort src
pipenv run isort tests

# Deployment

Expand Down
3 changes: 0 additions & 3 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ python_version = "3.9"
[dev-packages]
# Symlink to project root
tentaclio = {editable = true,path = "."}

# Linting
black = "*"
isort = "*"
Expand All @@ -18,13 +17,11 @@ mypy = "*"
pydocstyle = "*"
types-pyyaml = "*"
types-requests = "*"

# Testing
moto = "*"
pytest = "*"
pytest-cov = "*"
pytest-mock = "*"

# Releasing
twine = "*"
secretstorage = "*"
1,186 changes: 648 additions & 538 deletions Pipfile.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from setuptools.command.install import install


VERSION = "1.1.0"
VERSION = "1.2.0"

REPO_ROOT = pathlib.Path(__file__).parent

Expand Down
6 changes: 3 additions & 3 deletions src/tentaclio/clients/ftp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _connect(self) -> ftplib.FTP:
# Stream methods:

@decorators.check_conn
def get(self, writer: protocols.ByteWriter, file_path: str = None) -> None:
def get(self, writer: protocols.ByteWriter, file_path: Optional[str] = None) -> None:
"""Write the contents of a remote file into the passed writer.
Arguments:
Expand Down Expand Up @@ -162,7 +162,7 @@ def _connect(self) -> pysftp.Connection:
# Stream methods:

@decorators.check_conn
def get(self, writer: protocols.ByteWriter, file_path: str = None) -> None:
def get(self, writer: protocols.ByteWriter, file_path: Optional[str] = None) -> None:
"""Write the contents of a remote file into the passed writer.
Arguments:
Expand All @@ -180,7 +180,7 @@ def get(self, writer: protocols.ByteWriter, file_path: str = None) -> None:
self.conn.getfo(remote_path, writer)

@decorators.check_conn
def put(self, reader: protocols.ByteReader, file_path: str = None) -> None:
def put(self, reader: protocols.ByteReader, file_path: Optional[str] = None) -> None:
"""Write the contents of the reader into the remote file.
Arguments:
Expand Down
25 changes: 15 additions & 10 deletions src/tentaclio/clients/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ class HTTPClient(base_client.BaseClient["HTTPClient"]):
endpoint: str

def __init__(
self, url: str, default_timeout: float = None, default_headers: dict = None
self,
url: str,
default_timeout: Optional[float] = None,
default_headers: Optional[dict] = None,
) -> None:
"""Create a new http/https client based on the passed url and extra params."""
# Default connection timeout at 10''
Expand Down Expand Up @@ -81,9 +84,9 @@ def _connect(self) -> requests.Session:
def get(
self,
writer: protocols.ByteWriter,
endpoint: str = None,
params: dict = None,
options: dict = None,
endpoint: Optional[str] = None,
params: Optional[dict] = None,
options: Optional[dict] = None,
) -> None:
"""Read the contents from the url and write them into the provided writer.
Expand All @@ -103,9 +106,9 @@ def get(
def put(
self,
reader: protocols.ByteReader,
endpoint: str = None,
params: dict = None,
options: dict = None,
endpoint: Optional[str] = None,
params: Optional[dict] = None,
options: Optional[dict] = None,
) -> None:
"""Write the contents of the provided reader into the url using POST.
Expand All @@ -132,8 +135,8 @@ def _build_request(
self,
method: str,
url: str,
default_data: protocols.Reader = None,
default_params: dict = None,
default_data: Optional[protocols.Reader] = None,
default_params: Optional[dict] = None,
):
data: Union[protocols.Reader, list] = default_data or []
params = default_params or {}
Expand All @@ -149,7 +152,9 @@ def _build_request(

return self.conn.prepare_request(request)

def _send_request(self, request: requests.PreparedRequest, default_options: dict = None):
def _send_request(
self, request: requests.PreparedRequest, default_options: Optional[dict] = None
):
options = default_options or {}

response = self.conn.send(
Expand Down
22 changes: 13 additions & 9 deletions src/tentaclio/clients/sqla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from typing import Container, Generator, Optional, Union

import pandas as pd
from sqlalchemy import text
from sqlalchemy.engine import Connection, CursorResult, Engine, create_engine
from sqlalchemy.engine import url as sqla_url
from sqlalchemy.engine.url import URL as sqla_url
from sqlalchemy.orm import session, sessionmaker
from sqlalchemy.sql.schema import MetaData

Expand All @@ -20,7 +21,7 @@
__all__ = ["SQLAlchemyClient", "bound_session", "atomic_session"]


SessionGenerator = Generator[None, session.Session, None]
SessionGenerator = Generator[session.Session, None, None]


class _TrueContainer(Container[str]):
Expand Down Expand Up @@ -57,7 +58,10 @@ class SQLAlchemyClient(base_client.BaseClient["SQLAlchemyClient"]):
port: Optional[int]

def __init__(
self, url: Union[str, urls.URL], execution_options: dict = None, connect_args: dict = None
self,
url: Union[str, urls.URL],
execution_options: Optional[dict] = None,
connect_args: Optional[dict] = None,
) -> None:
"""Create sqlalchemy client based on the passed url.
Expand All @@ -84,16 +88,16 @@ def _extract_url_params(self) -> None:
# Connection methods:

def _connect(self) -> Connection:

parsed_url = sqla_url.URL(
parsed_url = sqla_url.create(
drivername=self.drivername,
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database,
query=self.url_query,
)
if self.url.query_string:
parsed_url.update_query_string(self.url.query_string)
if self.engine is None:
self.engine = create_engine(
parsed_url,
Expand Down Expand Up @@ -124,14 +128,14 @@ def query(self, sql_query: str, **kwargs) -> CursorResult:
This will not commit any changes to the database.
"""
return self.conn.execute(sql_query, **kwargs)
return self.conn.execute(text(sql_query), **kwargs)

@decorators.check_conn
def execute(self, sql_query: str, **kwargs) -> None:
"""Execute a raw SQL query command."""
trans = self.conn.begin()
try:
self.conn.execute(sql_query, **kwargs)
self.conn.execute(text(sql_query), **kwargs)
except Exception:
trans.rollback()
raise
Expand All @@ -141,7 +145,7 @@ def execute(self, sql_query: str, **kwargs) -> None:
# Dataframe methods:

@decorators.check_conn
def get_df(self, sql_query: str, params: dict = None, **kwargs) -> pd.DataFrame:
def get_df(self, sql_query: str, params: Optional[dict] = None, **kwargs) -> pd.DataFrame:
"""Run a raw SQL query and return a data frame."""
return pd.read_sql(sql_query, self.conn, params=params, **kwargs)

Expand Down
12 changes: 10 additions & 2 deletions src/tentaclio/hooks/slack_hook.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Slack http hook."""
import io
import json
from typing import Optional

from tentaclio.clients import http_client

Expand All @@ -16,7 +17,11 @@ def __init__(self, url: str) -> None:
self.url = url

def notify(
self, message: str, channel: str = None, username: str = None, icon_emoji: str = None
self,
message: str,
channel: Optional[str] = None,
username: Optional[str] = None,
icon_emoji: Optional[str] = None,
) -> None:
"""Send a notification to slack."""
body = self._build_request_body(
Expand All @@ -31,7 +36,10 @@ def notify(

@staticmethod
def _build_request_body(
message: str, channel: str = None, username: str = None, icon_emoji: str = None
message: str,
channel: Optional[str] = None,
username: Optional[str] = None,
icon_emoji: Optional[str] = None,
) -> str:
# Fetch message payload
payload = dict(text=message)
Expand Down
4 changes: 2 additions & 2 deletions src/tentaclio/streams/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Main entry points to tentaclio-io."""
from typing import ContextManager, Union
from typing import ContextManager, Optional, Union

from tentaclio import protocols
from tentaclio.credentials import authenticate
Expand All @@ -17,7 +17,7 @@
]


def open(url: str, mode: str = None, **kwargs) -> AnyContextStreamerReaderWriter:
def open(url: str, mode: Optional[str] = None, **kwargs) -> AnyContextStreamerReaderWriter:
"""Open a url and return a reader or writer depending on mode.
Arguments:
Expand Down
21 changes: 14 additions & 7 deletions src/tentaclio/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class URL:
_port: Optional[int] = None
_path: str
_query: Optional[Dict[str, str]] = None
_query_string: Optional[str] = None

def __init__(self, url: str) -> None:
"""Create a url by parsing the parametre."""
Expand All @@ -57,6 +58,7 @@ def _parse_url(self) -> None:
self._hostname = parsed_url.hostname
self._port = parsed_url.port
self._path = parsed_url.path
self._query_string = parsed_url.query

# Replace %xx escapes - ONLY for username & password
if parsed_url.username and self._username:
Expand All @@ -74,13 +76,13 @@ def _parse_url(self) -> None:

def copy(
self,
scheme: str = None,
username: str = None,
password: str = None,
hostname: str = None,
port: int = None,
path: str = None,
query: str = None,
scheme: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
hostname: Optional[str] = None,
port: Optional[int] = None,
path: Optional[str] = None,
query: Optional[str] = None,
) -> "URL":
"""Copy this url optionally overwriting the provided components."""
return URL.from_components(
Expand Down Expand Up @@ -168,6 +170,11 @@ def query(self) -> Optional[Dict[str, str]]:
"""Access the query."""
return self._query

@property
def query_string(self) -> Optional[str]:
"""Access the query string."""
return self._query_string

@property
def url(self) -> str:
"""Return the original url."""
Expand Down
1 change: 0 additions & 1 deletion tests/unit/clients/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def test_inactive_client_connection(self):
url = "file:///path"

class TestClient(base_client.BaseClient):

allowed_schemes = ["file"]

def connect(self):
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/clients/test_ftp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_parsing_ftp_url(self, url, username, password, hostname, port, path):
@pytest.mark.parametrize("url,path", [("ftp://:@localhost", None)])
def test_get_invalid_path(self, url, path, mocked_ftp_conn):
with ftp_client.FTPClient(url) as client:

with pytest.raises(exceptions.FTPError):
client.get(io.StringIO(), file_path=path)

Expand Down Expand Up @@ -149,7 +148,6 @@ def test_invalid_scheme(self, url):
@pytest.mark.parametrize("url,path", [("sftp://:@localhost", None)])
def test_get_invalid_path(self, url, path, mocked_sftp_conn):
with ftp_client.SFTPClient(url) as client:

with pytest.raises(exceptions.FTPError):
client.get(io.StringIO(), file_path=path)

Expand Down
2 changes: 0 additions & 2 deletions tests/unit/clients/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_get_invalid_endpoint(self, url, path, mocker, mocked_http_conn):
mocked_writer = mocker.Mock()

with pytest.raises(exceptions.HTTPError):

with http_client.HTTPClient(url) as client:
client.get(mocked_writer, endpoint=path)

Expand All @@ -73,7 +72,6 @@ def test_get_invalid_endpoint(self, url, path, mocker, mocked_http_conn):
)
def test_fetching_url_endpoint(self, base_url, endpoint, auth, full_url):
with http_client.HTTPClient(base_url) as client:

assert client.conn.auth == auth
assert client._fetch_url(endpoint) == full_url

Expand Down
1 change: 0 additions & 1 deletion tests/unit/credentials/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@


def test_authenticate(mocker):

injector = credentials.CredentialsInjector()
injector.register_credentials(urls.URL("ftp://user:[email protected]"))
mock_cred = mocker.patch("tentaclio.credentials.api.load_credentials_injector")
Expand Down
1 change: 0 additions & 1 deletion tests/unit/streams/test_csv_db_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def csv_data():


def test_dump_csv(csv_data, csv_dumper):

with csv_db_stream.DatabaseCsvWriter(csv_dumper, "my_table") as writer:
writer.write(csv_data.getvalue())

Expand Down
3 changes: 2 additions & 1 deletion tests/unit/streams/test_stream_handler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import io
from typing import Optional

import pytest

Expand All @@ -9,7 +10,7 @@

class FakeClient(base_client.BaseClient["FakeClient"]):
# clients only understand bytes
def __init__(self, url: URL, message: bytearray = None, *args, **kwargs):
def __init__(self, url: URL, message: Optional[bytearray] = None, *args, **kwargs):
self._writer = io.BytesIO()
self._message = message or bytes("hello", encoding="utf-8")
self._closed = False
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class TestURL:

# Generic parsing rules:
def test_missing_url(self):
with pytest.raises(urls.URLError):
Expand Down

0 comments on commit 0b93676

Please sign in to comment.