-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
#!/usr/bin/env python | ||
"""Remove availability of items no longer present in OPDS 2.0 import collections.""" | ||
import json | ||
from typing import Any | ||
|
||
from webpub_manifest_parser.opds2 import OPDS2FeedParserFactory | ||
|
||
from palace.manager.core.coverage import CoverageFailure | ||
from palace.manager.core.metadata_layer import TimestampData | ||
from palace.manager.core.opds2_import import ( | ||
OPDS2API, | ||
OPDS2Importer, | ||
OPDS2ImportMonitor, | ||
RWPMManifestParser, | ||
) | ||
from palace.manager.scripts.input import CollectionInputScript | ||
from palace.manager.sqlalchemy.model.collection import Collection | ||
from palace.manager.sqlalchemy.model.edition import Edition | ||
from palace.manager.sqlalchemy.model.identifier import Identifier | ||
from palace.manager.sqlalchemy.model.licensing import LicensePool | ||
|
||
|
||
def main(): | ||
reaper_script = OPDS2ReaperScript( | ||
importer_class=OPDS2Importer, | ||
monitor_class=OPDS2ReaperMonitor, | ||
protocol=OPDS2Importer.NAME, | ||
parser=RWPMManifestParser(OPDS2FeedParserFactory()), | ||
) | ||
|
||
reaper_script.run() | ||
|
||
|
||
class OPDS2ReaperMonitor(OPDS2ImportMonitor): | ||
"""Monitor to make unavailable any license pools without a matching identifier in the feed.""" | ||
|
||
def __init__( | ||
self, | ||
*args: Any, | ||
dry_run: bool = False, | ||
**import_class_kwargs: Any, | ||
) -> None: | ||
self.seen_identifiers: set[str] = set() | ||
self.missing_id_count = 0 | ||
self.publication_count = 0 | ||
self.dry_run = dry_run | ||
super().__init__(*args, **import_class_kwargs) | ||
|
||
def feed_contains_new_data(self, feed: bytes | str) -> bool: | ||
# Always return True so that the reaper will crawl the entire feed. | ||
return True | ||
|
||
def import_one_feed( | ||
self, feed: bytes | str | ||
) -> tuple[list[Edition], dict[str, list[CoverageFailure]]]: | ||
# Collect all the identifiers in the given feed page. | ||
feed_obj = json.loads(feed) | ||
publications: list[dict[str, Any]] = feed_obj["publications"] | ||
identifiers = list( | ||
filter( | ||
None, | ||
(pub.get("metadata", {}).get("identifier") for pub in publications), | ||
) | ||
) | ||
|
||
self.publication_count += len(publications) | ||
self.missing_id_count += len(publications) - len(identifiers) | ||
self.seen_identifiers.update(identifiers) | ||
|
||
# No editions / coverage failures, since we're just reaping. | ||
return [], {} | ||
|
||
def run_once(self, progress: TimestampData) -> TimestampData: | ||
"""Check to see if any identifiers we know about are no longer | ||
present on the remote. If there are any, remove them. | ||
:param progress: A TimestampData, ignored. | ||
""" | ||
super().run_once(progress) | ||
|
||
# Convert feed identifiers to our identifiers, so we can find them. | ||
# Unlike the import case, we don't want to create identifiers, if | ||
# they don't already exist. | ||
identifiers, failures = Identifier.parse_urns( | ||
self._db, self.seen_identifiers, autocreate=False | ||
) | ||
identifier_ids = [x.id for x in list(identifiers.values())] | ||
|
||
collection_license_pools_qu = self._db.query(LicensePool).filter( | ||
LicensePool.collection_id == self.collection.id | ||
) | ||
collection_license_pools = collection_license_pools_qu.count() | ||
|
||
unlimited_access_license_pools_qu = collection_license_pools_qu.filter( | ||
LicensePool.licenses_available == LicensePool.UNLIMITED_ACCESS | ||
) | ||
unlimited_access_license_pools = unlimited_access_license_pools_qu.count() | ||
|
||
# At this point we've gone through the feed and collected all the identifiers. | ||
# If there's anything we didn't see, we know it's no longer available. | ||
to_be_reaped_qu = unlimited_access_license_pools_qu.join(Identifier).filter( | ||
~Identifier.id.in_(identifier_ids) | ||
) | ||
reap_count = to_be_reaped_qu.count() | ||
self.log.info( | ||
f"Reaping {reap_count} of {unlimited_access_license_pools} unlimited (of {collection_license_pools} total) license pools from collection '{self.collection.name}'. " | ||
f"Feed contained {self.publication_count} publication entries, {len(self.seen_identifiers)} unique identifiers, {self.missing_id_count} missing identifiers. " | ||
f"Unable to parse {len(failures)} of {len(self.seen_identifiers)} identifiers." | ||
) | ||
|
||
if self.dry_run: | ||
# TODO: Need to prevent timestamp update for dry runs. | ||
self.log.info( | ||
"Dry run. No license pools were reaped, but {} were eligible." | ||
) | ||
else: | ||
for pool in to_be_reaped_qu: | ||
pool.unlimited_access = False | ||
|
||
achievements = f"License pools removed: {reap_count}. Failures parsing identifiers from feed: {len(failures)}." | ||
return TimestampData(achievements=achievements) | ||
|
||
|
||
class OPDS2ReaperScript(CollectionInputScript): | ||
"""Import all books from the OPDS feed associated with a collection.""" | ||
|
||
name = "Reap books from a collection, if not present in its associate feed." | ||
|
||
IMPORTER_CLASS = OPDS2Importer | ||
MONITOR_CLASS: type[OPDS2ReaperMonitor] = OPDS2ReaperMonitor | ||
PROTOCOL = OPDS2API.label() | ||
|
||
def __init__( | ||
self, | ||
_db=None, | ||
importer_class=None, | ||
monitor_class=None, | ||
protocol=None, | ||
*args, | ||
**kwargs, | ||
): | ||
super().__init__(_db, *args, **kwargs) | ||
self.importer_class = importer_class or self.IMPORTER_CLASS | ||
self.monitor_class = monitor_class or self.MONITOR_CLASS | ||
self.protocol = protocol or self.PROTOCOL | ||
self.importer_kwargs = kwargs | ||
|
||
@classmethod | ||
def arg_parser(cls): | ||
parser = super().arg_parser() | ||
parser.add_argument( | ||
"--dry-run", | ||
"-n", | ||
help="Don't actually reap any books. Just report the statistics.", | ||
dest="dry_run", | ||
action="store_true", | ||
) | ||
parser.add_argument( | ||
"--all-collections-for-protocol", | ||
"-a", | ||
help="Use all collections with associate protocol(self.protocol), if no collections specified..", | ||
dest="all_protocol_collections", | ||
action="store_true", | ||
) | ||
return parser | ||
|
||
def do_run(self, cmd_args=None): | ||
parsed = self.parse_command_line(self._db, cmd_args=cmd_args) | ||
collections: list[Collection] = parsed.collections | ||
if not collections and parsed.all_protocol_collections: | ||
collections = list(Collection.by_protocol(self._db, self.protocol)) | ||
for collection in collections: | ||
self.run_monitor( | ||
collection, | ||
dry_run=parsed.dry_run, | ||
) | ||
|
||
def run_monitor(self, collection, *, dry_run=False): | ||
monitor = self.monitor_class( | ||
self._db, | ||
collection, | ||
import_class=self.importer_class, | ||
dry_run=dry_run, | ||
**self.importer_kwargs, | ||
) | ||
monitor.run() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |