From 0d615fcec6854a8a9546a3b8eb6080cde605d786 Mon Sep 17 00:00:00 2001 From: Srinivas Gorur-Shandilya Date: Thu, 11 Apr 2024 09:18:26 -0400 Subject: [PATCH] refactor(managed-data): switch to new nucleus_api_route --- src/config/__init__.py | 14 +++++++++++--- src/config/default.yml | 7 ++++--- src/do_api.py | 1 - src/managed_data/_api.py | 3 ++- src/utils.py | 14 ++++++++++++++ tests/test_low_level.py | 4 +++- 6 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/config/__init__.py b/src/config/__init__.py index c480e81f..0a5b6b6f 100644 --- a/src/config/__init__.py +++ b/src/config/__init__.py @@ -8,6 +8,13 @@ from ..exceptions import DeepOriginException CONFIG_DIR = pathlib.Path(__file__).parent +CONFIG_YML_LOCATION = os.path.expanduser( + os.path.join( + "~", + ".deep-origin", + "config.yml", + ) +) __all__ = ["get_value"] @@ -15,7 +22,7 @@ @functools.cache def get_value( user_config_filenames: collections.abc.Iterable[str] = ( - os.path.expanduser(os.path.join("~", ".deep-origin", "config.yml")), + CONFIG_YML_LOCATION, os.path.join(".deep-origin", "config.yml"), ), ) -> confuse.templates.AttrDict: @@ -47,8 +54,9 @@ def get_value( "organization_id": confuse.String(), "bench_id": confuse.String(), "env": confuse.String(), - "nucleus_api_endpoint": confuse.String(), - "api_endpoint": confuse.String(), + "api_endpoint": confuse.Optional(confuse.String()), + "nucleus_api_route": confuse.String(), + "graphql_api_route": confuse.String(), "auth_domain": confuse.String(), "auth_device_code_endpoint": confuse.String(), "auth_token_endpoint": confuse.String(), diff --git a/src/config/default.yml b/src/config/default.yml index 3d9a8a6e..b18cc717 100644 --- a/src/config/default.yml +++ b/src/config/default.yml @@ -2,10 +2,11 @@ env: local organization_id: null bench_id: null api_endpoint: null -nucleus_api_endpoint: null +nucleus_api_route: nucleus-api/api/ +graphql_api_route: api/graphql/ auth_domain: null -auth_device_code_endpoint: /oauth/device/code -auth_token_endpoint: /oauth/token +auth_device_code_endpoint: oauth/device/code/ +auth_token_endpoint: oauth/token/ auth_audience: null auth_grant_type: urn:ietf:params:oauth:grant-type:device_code auth_client_id: null diff --git a/src/do_api.py b/src/do_api.py index ea7358fb..e3bc584d 100644 --- a/src/do_api.py +++ b/src/do_api.py @@ -31,7 +31,6 @@ def get_do_api_tokens() -> tuple[str, str]: config = get_config() if os.path.isfile(config.api_tokens_filename): - print(f"file exists @ {config.api_tokens_filename}") tokens = read_cached_do_api_tokens() refresh_token = tokens["refresh"] diff --git a/src/managed_data/_api.py b/src/managed_data/_api.py index 27f9476a..d3a915e7 100644 --- a/src/managed_data/_api.py +++ b/src/managed_data/_api.py @@ -9,9 +9,10 @@ from deeporigin import cache_do_api_tokens, get_do_api_tokens from deeporigin.config import get_value from deeporigin.exceptions import DeepOriginException +from deeporigin.utils import _nucleus_url -API_URL = get_value()["nucleus_api_endpoint"] ORG_ID = get_value()["organization_id"] +API_URL = _nucleus_url() @beartype diff --git a/src/utils.py b/src/utils.py index 363fee4a..3bc873f0 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,4 +1,8 @@ import os +from urllib.parse import urljoin + +from beartype import beartype +from deeporigin.config import get_value __all__ = [ "expand_user", @@ -8,6 +12,16 @@ PREFIX = "deeporigin://" +@beartype +def _nucleus_url() -> str: + """returns URL for nucleus API endpoint""" + return urljoin( + get_value()["api_endpoint"], + get_value()["nucleus_api_route"], + ) + + +@beartype def expand_user(path, user_home_dirname: str = os.path.expanduser("~")) -> str: """Expand paths that start with `~` by replacing it the user's home directory diff --git a/tests/test_low_level.py b/tests/test_low_level.py index 174c56ec..d1fdaec2 100644 --- a/tests/test_low_level.py +++ b/tests/test_low_level.py @@ -8,7 +8,9 @@ from deeporigin.config import get_value from deeporigin.exceptions import DeepOriginException from deeporigin.managed_data import _api, api +from deeporigin.utils import _nucleus_url +API_URL = _nucleus_url() # constants row_description_keys = { "id", @@ -62,7 +64,7 @@ # if we're running against a real instance, determine # if the database has any files in it -if get_value()["nucleus_api_endpoint"] != MOCK_URL: +if API_URL != MOCK_URL: df = api.get_dataframe(DB_NAME) file_ids = df.attrs["file_ids"] if len(file_ids) == 0: