Skip to content

Commit

Permalink
fetch: Fix typing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jbruechert committed Nov 10, 2024
1 parent 064ebb5 commit e957ab3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
20 changes: 13 additions & 7 deletions src/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datetime import datetime, timezone
from utils import eprint
from zipfile import ZipFile
from typing import Optional
from typing import Optional, Any
from zoneinfo import ZoneInfo

import email.utils
Expand Down Expand Up @@ -57,7 +57,9 @@ def check_feed_timeframe_valid(zip_content: bytes) -> bool:
if "feed_info.txt" not in z.namelist():
return True

feed_timezone = ZoneInfo(get_feed_timezone(z))
tz = get_feed_timezone(z)
assert tz
feed_timezone = ZoneInfo(tz)

with z.open("feed_info.txt", "r") as a:
with io.TextIOWrapper(a) as at:
Expand Down Expand Up @@ -98,7 +100,7 @@ def fetch_source(self, dest_path: Path, source: Source) -> bool:

return self.fetch_source(dest_path, http_source)
case HttpSource():
request_options = {
request_options: dict[str, Any] = {
"verify": not source.options.ignore_tls_errors,
"timeout": 5
}
Expand Down Expand Up @@ -168,14 +170,15 @@ def fetch_source(self, dest_path: Path, source: Source) -> bool:
last_modified_server = email.utils.parsedate_to_datetime(
server_headers["last-modified"])

content: bytes
if "#" in download_url:
# if URL contains #, treat the path after # as an embedded ZIP file
sub_path = download_url.partition("#")[2]
zipfile = ZipFile(io.BytesIO(response.content))

content: bytes = zipfile.read(sub_path)
content = zipfile.read(sub_path)
else:
content: bytes = response.content
content = response.content

# Only write file if the new version changed. Helps to at least
# skip postprocessing with servers that don't send a
Expand Down Expand Up @@ -262,12 +265,14 @@ def fetch(self, metadata: Path):
# Resolve transitland sources to http / url sources
match source:
case TransitlandSource():
source = self.transitland_atlas.source_by_id(source)
http_source = self.transitland_atlas.source_by_id(source)

# Transitland source type that we cannot handle
if not source:
if not http_source:
continue

source = http_source

validate_source_name(source.name)
download_name = f"{region_name}_{source.name}"

Expand Down Expand Up @@ -312,6 +317,7 @@ def fetch(self, metadata: Path):

return errors


if __name__ == "__main__":
fetcher = Fetcher()

Expand Down
14 changes: 8 additions & 6 deletions src/transitland.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def source_by_id(self, source: TransitlandSource) -> Union[Source, None]:
result.drop_too_fast_trips = source.drop_too_fast_trips
result.function = source.function
result.drop_shapes = source.drop_shapes

if source.url_override:
result.url_override = source.url_override

if source.proxy:
result.url_override = "https://gtfsproxy.fwan.it/" + \
source.transitland_atlas_id

elif "realtime_trip_updates" in feed["urls"]:
result = UrlSource()
result.name = source.name
Expand All @@ -58,10 +66,4 @@ def source_by_id(self, source: TransitlandSource) -> Union[Source, None]:
if "url" in feed["license"]:
result.license.url = feed["license"]["url"]

if source.url_override:
result.url_override = source.url_override

if source.proxy:
result.url_override = "https://gtfsproxy.fwan.it/" + source.transitland_atlas_id

return result

0 comments on commit e957ab3

Please sign in to comment.