From 3305e94d7d62e32e7c6e7fdd6e5d11ed58a54051 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Tue, 17 Sep 2024 10:46:36 -0300 Subject: [PATCH] Update how we are storing overdrive session token (#2060) --- src/palace/manager/api/overdrive.py | 82 +++++------- tests/manager/api/test_overdrive.py | 186 ++++++++++++++-------------- 2 files changed, 127 insertions(+), 141 deletions(-) diff --git a/src/palace/manager/api/overdrive.py b/src/palace/manager/api/overdrive.py index 8913151e7..0c4424240 100644 --- a/src/palace/manager/api/overdrive.py +++ b/src/palace/manager/api/overdrive.py @@ -9,7 +9,7 @@ import urllib.parse from collections.abc import Iterable from threading import RLock -from typing import Any +from typing import Any, NamedTuple from urllib.parse import quote, urlsplit, urlunsplit import dateutil @@ -222,6 +222,11 @@ class OverdriveChildSettings(BaseSettings): ) +class OverdriveToken(NamedTuple): + token: str + expires: datetime.datetime + + class OverdriveAPI( PatronActivityCirculationAPI[OverdriveSettings, OverdriveLibrarySettings], CirculationInternalFormatsMixin, @@ -418,9 +423,8 @@ def __init__(self, _db, collection): self._hosts = self._determine_hosts(server_nickname=self._server_nickname) - # This is set by an access to .token, or by a call to - # check_creds() or refresh_creds(). - self._token = None + # This is set by access to .token + self._token: OverdriveToken | None = None # This is set by an access to .collection_token self._collection_token = None @@ -455,10 +459,26 @@ def endpoint(self, url: str, **kwargs) -> str: return url % kwargs @property - def token(self): - if not self._token: - self.check_creds() - return self._token + def token(self) -> str: + if (token := self._token) is not None and utc_now() < token.expires: + return token.token + + return self._refresh_token().token + + def _refresh_token(self) -> OverdriveToken: + """Get an overdrive bearer token.""" + with self.lock: + response = self.token_post( + self.TOKEN_ENDPOINT, + dict(grant_type="client_credentials"), + allowed_response_codes=[200], + ) + data = response.json() + access_token = data["access_token"] + expires_in = data["expires_in"] * 0.9 + expires = utc_now() + datetime.timedelta(seconds=expires_in) + self._token = OverdriveToken(token=access_token, expires=expires) + return self._token @property def collection_token(self): @@ -468,7 +488,6 @@ def collection_token(self): credentials are working. """ if not self._collection_token: - self.check_creds() library = self.get_library() error = library.get("errorCode") if error: @@ -511,42 +530,6 @@ def advantage_library_id(self): return self.OVERDRIVE_MAIN_ACCOUNT_ID return int(self._library_id) - def check_creds(self, force_refresh=False): - """If the Bearer Token has expired, update it.""" - with self.lock: - refresh_on_lookup = self.refresh_creds - if force_refresh: - refresh_on_lookup = lambda x: x - - credential = self.credential_object(refresh_on_lookup) - if force_refresh: - self.refresh_creds(credential) - self._token = credential.credential - - def credential_object(self, refresh): - """Look up the Credential object that allows us to use - the Overdrive API. - """ - return Credential.lookup( - self._db, - DataSource.OVERDRIVE, - None, - None, - refresh, - collection=self.collection, - ) - - def refresh_creds(self, credential): - """Fetch a new Bearer Token and update the given Credential object.""" - response = self.token_post( - self.TOKEN_ENDPOINT, - dict(grant_type="client_credentials"), - allowed_response_codes=[200], - ) - data = response.json() - self._update_credential(credential, data) - self._token = credential.credential - def get( self, url: str, extra_headers={}, exception_on_401=False ) -> tuple[int, CaseInsensitiveDict, bytes]: @@ -570,8 +553,8 @@ def get( response, ) else: - # Refresh the token and try again. - self.check_creds(True) + # Force a refresh of the token and try again. + self._refresh_token() return self.get(url, extra_headers, True) else: return status_code, headers, content @@ -834,8 +817,7 @@ def hosts(self) -> dict[str, str]: def _run_self_tests(self, _db): result = self.run_test( "Checking global Client Authentication privileges", - self.check_creds, - force_refresh=True, + self._refresh_token, ) yield result if not result.success: @@ -1061,6 +1043,8 @@ def _process_checkout_error(self, patron, pin, licensepool, error): code = error.get("errorCode", code) if code == "NoCopiesAvailable": # Clearly our info is out of date. + # TODO: This shouldn't be happening in the web thread, it should really be happening in a + # background job. No need to block the user while we update the license pool. self.update_licensepool(identifier.identifier) raise NoAvailableCopies() diff --git a/tests/manager/api/test_overdrive.py b/tests/manager/api/test_overdrive.py index b623ff685..38826eb2b 100644 --- a/tests/manager/api/test_overdrive.py +++ b/tests/manager/api/test_overdrive.py @@ -426,58 +426,63 @@ def test_401_on_get_refreshes_bearer_token( # The bearer token has been updated. assert "new bearer token" == fixture.api.token - def test_credential_refresh_success( - self, overdrive_api_fixture: OverdriveAPIFixture - ): - fixture = overdrive_api_fixture - + def test_token(self, overdrive_api_fixture: OverdriveAPIFixture): """Verify the process of refreshing the Overdrive bearer token.""" - # Perform the initial credential check. - fixture.api.check_creds() - credential = fixture.api.credential_object(lambda x: x) - assert "bearer token" == credential.credential - assert fixture.api.token == credential.credential + api = overdrive_api_fixture.api - fixture.api.access_token_response = fixture.api.mock_access_token_response( - "new bearer token" - ) + # Initially the token is None + assert len(api.access_token_requests) == 0 + + # Accessing the token triggers a refresh + assert api.token == "bearer token" + assert len(api.access_token_requests) == 1 + + # Mock the token response + api.access_token_response = api.mock_access_token_response("new bearer token") + + # Accessing the token again won't refresh, because the old token is still valid + assert api.token == "bearer token" + assert len(api.access_token_requests) == 1 - # Refresh the credentials and the token will change to - # the mocked value. - fixture.api.refresh_creds(credential) - assert "new bearer token" == credential.credential - assert fixture.api.token == credential.credential + # However if the token expires we will get a new one + assert api._token is not None + api._token = api._token._replace(expires=utc_now() - timedelta(seconds=1)) + + assert api.token == "new bearer token" + assert len(api.access_token_requests) == 2 def test_401_after_token_refresh_raises_error( self, overdrive_api_fixture: OverdriveAPIFixture ): fixture = overdrive_api_fixture + api = fixture.api - assert "bearer token" == fixture.api.token + # Our initial token value is "bearer token". + assert api.token == "bearer token" # We try to GET and receive a 401. - fixture.api.queue_response(401) + api.queue_response(401) # We refresh the bearer token. - fixture.api.access_token_response = fixture.api.mock_access_token_response( - "new bearer token" - ) + api.access_token_response = api.mock_access_token_response("new bearer token") # Then we retry the GET but we get another 401. - fixture.api.queue_response(401) - - credential = fixture.api.credential_object(lambda x: x) - fixture.api.refresh_creds(credential) + api.queue_response(401) # That raises a BadResponseException - with pytest.raises(BadResponseException) as excinfo: - fixture.api.get_library() - assert "Bad response from" in str(excinfo.value) - assert "Something's wrong with the Overdrive OAuth Bearer Token!" in str( - excinfo.value - ) + with pytest.raises( + BadResponseException, + match="Bad response from .*: Something's wrong with the Overdrive OAuth Bearer Token", + ): + api.get_library() + + # We refreshed the token in the process. + assert fixture.api.token == "new bearer token" + + # We made two requests + assert len(api.requests) == 2 - def test_401_during_refresh_raises_error( + def test_401_during__refresh_token_raises_error( self, overdrive_api_fixture: OverdriveAPIFixture ): fixture = overdrive_api_fixture @@ -486,10 +491,11 @@ def test_401_during_refresh_raises_error( raised. """ fixture.api.access_token_response = MockRequestsResponse(401, {}, "") - with pytest.raises(BadResponseException) as excinfo: - fixture.api.refresh_creds(None) - assert "Got status code 401" in str(excinfo.value) - assert "can only continue on: 200." in str(excinfo.value) + with pytest.raises( + BadResponseException, + match="Got status code 401 .* can only continue on: 200.", + ): + fixture.api._refresh_token() def test_advantage_differences(self, overdrive_api_fixture: OverdriveAPIFixture): transaction = overdrive_api_fixture.db @@ -624,36 +630,30 @@ def test__run_self_tests( # methods. db = overdrive_api_fixture.db - class Mock(MockOverdriveAPI): - "Mock every method used by OverdriveAPI._run_self_tests." - - # First we will call check_creds() to get a fresh credential. - mock_credential = object() - - def check_creds(self, force_refresh=False): - self.check_creds_called_with = force_refresh - return self.mock_credential + # Mock every method used by OverdriveAPI._run_self_tests. + api = MockOverdriveAPI(db.session, overdrive_api_fixture.collection) - # Then we will call get_advantage_accounts(). - mock_advantage_accounts = [object(), object()] + # First we will call get_token + mock_refresh_token = create_autospec(api._refresh_token) + api._refresh_token = mock_refresh_token - def get_advantage_accounts(self): - return self.mock_advantage_accounts - - # Then we will call get() on the _all_products_link. - def get(self, url, extra_headers, exception_on_401=False): - self.get_called_with = (url, extra_headers, exception_on_401) - return 200, {}, json.dumps(dict(totalItems=2010)) + # Then we will call get_advantage_accounts(). + mock_get_advantage_accounts = create_autospec( + api.get_advantage_accounts, return_value=[object(), object()] + ) + api.get_advantage_accounts = mock_get_advantage_accounts - # Finally, for every library associated with this - # collection, we'll call get_patron_credential() using - # the credentials of that library's test patron. - mock_patron_credential = object() - get_patron_credential_called_with = [] + # Then we will call get() on the _all_products_link. + mock_get = create_autospec( + api.get, return_value=(200, {}, json.dumps(dict(totalItems=2010))) + ) + api.get = mock_get - def get_patron_credential(self, patron, pin): - self.get_patron_credential_called_with.append((patron, pin)) - return self.mock_patron_credential + # Finally, for every library associated with this + # collection, we'll call get_patron_credential() using + # the credentials of that library's test patron. + mock_get_patron_credential = create_autospec(api.get_patron_credential) + api.get_patron_credential = mock_get_patron_credential # Now let's make sure two Libraries have access to this # Collection -- one library with a default patron and one @@ -665,7 +665,6 @@ def get_patron_credential(self, patron, pin): db.simple_auth_integration(with_default_patron) # Now that everything is set up, run the self-test. - api = Mock(db.session, overdrive_api_fixture.collection) results = sorted(api._run_self_tests(db.session), key=lambda x: x.name) [ no_patron_credential, @@ -678,42 +677,45 @@ def get_patron_credential(self, patron, pin): # Verify that each test method was called and returned the # expected SelfTestResult object. assert ( - "Checking global Client Authentication privileges" == global_privileges.name + global_privileges.name == "Checking global Client Authentication privileges" ) - assert True == global_privileges.success - assert api.mock_credential == global_privileges.result + assert global_privileges.success is True + assert global_privileges.result == mock_refresh_token.return_value - assert "Looking up Overdrive Advantage accounts" == advantage.name - assert True == advantage.success - assert "Found 2 Overdrive Advantage account(s)." == advantage.result + assert advantage.name == "Looking up Overdrive Advantage accounts" + assert advantage.success is True + assert advantage.result == "Found 2 Overdrive Advantage account(s)." + mock_get_advantage_accounts.assert_called_once() - assert "Counting size of collection" == collection_size.name - assert True == collection_size.success - assert "2010 item(s) in collection" == collection_size.result - url, headers, error_on_401 = api.get_called_with - assert api._all_products_link == url + assert collection_size.name == "Counting size of collection" + assert collection_size.success is True + assert collection_size.result == "2010 item(s) in collection" + mock_get.assert_called_once_with(api._all_products_link, {}) assert ( - "Acquiring test patron credentials for library %s" % no_default_patron.name - == no_patron_credential.name + no_patron_credential.name + == f"Acquiring test patron credentials for library {no_default_patron.name}" ) - assert False == no_patron_credential.success - assert "Library has no test patron configured." == str( - no_patron_credential.exception + assert no_patron_credential.success is False + assert ( + str(no_patron_credential.exception) + == "Library has no test patron configured." ) assert ( - "Checking Patron Authentication privileges, using test patron for library %s" - % with_default_patron.name - == default_patron_credential.name + default_patron_credential.name + == f"Checking Patron Authentication privileges, using test patron for library {with_default_patron.name}" + ) + assert default_patron_credential.success is True + assert ( + default_patron_credential.result == mock_get_patron_credential.return_value ) - assert True == default_patron_credential.success - assert api.mock_patron_credential == default_patron_credential.result # Although there are two libraries associated with this # collection, get_patron_credential was only called once, because # one of the libraries doesn't have a default patron. - [(patron1, password1)] = api.get_patron_credential_called_with + mock_get_patron_credential.assert_called_once() + (patron1, password1) = mock_get_patron_credential.call_args.args assert "username1" == patron1.authorization_identifier assert "password1" == password1 @@ -727,16 +729,16 @@ def test_run_self_tests_short_circuit( work we won't be able to instantiate the OverdriveAPI class. """ - def explode(*args, **kwargs): - raise Exception("Failure!") - - overdrive_api_fixture.api.check_creds = explode + api = overdrive_api_fixture.api + api._refresh_token = create_autospec( + api._refresh_token, side_effect=Exception("Failure!") + ) # Only one test will be run. [check_creds] = overdrive_api_fixture.api._run_self_tests( overdrive_api_fixture.db.session ) - assert "Failure!" == str(check_creds.exception) + assert str(check_creds.exception) == "Failure!" def test_default_notification_email_address( self,