diff --git a/bin/opds2_reaper_monitor b/bin/opds2_reaper_monitor new file mode 100755 index 000000000..59d4f26a1 --- /dev/null +++ b/bin/opds2_reaper_monitor @@ -0,0 +1,190 @@ +#!/usr/bin/env python +"""Remove availability of items no longer present in OPDS 2.0 import collections.""" +import json +from typing import Any + +from webpub_manifest_parser.opds2 import OPDS2FeedParserFactory + +from palace.manager.core.coverage import CoverageFailure +from palace.manager.core.metadata_layer import TimestampData +from palace.manager.core.opds2_import import ( + OPDS2API, + OPDS2Importer, + OPDS2ImportMonitor, + RWPMManifestParser, +) +from palace.manager.scripts.input import CollectionInputScript +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.edition import Edition +from palace.manager.sqlalchemy.model.identifier import Identifier +from palace.manager.sqlalchemy.model.licensing import LicensePool + + +def main(): + reaper_script = OPDS2ReaperScript( + importer_class=OPDS2Importer, + monitor_class=OPDS2ReaperMonitor, + protocol=OPDS2Importer.NAME, + parser=RWPMManifestParser(OPDS2FeedParserFactory()), + ) + + reaper_script.run() + + +class OPDS2ReaperMonitor(OPDS2ImportMonitor): + """Monitor to make unavailable any license pools without a matching identifier in the feed.""" + + def __init__( + self, + *args: Any, + dry_run: bool = False, + **import_class_kwargs: Any, + ) -> None: + self.seen_identifiers: set[str] = set() + self.missing_id_count = 0 + self.publication_count = 0 + self.dry_run = dry_run + super().__init__(*args, **import_class_kwargs) + + def feed_contains_new_data(self, feed: bytes | str) -> bool: + # Always return True so that the reaper will crawl the entire feed. + return True + + def import_one_feed( + self, feed: bytes | str + ) -> tuple[list[Edition], dict[str, list[CoverageFailure]]]: + # Collect all the identifiers in the given feed page. + feed_obj = json.loads(feed) + publications: list[dict[str, Any]] = feed_obj["publications"] + identifiers = list( + filter( + None, + (pub.get("metadata", {}).get("identifier") for pub in publications), + ) + ) + + self.publication_count += len(publications) + self.missing_id_count += len(publications) - len(identifiers) + self.seen_identifiers.update(identifiers) + + # No editions / coverage failures, since we're just reaping. + return [], {} + + def run_once(self, progress: TimestampData) -> TimestampData: + """Check to see if any identifiers we know about are no longer + present on the remote. If there are any, remove them. + + :param progress: A TimestampData, ignored. + """ + super().run_once(progress) + + # Convert feed identifiers to our identifiers, so we can find them. + # Unlike the import case, we don't want to create identifiers, if + # they don't already exist. + identifiers, failures = Identifier.parse_urns( + self._db, self.seen_identifiers, autocreate=False + ) + identifier_ids = [x.id for x in list(identifiers.values())] + + collection_license_pools_qu = self._db.query(LicensePool).filter( + LicensePool.collection_id == self.collection.id + ) + collection_license_pools = collection_license_pools_qu.count() + + unlimited_access_license_pools_qu = collection_license_pools_qu.filter( + LicensePool.licenses_available == LicensePool.UNLIMITED_ACCESS + ) + unlimited_access_license_pools = unlimited_access_license_pools_qu.count() + + # At this point we've gone through the feed and collected all the identifiers. + # If there's anything we didn't see, we know it's no longer available. + to_be_reaped_qu = unlimited_access_license_pools_qu.join(Identifier).filter( + ~Identifier.id.in_(identifier_ids) + ) + reap_count = to_be_reaped_qu.count() + self.log.info( + f"Reaping {reap_count} of {unlimited_access_license_pools} unlimited (of {collection_license_pools} total) license pools from collection '{self.collection.name}'. " + f"Feed contained {self.publication_count} publication entries, {len(self.seen_identifiers)} unique identifiers, {self.missing_id_count} missing identifiers. " + f"Unable to parse {len(failures)} of {len(self.seen_identifiers)} identifiers." + ) + + if self.dry_run: + # TODO: Need to prevent timestamp update for dry runs. + self.log.info( + "Dry run. No license pools were reaped, but {} were eligible." + ) + else: + for pool in to_be_reaped_qu: + pool.unlimited_access = False + + achievements = f"License pools removed: {reap_count}. Failures parsing identifiers from feed: {len(failures)}." + return TimestampData(achievements=achievements) + + +class OPDS2ReaperScript(CollectionInputScript): + """Import all books from the OPDS feed associated with a collection.""" + + name = "Reap books from a collection, if not present in its associate feed." + + IMPORTER_CLASS = OPDS2Importer + MONITOR_CLASS: type[OPDS2ReaperMonitor] = OPDS2ReaperMonitor + PROTOCOL = OPDS2API.label() + + def __init__( + self, + _db=None, + importer_class=None, + monitor_class=None, + protocol=None, + *args, + **kwargs, + ): + super().__init__(_db, *args, **kwargs) + self.importer_class = importer_class or self.IMPORTER_CLASS + self.monitor_class = monitor_class or self.MONITOR_CLASS + self.protocol = protocol or self.PROTOCOL + self.importer_kwargs = kwargs + + @classmethod + def arg_parser(cls): + parser = super().arg_parser() + parser.add_argument( + "--dry-run", + "-n", + help="Don't actually reap any books. Just report the statistics.", + dest="dry_run", + action="store_true", + ) + parser.add_argument( + "--all-collections-for-protocol", + "-a", + help="Use all collections with associate protocol(self.protocol), if no collections specified..", + dest="all_protocol_collections", + action="store_true", + ) + return parser + + def do_run(self, cmd_args=None): + parsed = self.parse_command_line(self._db, cmd_args=cmd_args) + collections: list[Collection] = parsed.collections + if not collections and parsed.all_protocol_collections: + collections = list(Collection.by_protocol(self._db, self.protocol)) + for collection in collections: + self.run_monitor( + collection, + dry_run=parsed.dry_run, + ) + + def run_monitor(self, collection, *, dry_run=False): + monitor = self.monitor_class( + self._db, + collection, + import_class=self.importer_class, + dry_run=dry_run, + **self.importer_kwargs, + ) + monitor.run() + + +if __name__ == "__main__": + main()