From e957ab3de81f9e7ae991f8624417949065deb305 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonah=20Br=C3=BCchert?= Date: Sun, 10 Nov 2024 22:51:15 +0100 Subject: [PATCH] fetch: Fix typing issues --- src/fetch.py | 20 +++++++++++++------- src/transitland.py | 14 ++++++++------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/fetch.py b/src/fetch.py index bbba2f64..01e29bb4 100755 --- a/src/fetch.py +++ b/src/fetch.py @@ -8,7 +8,7 @@ from datetime import datetime, timezone from utils import eprint from zipfile import ZipFile -from typing import Optional +from typing import Optional, Any from zoneinfo import ZoneInfo import email.utils @@ -57,7 +57,9 @@ def check_feed_timeframe_valid(zip_content: bytes) -> bool: if "feed_info.txt" not in z.namelist(): return True - feed_timezone = ZoneInfo(get_feed_timezone(z)) + tz = get_feed_timezone(z) + assert tz + feed_timezone = ZoneInfo(tz) with z.open("feed_info.txt", "r") as a: with io.TextIOWrapper(a) as at: @@ -98,7 +100,7 @@ def fetch_source(self, dest_path: Path, source: Source) -> bool: return self.fetch_source(dest_path, http_source) case HttpSource(): - request_options = { + request_options: dict[str, Any] = { "verify": not source.options.ignore_tls_errors, "timeout": 5 } @@ -168,14 +170,15 @@ def fetch_source(self, dest_path: Path, source: Source) -> bool: last_modified_server = email.utils.parsedate_to_datetime( server_headers["last-modified"]) + content: bytes if "#" in download_url: # if URL contains #, treat the path after # as an embedded ZIP file sub_path = download_url.partition("#")[2] zipfile = ZipFile(io.BytesIO(response.content)) - content: bytes = zipfile.read(sub_path) + content = zipfile.read(sub_path) else: - content: bytes = response.content + content = response.content # Only write file if the new version changed. Helps to at least # skip postprocessing with servers that don't send a @@ -262,12 +265,14 @@ def fetch(self, metadata: Path): # Resolve transitland sources to http / url sources match source: case TransitlandSource(): - source = self.transitland_atlas.source_by_id(source) + http_source = self.transitland_atlas.source_by_id(source) # Transitland source type that we cannot handle - if not source: + if not http_source: continue + source = http_source + validate_source_name(source.name) download_name = f"{region_name}_{source.name}" @@ -312,6 +317,7 @@ def fetch(self, metadata: Path): return errors + if __name__ == "__main__": fetcher = Fetcher() diff --git a/src/transitland.py b/src/transitland.py index 24b8dde8..0ea51132 100644 --- a/src/transitland.py +++ b/src/transitland.py @@ -39,6 +39,14 @@ def source_by_id(self, source: TransitlandSource) -> Union[Source, None]: result.drop_too_fast_trips = source.drop_too_fast_trips result.function = source.function result.drop_shapes = source.drop_shapes + + if source.url_override: + result.url_override = source.url_override + + if source.proxy: + result.url_override = "https://gtfsproxy.fwan.it/" + \ + source.transitland_atlas_id + elif "realtime_trip_updates" in feed["urls"]: result = UrlSource() result.name = source.name @@ -58,10 +66,4 @@ def source_by_id(self, source: TransitlandSource) -> Union[Source, None]: if "url" in feed["license"]: result.license.url = feed["license"]["url"] - if source.url_override: - result.url_override = source.url_override - - if source.proxy: - result.url_override = "https://gtfsproxy.fwan.it/" + source.transitland_atlas_id - return result