Skip to content

Commit

Permalink
fix: client now correctly configured?
Browse files Browse the repository at this point in the history
  • Loading branch information
sg-s committed Nov 25, 2024
1 parent f126e29 commit ff0388b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def authenticate() -> dict:
return tokens


@beartype
def refresh_tokens(api_refresh_token: str) -> str:
"""Refresh the access token for the Deep Origin OS
Expand Down
58 changes: 30 additions & 28 deletions src/data_hub/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path

from beartype import beartype
from beartype.typing import Optional
from box import Box
from deeporigin import auth
from deeporigin.exceptions import DeepOriginException
Expand Down Expand Up @@ -67,7 +68,7 @@ def _get_client_methods() -> set:

def _get_default_client(
*,
client=None,
access_token: Optional[str] = None,
refresh: bool = True,
use_async: bool = False,
):
Expand All @@ -84,35 +85,36 @@ def _get_default_client(
"""
if client is None:

if access_token is None:
tokens = auth.get_tokens(refresh=refresh)
access_token = tokens["access"]

import httpx
from deeporigin.config import get_value
import httpx
from deeporigin.config import get_value

value = get_value()
value = get_value()

org_id = value["organization_id"]
base_url = httpx.URL.join(
value["api_endpoint"],
value["nucleus_api_route"],
)
org_id = value["organization_id"]
base_url = httpx.URL.join(
value["api_endpoint"],
value["nucleus_api_route"],
)

if use_async:
from deeporigin_data import AsyncDeeporiginData
if use_async:
from deeporigin_data import AsyncDeeporiginData

client = AsyncDeeporiginData(
token=access_token,
org_id=org_id,
base_url=base_url,
)
else:
client = DeeporiginData(
token=access_token,
org_id=org_id,
base_url=base_url,
).with_raw_response
client = AsyncDeeporiginData(
token=access_token,
org_id=org_id,
base_url=base_url,
)
else:
client = DeeporiginData(
token=access_token,
org_id=org_id,
base_url=base_url,
).with_raw_response

return client

Expand Down Expand Up @@ -148,9 +150,13 @@ def dynamic_function(
except AuthenticationError as error:
if "expired token" in error.message:
print("⚠️ Token expired. Refreshing credentials...")
tokens = auth.read_cached_tokens()

tokens = auth.get_tokens(refresh=False)
tokens["access"] = auth.refresh_tokens(tokens["refresh"])

# create a new client with the new access token
client = _get_default_client(access_token=tokens["access"])

# cache to disk
auth.cache_tokens(tokens)

Expand All @@ -159,10 +165,6 @@ def dynamic_function(
# token from disk
auth.get_tokens.cache_clear()

# configure the client to use the new access
# token
client.token = tokens["access"]

method = _get_method(client, method_path)
response = method(**kwargs)
else:
Expand Down

0 comments on commit ff0388b

Please sign in to comment.