From d81b7d08292935c93746edda891bd87e55788789 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon-Martin=20Schr=C3=B6der?= Date: Fri, 18 Oct 2024 23:16:40 +0200 Subject: [PATCH] Archive improvements (#18) - Archive: Path interface & compress_hint - archive.read_tsv: dtype precedence: file header, column defaults, parameter - Default column dtypes - read_tsv: Peek at header, then read accordingly - read_tsv: Detect duplicate column names - Archive.validate --- docs/index.rst | 15 +- src/pyecotaxa/archive.py | 413 +++++++++++++++++++++++++++++---------- tests/test_archive.py | 48 +++-- 3 files changed, 337 insertions(+), 139 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index b80ae8c..c9e90df 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,8 +1,3 @@ -.. pyecotaxa documentation master file, created by - sphinx-quickstart on Fri Jan 29 22:02:29 2021. - You can adapt this file completely to your liking, but it should at least - contain the root `toctree` directive. - pyecotaxa ========= @@ -24,12 +19,4 @@ Reading and writing archives .. toctree:: :maxdepth: 2 :caption: Contents: - - - -.. Indices and tables -.. ================== - -.. * :ref:`genindex` -.. * :ref:`modindex` -.. * :ref:`search` + :hidden: diff --git a/src/pyecotaxa/archive.py b/src/pyecotaxa/archive.py index a65094c..b99316f 100644 --- a/src/pyecotaxa/archive.py +++ b/src/pyecotaxa/archive.py @@ -1,122 +1,201 @@ """Read and write EcoTaxa archives and individual EcoTaxa TSV files.""" +import collections +import csv import fnmatch import io import pathlib +import posixpath +import shutil import tarfile -import warnings import zipfile -from typing import IO, Callable, List, Union +from io import BufferedReader, BytesIO, IOBase +from typing import ( + IO, + Any, + BinaryIO, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) -import numpy as np import pandas as pd +from tqdm.auto import tqdm __all__ = ["read_tsv", "write_tsv"] -def _fix_types(dataframe, enforce_types): - header = dataframe.columns.get_level_values(0) - types = dataframe.columns.get_level_values(1) +DEFAULT_DTYPES = { + "img_file_name": str, + "img_rank": int, + "object_id": str, + "object_link": str, + "object_lat": float, + "object_lon": float, + "object_date": str, + "object_time": str, + "object_annotation_date": str, + "object_annotation_time": str, + "object_annotation_category": str, + "object_annotation_category_id": "Int64", + "object_annotation_person_name": str, + "object_annotation_person_email": str, + "object_annotation_status": str, + "process_id": str, + "acq_id": str, + "sample_id": str, +} - dataframe.columns = header +VALID_PREFIXES = {"object", "sample", "acq", "process", "img"} - float_cols = [] - text_cols = [] - for c, t in zip(header, types): - if t == "[f]": - float_cols.append(c) - elif t == "[t]": - text_cols.append(c) - else: - # If the first row contains other values than [f] or [t], - # it is not a type header but a normal line of values and has to be inserted into the dataframe. - # This is the case for "General export". - # Clean up empty fields - types = [None if t.startswith("Unnamed") else t for t in types] +def _parse_tsv_header( + f: IOBase, encoding: str +) -> Tuple[Optional[Sequence[str]], Dict[str, Any], int]: + skiprows = 0 - # Prepend the current "types" to the dataframe - row0 = pd.DataFrame([types], columns=header).astype(dataframe.dtypes) + header: List[str] = [] + while len(header) < 2: + line = f.readline() - if enforce_types: - warnings.warn( - "enforce_types=True, but no type header was found.", stacklevel=3 - ) + if not line: + break + + skiprows += 1 + + if isinstance(line, bytes): + line = line.decode(encoding) - return pd.concat((row0, dataframe), ignore_index=True) + if not line.startswith("#"): + header.append(line) - if enforce_types: - # Enforce [f] types - dataframe[float_cols] = dataframe[float_cols].astype(float) - dataframe[text_cols] = dataframe[text_cols].fillna("").astype(str) + if not header: + return None, {}, 0 - return dataframe + csv_reader = csv.reader(header, delimiter="\t") + names = next(csv_reader) -def _apply_usecols( - df: pd.DataFrame, usecols: Union[Callable, List[str]] -) -> pd.DataFrame: - if callable(usecols): - columns = [c for c in df.columns.get_level_values(0) if usecols(c)] + try: + maybe_types = next(csv_reader) + except StopIteration: + # No second line + return names, {}, skiprows + + if len(names) != len(maybe_types): + raise ValueError("Number of names does not match number of types") + + # Second line *might* contain types + if all(t in ("[t]", "[f]") for t in maybe_types): + # Infer dtype + dtype = {n: str for n, t in zip(names, maybe_types) if t == "[t]"} else: - columns = [c for c in df.columns.get_level_values(0) if c in usecols] + # This wasn't a type row after all + dtype = {} + skiprows -= 1 - return df[columns] + return names, dtype, skiprows def read_tsv( - filepath_or_buffer, + fn_or_f: Union[str, pathlib.Path, IOBase], encoding: str = "utf-8-sig", - enforce_types=False, - usecols: Union[None, Callable, List[str]] = None, + dtype=None, **kwargs, -) -> pd.DataFrame: +): """ - Read an individual EcoTaxa TSV file. + Read a TSV (Tab-Separated Values) file into a pandas DataFrame. + + This function reads a TSV file and processes the type header (if provided) + to ensure the correct dtype for each column. + It supports handling files from both file paths and file-like objects. + + The dtype of each column is determined by the following precedence order: + 1. `dtype` parameter (if provided explicitly). + 2. DEFAULT_DTYPES, containing the correct dtype for well-known columns. + 3. File header (if the file includes a type header). Args: - filepath_or_buffer (str, path object or file-like object): ... - encoding: Encoding of the TSV file. - With the default "utf-8-sig", both UTF8 and signed UTF8 can be read. - enforce_types: Enforce the column dtypes provided in the header. - Usually, it is desirable to allow pandas to infer the column dtypes. - usecols: List of strings or callable. - **kwargs: Additional kwargs are passed to :func:`pandas:pandas.read_csv`. + fn_or_f (str, pathlib.Path, or file-like): + The file path or file-like object to read the TSV from. + encoding (str, optional): + The encoding to use for reading the file. Defaults to "utf-8-sig". + dtype (dict, optional): + A dictionary specifying the data types of columns. Defaults to `None`, + which uses the default types. + **kwargs: + Additional keyword arguments passed to `pandas.read_csv()`. Returns: - A Pandas :class:`~pandas:pandas.DataFrame`. + A pandas DataFrame containing the TSV data. + + Notes: + The function detects duplicate column names and raises an error if found. + It fills NaN values in string columns with empty strings. """ + must_close = False + f: BinaryIO - if usecols is not None: - chunksize = kwargs.pop("chunksize", 10000) - - # Read a few rows a time - dataframe: pd.DataFrame = pd.concat( - [ - _apply_usecols(chunk, usecols) - for chunk in pd.read_csv( - filepath_or_buffer, - sep="\t", - encoding=encoding, - header=[0, 1], - chunksize=chunksize, - **kwargs, - ) - ] - ) # type: ignore + if dtype is None: + dtype = DEFAULT_DTYPES else: - if kwargs.pop("chunksize", None) is not None: - warnings.warn("Parameter chunksize is ignored.") + dtype = {**DEFAULT_DTYPES, **dtype} - dataframe: pd.DataFrame = pd.read_csv( - filepath_or_buffer, sep="\t", encoding=encoding, header=[0, 1], **kwargs - ) # type: ignore + if isinstance(fn_or_f, str): + fn_or_f = pathlib.Path(fn_or_f) - return _fix_types(dataframe, enforce_types) + if hasattr(fn_or_f, "open"): + f = fn_or_f.open("r", encoding=encoding) # type: ignore + must_close = True + else: + f = fn_or_f # type: ignore + + try: + if f.seekable(): + # We can just rewind after inspecting the header + names, header_dtype, skiprows = _parse_tsv_header(f, encoding) + f.seek(0) + else: + # Make sure that we can peek into the file + if not hasattr(f, "peek"): + f = BufferedReader(f) # type: ignore + + # Peek the first 8kb and inspect + header_f = BytesIO(f.peek(8 * 1024)) # type: ignore + names, header_dtype, skiprows = _parse_tsv_header(header_f, encoding) + + dtype = {**header_dtype, **dtype} + + # Detect duplicate names + duplicate_names = [ + f"'{name}' ({count}x)" + for name, count in collections.Counter(names).items() + if count > 1 + ] + if duplicate_names: + raise ValueError( + "TSV file contains duplicate column names: " + + (", ".join(duplicate_names)) + ) + + dataframe = pd.read_csv(f, sep="\t", names=names, dtype=dtype, skiprows=skiprows, **kwargs) # type: ignore + + for c, dt in dataframe.dtypes.items(): + if pd.api.types.is_string_dtype(dt): + dataframe[c] = dataframe[c].fillna("") + + return dataframe + finally: + if must_close: + f.close() def _dtype_to_ecotaxa(dtype): - if np.issubdtype(dtype, np.number): + if pd.api.types.is_numeric_dtype(dtype): return "[f]" return "[t]" @@ -127,35 +206,59 @@ def write_tsv( path_or_buf=None, encoding="utf-8", type_header=True, + formatters: Optional[Mapping] = None, **kwargs, ): """ - Write an individual EcoTaxa TSV file. + Write a pandas DataFrame to a TSV (Tab-Separated Values) file in EcoTaxa format. + + This function writes a DataFrame to a TSV file. Optionally, it includes a type + header that specifies the data types for each column, which is required for + compatibility with EcoTaxa. Args: - dataframe: A pandas DataFrame. - path_or_buf (str, path object or file-like object): ... - encoding: Encoding of the TSV file. - With the default "utf-8", both UTF8 and signed UTF8 readers can read the file. - enforce_types: Enforce the column dtypes provided in the header. - Usually, it is desirable to allow pandas to infer the column dtypes. - type_header (bool, default true): Include the type header ([t]/[f]). - This is required for a successful import into EcoTaxa. - - Return: - None or str - - If path_or_buf is None, returns the resulting csv format as a string. Otherwise returns None. + dataframe (pd.DataFrame): + The pandas DataFrame to be written to the TSV file. + path_or_buf (str, pathlib.Path, file-like, optional): + The file path or file-like object where the TSV will be written. If None, + the function returns the TSV content as a string. Defaults to None. + encoding (str, optional): + The encoding to use for writing the file. Defaults to "utf-8". + type_header (bool, optional): + Whether to include a type header specifying the data types for each column. + Defaults to True. + formatters (Optional[Mapping], optional): + A dictionary specifying formatting functions to apply to columns. + Defaults to None. + **kwargs: + Additional keyword arguments passed to `pandas.DataFrame.to_csv()`. + + Returns: + If `path_or_buf` is provided, the function returns None. If `path_or_buf` + is None, it returns the TSV content as a string. """ - if type_header: - # Make a copy before changing the index - dataframe = dataframe.copy() + if formatters is None: + formatters = {} + dataframe = dataframe.copy(deep=False) + + # Calculate type header before formatting values + ecotaxa_types = [_dtype_to_ecotaxa(dt) for dt in dataframe.dtypes] + + # Apply formatting + for col in dataframe.columns: + fmt = formatters.get(col) + + if fmt is None: + continue + + dataframe[col] = dataframe[col].apply(fmt) + + if type_header: # Inject types into header - type_header = [_dtype_to_ecotaxa(dt) for dt in dataframe.dtypes] dataframe.columns = pd.MultiIndex.from_tuples( - list(zip(dataframe.columns, type_header)) + list(zip(dataframe.columns, ecotaxa_types)) ) return dataframe.to_csv( @@ -186,6 +289,22 @@ def __len__(self): return len(self.tsv_fns) +class ArchivePath: + def __init__(self, archive: "Archive", filename) -> None: + self.archive = archive + self.filename = filename + + def open(self, mode="r", compress_hint=True) -> IO: + return self.archive.open(self.filename, mode, compress_hint) + + def __truediv__(self, filename): + return ArchivePath(self.archive, posixpath.join(self.filename, filename)) + + +class ValidationError(Exception): + pass + + class Archive: """ A generic archive reader and writer for ZIP and TAR archives. @@ -210,6 +329,8 @@ def __new__(cls, archive_fn: Union[str, pathlib.Path], mode: str = "r"): raise UnknownArchiveError(f"No handler found to write {archive_fn}") + raise ValueError("Unknown mode: {mode}") + @staticmethod def is_readable(archive_fn) -> bool: raise NotImplementedError() # pragma: no cover @@ -217,16 +338,14 @@ def is_readable(archive_fn) -> bool: def __init__(self, archive_fn: Union[str, pathlib.Path], mode: str = "r"): raise NotImplementedError() # pragma: no cover - def open(self, member_fn, mode="r") -> IO: + def open(self, member_fn, mode="r", compress_hint=True) -> IO: """ + Open an archive member. + Raises: - MemberNotFoundError if a member was not found + MemberNotFoundError if mode=="r" and the member was not found. """ - raise NotImplementedError() # pragma: no cover - def write_member( - self, member_fn, fileobj_or_bytes: Union[IO, bytes], compress_hint=True - ): raise NotImplementedError() # pragma: no cover def find(self, pattern) -> List[str]: @@ -258,6 +377,77 @@ def iter_tsv(self, **kwargs): """ return _TSVIterator(self, self.find("*.tsv"), kwargs) + def __truediv__(self, key): + return ArchivePath(self, key) + + def add_images( + self, df: pd.DataFrame, src: Union[str, "Archive", pathlib.Path], progress=False + ): + """Add images referenced in df from src.""" + + if isinstance(src, str): + src = pathlib.Path(src) + + for img_file_name in tqdm(df["img_file_name"], disable=not progress): + with (src / img_file_name).open() as f_src, self.open( + img_file_name, "w" + ) as f_dst: + shutil.copyfileobj(f_src, f_dst) + + def validate(self): + """Mimic the validation done by EcoTaxa.""" + + # ecotaxa_back/py/BO/Bundle.py:43 + MAX_FILES = 2000 + + tsv: pd.DataFrame + + for i, (tsv_fn, tsv) in enumerate(self.iter_tsv(), start=1): + if i > MAX_FILES: + raise ValidationError( + f"Archive contains too many files, max. is {MAX_FILES}" + ) + + errors = [] + + # Validate columns (validate_structure) + # ecotaxa_back/py/BO/TSVFile.py:873 + for c in tsv.columns: + if c in DEFAULT_DTYPES: + # This is a known field + # TODO: Check dtype + continue + + try: + prefix, name = c.split("_", 1) + except ValueError: + errors.append( + f"Invalid field '{c}', format must be '_'" + ) + continue + + if prefix not in VALID_PREFIXES: + errors.append(f"Invalid prefix '{prefix}' for column '{c}'") + continue + + # Ensure that each used prefix contains at least an ID + for prefix in ["object", "acq", "process", "sample"]: + expected_id = f"{prefix}_id" + prefix_columns = [c for c in tsv.columns if c.startswith(prefix)] + if prefix_columns and expected_id not in tsv.columns: + errors.append( + f"Field {expected_id} is mandatory as there are some '{prefix}' columns: {sorted(prefix_columns)}." + ) + + if errors: + raise ValidationError( + f"Invalid structure in {tsv_fn}:\n" + ("\n".join(errors)) + ) + + # TODO: Validate contents (validate_content) + # ecotaxa_back/py/BO/TSVFile.py:967 + ... + class _TarIO(io.BytesIO): def __init__(self, archive: "TarArchive", member_fn) -> None: @@ -297,7 +487,10 @@ def __init__(self, archive_fn: Union[str, pathlib.Path], mode: str = "r"): def close(self): self._tar.close() - def open(self, member_fn, mode="r") -> IO: + def open(self, member_fn, mode="r", compress_hint=True) -> IO: + # tar does not compress files individually + del compress_hint + if mode == "r": try: fp = self._tar.extractfile(self._resolve_member(member_fn)) @@ -331,6 +524,9 @@ def _resolve_member(self, member): def write_member( self, member_fn: str, fileobj_or_bytes: Union[IO, bytes], compress_hint=True ): + # tar does not compress files individually + del compress_hint + if isinstance(fileobj_or_bytes, bytes): fileobj_or_bytes = io.BytesIO(fileobj_or_bytes) @@ -341,6 +537,7 @@ def write_member( tar_info = self._tar.gettarinfo(arcname=member_fn, fileobj=fileobj_or_bytes) self._tar.addfile(tar_info, fileobj=fileobj_or_bytes) + self._members[tar_info.name] = tar_info def members(self): return self._tar.getnames() @@ -359,9 +556,17 @@ def __init__(self, archive_fn: Union[str, pathlib.Path], mode: str = "r"): def members(self): return self._zip.namelist() - def open(self, member_fn, mode="r") -> IO: + def open(self, member_fn: str, mode="r", compress_hint=True) -> IO: + if mode == "w" and not compress_hint: + # Disable compression + member = zipfile.ZipInfo(member_fn) + member.compress_type = zipfile.ZIP_STORED + else: + # Let ZipFile.open select compression and compression level + member = member_fn + try: - return self._zip.open(member_fn, mode) + return self._zip.open(member, mode) except KeyError as exc: raise MemberNotFoundError( f"{member_fn} not in {self._zip.filename}" diff --git a/tests/test_archive.py b/tests/test_archive.py index e02e3c7..5fa6999 100644 --- a/tests/test_archive.py +++ b/tests/test_archive.py @@ -1,4 +1,3 @@ -import contextlib import io import pathlib import tarfile @@ -13,23 +12,19 @@ from pyecotaxa.archive import Archive, MemberNotFoundError, read_tsv, write_tsv -@pytest.mark.parametrize("enforce_types", [True, False]) @pytest.mark.parametrize("type_header", [True, False]) -def test_read_tsv(enforce_types, type_header): +def test_read_tsv(type_header): if type_header: file_content = "a\tb\tc\td\n[t]\t[f]\t[t]\t[t]\n1\t2.0\ta\t\n3\t4.0\tb\t" else: file_content = "a\tb\tc\td\n1\t2.0\ta\t\n3\t4.0\tb\t" - with contextlib.ExitStack() as ctx: - if enforce_types and not type_header: - ctx.enter_context(pytest.warns(UserWarning)) - dataframe = read_tsv(StringIO(file_content), enforce_types=enforce_types) + dataframe = read_tsv(StringIO(file_content)) assert len(dataframe) == 2 assert list(dataframe.columns) == ["a", "b", "c", "d"] - if type_header and enforce_types: + if type_header: assert [dt.kind for dt in dataframe.dtypes] == ["O", "f", "O", "O"] assert_series_equal(dataframe["d"], pd.Series(["", ""]), check_names=False) else: @@ -39,25 +34,21 @@ def test_read_tsv(enforce_types, type_header): ) -@pytest.mark.parametrize("enforce_types", [True, False]) @pytest.mark.parametrize("type_header", [True, False]) -def test_read_tsv_usecols(enforce_types, type_header): +def test_read_tsv_usecols(type_header): if type_header: file_content = "a\tb\tc\td\n[t]\t[f]\t[t]\t[t]\n1\t2.0\ta\t\n3\t4.0\tb\t" else: file_content = "a\tb\tc\td\n1\t2.0\ta\t\n3\t4.0\tb\t" - with contextlib.ExitStack() as ctx: - if enforce_types and not type_header: - ctx.enter_context(pytest.warns(UserWarning)) - dataframe = read_tsv( - StringIO(file_content), enforce_types=enforce_types, usecols=("a", "b") - ) + dataframe = read_tsv( + StringIO(file_content), usecols=("a", "b") + ) assert len(dataframe) == 2 assert list(dataframe.columns) == ["a", "b"] - if type_header and enforce_types: + if type_header: assert [dt.kind for dt in dataframe.dtypes] == ["O", "f"] else: assert [dt.kind for dt in dataframe.dtypes] == ["i", "f"] @@ -66,15 +57,30 @@ def test_read_tsv_usecols(enforce_types, type_header): @pytest.mark.parametrize("type_header", [True, False]) def test_write_tsv(type_header): dataframe = pd.DataFrame( - {"i": [1, 2, 3], "O": ["a", "b", "c"], "f": [1.0, 2.0, 3.0]} + { + "i": [1, 2, 3], + "O": ["a", "b", "c"], + "f": [1.0, 2.0, 3.0], + "object_annotation_category_id": pd.array([4, None, 6], dtype="Int64"), + } ) + dataframe_orig = dataframe.copy(deep=True) content = write_tsv(dataframe, type_header=type_header) + # Check that dataframe was not altered + assert_frame_equal(dataframe, dataframe_orig) + if type_header: - assert content == "i\tO\tf\n[f]\t[t]\t[f]\n1\ta\t1.0\n2\tb\t2.0\n3\tc\t3.0\n" + assert ( + content + == "i\tO\tf\tobject_annotation_category_id\n[f]\t[t]\t[f]\t[f]\n1\ta\t1.0\t4\n2\tb\t2.0\t\n3\tc\t3.0\t6\n" + ) else: - assert content == "i\tO\tf\n1\ta\t1.0\n2\tb\t2.0\n3\tc\t3.0\n" + assert ( + content + == "i\tO\tf\tobject_annotation_category_id\n1\ta\t1.0\t4\n2\tb\t2.0\t\n3\tc\t3.0\t6\n" + ) # Check round tripping dataframe2 = read_tsv(StringIO(content)) @@ -84,7 +90,7 @@ def test_write_tsv(type_header): def test_empty_str_column(): file_content = "a\tb\tc\n[t]\t[f]\t[t]\n\t2.0\ta" - dataframe = read_tsv(StringIO(file_content), enforce_types=True) + dataframe = read_tsv(StringIO(file_content)) assert len(dataframe) == 1 assert [dt.kind for dt in dataframe.dtypes] == ["O", "f", "O"]