Skip to content

Commit

Permalink
Add OPDS2 reaper as bin script.
Browse files Browse the repository at this point in the history
  • Loading branch information
tdilauro committed Sep 9, 2024
1 parent 81b7ce9 commit dcbb226
Showing 1 changed file with 190 additions and 0 deletions.
190 changes: 190 additions & 0 deletions bin/opds2_reaper_monitor
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#!/usr/bin/env python
"""Remove availability of items no longer present in OPDS 2.0 import collections."""
import json
from typing import Any

from webpub_manifest_parser.opds2 import OPDS2FeedParserFactory

from palace.manager.core.coverage import CoverageFailure
from palace.manager.core.metadata_layer import TimestampData
from palace.manager.core.opds2_import import (
OPDS2API,
OPDS2Importer,
OPDS2ImportMonitor,
RWPMManifestParser,
)
from palace.manager.scripts.input import CollectionInputScript
from palace.manager.sqlalchemy.model.collection import Collection
from palace.manager.sqlalchemy.model.edition import Edition
from palace.manager.sqlalchemy.model.identifier import Identifier
from palace.manager.sqlalchemy.model.licensing import LicensePool


def main():
reaper_script = OPDS2ReaperScript(
importer_class=OPDS2Importer,
monitor_class=OPDS2ReaperMonitor,
protocol=OPDS2Importer.NAME,
parser=RWPMManifestParser(OPDS2FeedParserFactory()),
)

reaper_script.run()


class OPDS2ReaperMonitor(OPDS2ImportMonitor):
"""Monitor to make unavailable any license pools without a matching identifier in the feed."""

def __init__(
self,
*args: Any,
dry_run: bool = False,
**import_class_kwargs: Any,
) -> None:
self.seen_identifiers: set[str] = set()
self.missing_id_count = 0
self.publication_count = 0
self.dry_run = dry_run
super().__init__(*args, **import_class_kwargs)

def feed_contains_new_data(self, feed: bytes | str) -> bool:
# Always return True so that the reaper will crawl the entire feed.
return True

def import_one_feed(
self, feed: bytes | str
) -> tuple[list[Edition], dict[str, list[CoverageFailure]]]:
# Collect all the identifiers in the given feed page.
feed_obj = json.loads(feed)
publications: list[dict[str, Any]] = feed_obj["publications"]
identifiers = list(
filter(
None,
(pub.get("metadata", {}).get("identifier") for pub in publications),
)
)

self.publication_count += len(publications)
self.missing_id_count += len(publications) - len(identifiers)
self.seen_identifiers.update(identifiers)

# No editions / coverage failures, since we're just reaping.
return [], {}

def run_once(self, progress: TimestampData) -> TimestampData:
"""Check to see if any identifiers we know about are no longer
present on the remote. If there are any, remove them.
:param progress: A TimestampData, ignored.
"""
super().run_once(progress)

# Convert feed identifiers to our identifiers, so we can find them.
# Unlike the import case, we don't want to create identifiers, if
# they don't already exist.
identifiers, failures = Identifier.parse_urns(
self._db, self.seen_identifiers, autocreate=False
)
identifier_ids = [x.id for x in list(identifiers.values())]

collection_license_pools_qu = self._db.query(LicensePool).filter(
LicensePool.collection_id == self.collection.id
)
collection_license_pools = collection_license_pools_qu.count()

unlimited_access_license_pools_qu = collection_license_pools_qu.filter(
LicensePool.licenses_available == LicensePool.UNLIMITED_ACCESS
)
unlimited_access_license_pools = unlimited_access_license_pools_qu.count()

# At this point we've gone through the feed and collected all the identifiers.
# If there's anything we didn't see, we know it's no longer available.
to_be_reaped_qu = unlimited_access_license_pools_qu.join(Identifier).filter(
~Identifier.id.in_(identifier_ids)
)
reap_count = to_be_reaped_qu.count()
self.log.info(
f"Reaping {reap_count} of {unlimited_access_license_pools} unlimited (of {collection_license_pools} total) license pools from collection '{self.collection.name}'. "
f"Feed contained {self.publication_count} publication entries, {len(self.seen_identifiers)} unique identifiers, {self.missing_id_count} missing identifiers. "
f"Unable to parse {len(failures)} of {len(self.seen_identifiers)} identifiers."
)

if self.dry_run:
# TODO: Need to prevent timestamp update for dry runs.
self.log.info(
"Dry run. No license pools were reaped, but {} were eligible."
)
else:
for pool in to_be_reaped_qu:
pool.unlimited_access = False

achievements = f"License pools removed: {reap_count}. Failures parsing identifiers from feed: {len(failures)}."
return TimestampData(achievements=achievements)


class OPDS2ReaperScript(CollectionInputScript):
"""Import all books from the OPDS feed associated with a collection."""

name = "Reap books from a collection, if not present in its associate feed."

IMPORTER_CLASS = OPDS2Importer
MONITOR_CLASS: type[OPDS2ReaperMonitor] = OPDS2ReaperMonitor
PROTOCOL = OPDS2API.label()

def __init__(
self,
_db=None,
importer_class=None,
monitor_class=None,
protocol=None,
*args,
**kwargs,
):
super().__init__(_db, *args, **kwargs)
self.importer_class = importer_class or self.IMPORTER_CLASS
self.monitor_class = monitor_class or self.MONITOR_CLASS
self.protocol = protocol or self.PROTOCOL
self.importer_kwargs = kwargs

@classmethod
def arg_parser(cls):
parser = super().arg_parser()
parser.add_argument(
"--dry-run",
"-n",
help="Don't actually reap any books. Just report the statistics.",
dest="dry_run",
action="store_true",
)
parser.add_argument(
"--all-collections-for-protocol",
"-a",
help="Use all collections with associate protocol(self.protocol), if no collections specified..",
dest="all_protocol_collections",
action="store_true",
)
return parser

def do_run(self, cmd_args=None):
parsed = self.parse_command_line(self._db, cmd_args=cmd_args)
collections: list[Collection] = parsed.collections
if not collections and parsed.all_protocol_collections:
collections = list(Collection.by_protocol(self._db, self.protocol))
for collection in collections:
self.run_monitor(
collection,
dry_run=parsed.dry_run,
)

def run_monitor(self, collection, *, dry_run=False):
monitor = self.monitor_class(
self._db,
collection,
import_class=self.importer_class,
dry_run=dry_run,
**self.importer_kwargs,
)
monitor.run()


if __name__ == "__main__":
main()

0 comments on commit dcbb226

Please sign in to comment.