diff --git a/src/palace_tools/cli/download_feed.py b/src/palace_tools/cli/download_feed.py index 62d2017..49d7685 100644 --- a/src/palace_tools/cli/download_feed.py +++ b/src/palace_tools/cli/download_feed.py @@ -6,8 +6,7 @@ import typer import xmltodict -from palace_tools.feeds import axis, opds, overdrive -from palace_tools.feeds.opds import write_json +from palace_tools.feeds import axis, opds, opds1, overdrive from palace_tools.utils.typer import run_typer_app_as_main app = typer.Typer() @@ -102,7 +101,23 @@ def download_opds( """Download OPDS 2 feed.""" publications = opds.fetch(url, username, password, authentication) with output_file.open("w") as file: - write_json(file, publications) + opds.write_json(file, publications) + + +@app.command("opds1") +def download_opds1( + username: str = typer.Option(None, "--username", "-u", help="Username"), + password: str = typer.Option(None, "--password", "-p", help="Password"), + authentication: opds.AuthType = typer.Option( + opds.AuthType.NONE, "--auth", "-a", help="Authentication type" + ), + url: str = typer.Argument(..., help="URL of feed", metavar="URL"), + output_file: Path = typer.Argument( + ..., help="Output file", writable=True, file_okay=True, dir_okay=False + ), +) -> None: + """Download OPDS 1.x feed.""" + opds1.fetch(url, username, password, authentication, output_file) def main() -> None: diff --git a/src/palace_tools/feeds/opds.py b/src/palace_tools/feeds/opds.py index e80b96d..5a0ce3e 100644 --- a/src/palace_tools/feeds/opds.py +++ b/src/palace_tools/feeds/opds.py @@ -4,9 +4,9 @@ import math import sys from base64 import b64encode -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping from enum import Enum -from typing import Any, TextIO +from typing import Any, NamedTuple, TextIO import httpx from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn @@ -18,6 +18,12 @@ class AuthType(Enum): NONE = "none" +class OpdsLinkTuple(NamedTuple): + type: str + href: str + rel: str + + class OAuthAuth(httpx.Auth): # Implementation of OPDS auth document OAuth client credentials flow for httpx # See: @@ -26,40 +32,49 @@ class OAuthAuth(httpx.Auth): requires_response_body = True - def __init__(self, username: str, password: str) -> None: + def __init__( + self, + username: str, + password: str, + *, + feed_url: str, + parse_links: Callable[[str], dict[str, OpdsLinkTuple]] | None = None, + ) -> None: self.username = username self.password = password + self.feed_url = feed_url + self.parse_links = parse_links + self.token: str | None = None + self.oauth_url: str | None = None @staticmethod - def _get_oauth_url_from_auth_document( - url: str, auth_document: Mapping[str, Any] - ) -> str: + def _get_oauth_url_from_auth_document(auth_document: Mapping[str, Any]) -> str: auth_types: list[dict[str, Any]] = auth_document.get("authentication", []) - oauth_authentication = [ - tlinks - for t in auth_types - if t.get("type") == "http://opds-spec.org/auth/oauth/client_credentials" - and (tlinks := t.get("links")) is not None - ] - if not oauth_authentication: - print(f"Unable to find supported authentication type ({url})") - print(f"Auth document: {json.dumps(auth_document)}") + try: + [links] = [ + tlinks + for t in auth_types + if t.get("type") == "http://opds-spec.org/auth/oauth/client_credentials" + and (tlinks := t.get("links")) is not None + ] + except (ValueError, TypeError): + print("Unable to find supported authentication type") + print(f"Auth document: {json.dumps(auth_document, indent=2)}") sys.exit(-1) - links = oauth_authentication[0] - auth_links: list[str] = [ - lhref - for l in links - if l.get("rel") == "authenticate" and (lhref := l.get("href")) is not None - ] - if len(auth_links) != 1: - print(f"Unable to find valid authentication link ({url})") - print( - f"Found {len(auth_links)} authentication links. Auth document: {json.dumps(auth_document)}" - ) + try: + [auth_link] = [ + lhref + for l in links + if l.get("rel") == "authenticate" + and (lhref := l.get("href")) is not None + ] + except (ValueError, TypeError): + print("Unable to find valid authentication link") + print(f"Auth document: {json.dumps(auth_document, indent=2)}") sys.exit(-1) - return auth_links[0] + return auth_link # type: ignore[no-any-return] @staticmethod def _oauth_token_request(url: str, username: str, password: str) -> httpx.Request: @@ -70,43 +85,84 @@ def _oauth_token_request(url: str, username: str, password: str) -> httpx.Reques "POST", url, headers=headers, data={"grant_type": "client_credentials"} ) + def refresh_auth_url(self) -> Generator[httpx.Request, httpx.Response, None]: + response = yield httpx.Request("GET", self.feed_url) + if response.status_code == 200 and self.parse_links is not None: + links = self.parse_links(response.text) + auth_doc_url = links.get("http://opds-spec.org/auth/document") + if auth_doc_url is None: + print("No auth document link found") + print(links) + sys.exit(-1) + auth_doc_response = yield httpx.Request("GET", auth_doc_url.href) + if auth_doc_response.status_code != 200: + error_and_exit(auth_doc_response) + elif response.status_code == 401: + auth_doc_response = response + else: + error_and_exit(response) + + if ( + auth_doc_response.headers.get("Content-Type") + != "application/vnd.opds.authentication.v1.0+json" + ): + error_and_exit(auth_doc_response, "Invalid content type") + + self.oauth_url = self._get_oauth_url_from_auth_document( + auth_doc_response.json() + ) + + def refresh_token(self) -> Generator[httpx.Request, httpx.Response, None]: + if self.oauth_url is None: + yield from self.refresh_auth_url() + + # This should never happen, but we assert for sanity and mypy + assert self.oauth_url is not None + + response = yield self._oauth_token_request( + self.oauth_url, self.username, self.password + ) + if response.status_code != 200: + error_and_exit(response) + if (access_token := response.json().get("access_token")) is None: + print("No access token in response") + print(response.text) + sys.exit(-1) + self.token = access_token + def auth_flow( self, request: httpx.Request ) -> Generator[httpx.Request, httpx.Response, None]: - if self.token is not None: - request.headers["Authorization"] = f"Bearer {self.token}" + token_refreshed = False + if self.oauth_url is None or self.token is None: + yield from self.refresh_token() + token_refreshed = True + + # This should never happen, but we assert it for mypy and our sanity + assert self.token is not None + + request.headers["Authorization"] = f"Bearer {self.token}" response = yield request - if ( - response.status_code == 401 - and response.headers.get("Content-Type") - == "application/vnd.opds.authentication.v1.0+json" - ): - oauth_url = self._get_oauth_url_from_auth_document( - str(request.url), response.json() - ) - response = yield self._oauth_token_request( - oauth_url, self.username, self.password - ) - if response.status_code != 200: - print(f"Error: {response.status_code}") - print(response.text) - sys.exit(-1) - if (access_token := response.json().get("access_token")) is None: - print("No access token in response") - print(response.text) - sys.exit(-1) - self.token = access_token + + if response.status_code == 401 and not token_refreshed: + yield from self.refresh_token() request.headers["Authorization"] = f"Bearer {self.token}" yield request +def error_and_exit(response: httpx.Response, detail: str = "") -> None: + print(f"Error: {detail}") + print(f"Request: {response.request.method} {response.request.url}") + print(f"Status code: {response.status_code}") + print(f"Headers: {json.dumps(dict(response.headers), indent=4)}") + print(f"Body: {response.text}") + sys.exit(-1) + + def make_request(session: httpx.Client, url: str) -> dict[str, Any]: response = session.get(url) if response.status_code != 200: - print(f"Error: {response.status_code}") - print(f"Headers: {json.dumps(dict(response.headers), indent=4)}") - print(response.text) - sys.exit(-1) + error_and_exit(response) return response.json() # type: ignore[no-any-return] @@ -132,7 +188,7 @@ def fetch( if auth_type == AuthType.BASIC: client.auth = httpx.BasicAuth(username, password) elif auth_type == AuthType.OAUTH: - client.auth = OAuthAuth(username, password) + client.auth = OAuthAuth(username, password, feed_url=url) elif auth_type != AuthType.NONE: print("Username and password are required for authentication") sys.exit(-1) diff --git a/src/palace_tools/feeds/opds1.py b/src/palace_tools/feeds/opds1.py new file mode 100644 index 0000000..adbe0cd --- /dev/null +++ b/src/palace_tools/feeds/opds1.py @@ -0,0 +1,69 @@ +import sys +from pathlib import Path +from xml.etree import ElementTree + +import httpx +from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn + +from palace_tools.feeds.opds import AuthType, OAuthAuth, OpdsLinkTuple, error_and_exit + + +def parse_links(feed: str) -> dict[str, OpdsLinkTuple]: + feed_element = ElementTree.fromstring(feed) + return { + rel: OpdsLinkTuple(type=link_type, href=href, rel=rel) + for link in feed_element.findall("{http://www.w3.org/2005/Atom}link") + if (rel := link.get("rel")) is not None + and (link_type := link.get("type")) is not None + and (href := link.get("href")) is not None + } + + +def make_request(session: httpx.Client, url: str) -> str: + response = session.get(url) + if response.status_code != 200: + error_and_exit(response) + return response.text + + +def fetch( + url: str, + username: str | None, + password: str | None, + auth_type: AuthType, + output_file: Path, +) -> None: + # Create a session to fetch the documents + client = httpx.Client() + + client.headers.update( + { + "Accept": "application/atom+xml;profile=opds-catalog;kind=acquisition,application/atom+xml;q=0.9,application/xml;q=0.8,*/*;q=0.1", + "User-Agent": "Palace", + } + ) + client.timeout = httpx.Timeout(30.0) + + if username and password: + if auth_type == AuthType.BASIC: + client.auth = httpx.BasicAuth(username, password) + elif auth_type == AuthType.OAUTH: + client.auth = OAuthAuth( + username, password, feed_url=url, parse_links=parse_links + ) + elif auth_type != AuthType.NONE: + print("Username and password are required for authentication") + sys.exit(-1) + + next_url: str | None = url + with output_file.open("w") as file: + with Progress( + SpinnerColumn(), *Progress.get_default_columns(), MofNCompleteColumn() + ) as progress: + download_task = progress.add_task(f"Downloading Feed", total=None) + while next_url is not None: + response = make_request(client, next_url) + file.write(response) + links = parse_links(response) + next_url = links.get("next") and links["next"].href + progress.update(download_task, advance=1)