Skip to content

Commit

Permalink
Update how we are storing overdrive session token (#2060)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathangreen committed Sep 17, 2024
1 parent 0c3ae76 commit 3305e94
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 141 deletions.
82 changes: 33 additions & 49 deletions src/palace/manager/api/overdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import urllib.parse
from collections.abc import Iterable
from threading import RLock
from typing import Any
from typing import Any, NamedTuple
from urllib.parse import quote, urlsplit, urlunsplit

import dateutil
Expand Down Expand Up @@ -222,6 +222,11 @@ class OverdriveChildSettings(BaseSettings):
)


class OverdriveToken(NamedTuple):
token: str
expires: datetime.datetime


class OverdriveAPI(
PatronActivityCirculationAPI[OverdriveSettings, OverdriveLibrarySettings],
CirculationInternalFormatsMixin,
Expand Down Expand Up @@ -418,9 +423,8 @@ def __init__(self, _db, collection):

self._hosts = self._determine_hosts(server_nickname=self._server_nickname)

# This is set by an access to .token, or by a call to
# check_creds() or refresh_creds().
self._token = None
# This is set by access to .token
self._token: OverdriveToken | None = None

# This is set by an access to .collection_token
self._collection_token = None
Expand Down Expand Up @@ -455,10 +459,26 @@ def endpoint(self, url: str, **kwargs) -> str:
return url % kwargs

@property
def token(self):
if not self._token:
self.check_creds()
return self._token
def token(self) -> str:
if (token := self._token) is not None and utc_now() < token.expires:
return token.token

return self._refresh_token().token

def _refresh_token(self) -> OverdriveToken:
"""Get an overdrive bearer token."""
with self.lock:
response = self.token_post(
self.TOKEN_ENDPOINT,
dict(grant_type="client_credentials"),
allowed_response_codes=[200],
)
data = response.json()
access_token = data["access_token"]
expires_in = data["expires_in"] * 0.9
expires = utc_now() + datetime.timedelta(seconds=expires_in)
self._token = OverdriveToken(token=access_token, expires=expires)
return self._token

@property
def collection_token(self):
Expand All @@ -468,7 +488,6 @@ def collection_token(self):
credentials are working.
"""
if not self._collection_token:
self.check_creds()
library = self.get_library()
error = library.get("errorCode")
if error:
Expand Down Expand Up @@ -511,42 +530,6 @@ def advantage_library_id(self):
return self.OVERDRIVE_MAIN_ACCOUNT_ID
return int(self._library_id)

def check_creds(self, force_refresh=False):
"""If the Bearer Token has expired, update it."""
with self.lock:
refresh_on_lookup = self.refresh_creds
if force_refresh:
refresh_on_lookup = lambda x: x

credential = self.credential_object(refresh_on_lookup)
if force_refresh:
self.refresh_creds(credential)
self._token = credential.credential

def credential_object(self, refresh):
"""Look up the Credential object that allows us to use
the Overdrive API.
"""
return Credential.lookup(
self._db,
DataSource.OVERDRIVE,
None,
None,
refresh,
collection=self.collection,
)

def refresh_creds(self, credential):
"""Fetch a new Bearer Token and update the given Credential object."""
response = self.token_post(
self.TOKEN_ENDPOINT,
dict(grant_type="client_credentials"),
allowed_response_codes=[200],
)
data = response.json()
self._update_credential(credential, data)
self._token = credential.credential

def get(
self, url: str, extra_headers={}, exception_on_401=False
) -> tuple[int, CaseInsensitiveDict, bytes]:
Expand All @@ -570,8 +553,8 @@ def get(
response,
)
else:
# Refresh the token and try again.
self.check_creds(True)
# Force a refresh of the token and try again.
self._refresh_token()
return self.get(url, extra_headers, True)
else:
return status_code, headers, content
Expand Down Expand Up @@ -834,8 +817,7 @@ def hosts(self) -> dict[str, str]:
def _run_self_tests(self, _db):
result = self.run_test(
"Checking global Client Authentication privileges",
self.check_creds,
force_refresh=True,
self._refresh_token,
)
yield result
if not result.success:
Expand Down Expand Up @@ -1061,6 +1043,8 @@ def _process_checkout_error(self, patron, pin, licensepool, error):
code = error.get("errorCode", code)
if code == "NoCopiesAvailable":
# Clearly our info is out of date.
# TODO: This shouldn't be happening in the web thread, it should really be happening in a
# background job. No need to block the user while we update the license pool.
self.update_licensepool(identifier.identifier)
raise NoAvailableCopies()

Expand Down
Loading

0 comments on commit 3305e94

Please sign in to comment.