diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aef58f9e..3811f55f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,13 +24,13 @@ repos: - id: end-of-file-fixer - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.1.5" + rev: "v0.1.6" hooks: - id: ruff args: ["--fix", "--show-fixes"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.7.1 hooks: - id: mypy files: '^src/decaylanguage/(decay|dec|utils)/' diff --git a/src/decaylanguage/dec/dec.py b/src/decaylanguage/dec/dec.py index ca6fef3f..c2bfebef 100644 --- a/src/decaylanguage/dec/dec.py +++ b/src/decaylanguage/dec/dec.py @@ -55,7 +55,7 @@ from .. import data from .._compat.typing import Self -from ..decay.decay import _expand_decay_modes +from ..decay.decay import DecayModeDict, _expand_decay_modes from ..utils import charge_conjugate_name from .enums import PhotosEnum @@ -747,7 +747,7 @@ def list_decay_modes(self, mother: str, pdg_name: bool = False) -> list[list[str def _decay_mode_details( self, decay_mode: Tree, display_photos_keyword: bool = True - ) -> tuple[float, list[str], str, str | list[str | Any]]: + ) -> DecayModeDict: """ Parse a decay mode (Tree instance) and return the relevant bits of information in it. @@ -768,7 +768,9 @@ def _decay_mode_details( if display_photos_keyword and list(decay_mode.find_data("photos")): model = "PHOTOS " + model - return (bf, fsp_names, model, model_params) + return DecayModeDict( + bf=bf, fs=fsp_names, model=model, model_params=model_params + ) def print_decay_modes( self, @@ -864,11 +866,13 @@ def print_decay_modes( ls_dict = {} for dm in dms: - bf, fsp_names, model, model_params = self._decay_mode_details( - dm, display_photos_keyword + dmdict = self._decay_mode_details(dm, display_photos_keyword) + model_params = [str(i) for i in dmdict["model_params"]] + ls_dict[dmdict["bf"]] = ( + dmdict["fs"], + dmdict["model"], + model_params, ) - model_params = [str(i) for i in model_params] - ls_dict[bf] = (fsp_names, model, model_params) dec_details = list(ls_dict.values()) ls_attrs_aligned = list( @@ -937,7 +941,7 @@ def build_decay_chains( self, mother: str, stable_particles: list[str] | set[str] | tuple[str] | tuple[()] = (), - ) -> dict[str, list[dict[str, float | str | list[Any]]]]: + ) -> dict[str, list[DecayModeDict]]: """ Iteratively build the entire decay chains of a given mother particle, optionally considering, on the fly, certain particles as stable. @@ -992,14 +996,12 @@ def build_decay_chains( >>> p.build_decay_chains('D+', stable_particles=['pi0']) # doctest: +SKIP {'D+': [{'bf': 1.0, 'fs': ['K-', 'pi+', 'pi+', 'pi0'], 'model': 'PHSP', 'model_params': ''}]} """ - keys = ("bf", "fs", "model", "model_params") info = [] for dm in self._find_decay_modes(mother): - list_dm_details = self._decay_mode_details(dm, display_photos_keyword=False) - d = dict(zip(keys, list_dm_details)) + d = self._decay_mode_details(dm, display_photos_keyword=False) - for i, fs in enumerate(d["fs"]): # type: ignore[arg-type, var-annotated] + for i, fs in enumerate(d["fs"]): if fs in stable_particles: continue @@ -1008,6 +1010,7 @@ def build_decay_chains( # if fs does not have decays defined in the parsed file # _n_dms = len(self._find_decay_modes(fs)) + assert isinstance(fs, str) _info = self.build_decay_chains(fs, stable_particles) d["fs"][i] = _info # type: ignore[index] except DecayNotFound: @@ -1015,7 +1018,7 @@ def build_decay_chains( info.append(d) - return {mother: info} # type: ignore[dict-item] + return {mother: info} def __repr__(self) -> str: if self._parsed_dec_file is not None: diff --git a/src/decaylanguage/decay/decay.py b/src/decaylanguage/decay/decay.py index 7bfc4d0a..d500acec 100644 --- a/src/decaylanguage/decay/decay.py +++ b/src/decaylanguage/decay/decay.py @@ -7,15 +7,16 @@ from __future__ import annotations from collections import Counter +from collections.abc import Sequence from copy import deepcopy from itertools import product -from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List from particle import PDGID, ParticleNotFound from particle.converters import EvtGenName2PDGIDBiMap from particle.exceptions import MatchingIDNotFound -from .._compat.typing import Self +from .._compat.typing import Self, TypedDict from ..utils import DescriptorFormat, charge_conjugate_name if TYPE_CHECKING: @@ -24,6 +25,16 @@ CounterStr = Counter +class DecayModeDict(TypedDict): + bf: float + fs: Sequence[str | DecayChainDict] + model: str + model_params: str | Sequence[str | Any] + + +DecayChainDict = Dict[str, List[DecayModeDict]] + + class DaughtersDict(CounterStr): """ Class holding a decay final state as a dictionary. @@ -187,8 +198,12 @@ class DecayMode: def __init__( self, bf: float = 0, - daughters: None - | (DaughtersDict | dict[str, int] | list[str] | tuple[str] | str) = None, + daughters: DaughtersDict + | dict[str, int] + | list[str] + | tuple[str] + | str + | None = None, **info: Any, ) -> None: """ @@ -241,6 +256,8 @@ def __init__( True """ self.bf = bf + if daughters is None and "fs" in info: + self.daughters = DaughtersDict(info.pop("fs")) self.daughters = DaughtersDict(daughters) self.metadata: dict[str, str | None] = {"model": "", "model_params": ""} @@ -249,7 +266,7 @@ def __init__( @classmethod def from_dict( cls, - decay_mode_dict: dict[str, int | float | str | list[str]], + decay_mode_dict: DecayModeDict, ) -> Self: """ Constructor from a dictionary of the form @@ -285,13 +302,10 @@ def from_dict( dm = deepcopy(decay_mode_dict) # Ensure the input dict has the 2 required keys 'bf' and 'fs' - try: - bf = dm.pop("bf") - daughters = dm.pop("fs") - except KeyError as e: - raise RuntimeError("Input not in the expected format!") from e + if not dm.keys() >= {"bf", "fs"}: + raise RuntimeError("Input not in the expected format! Needs 'bf' and 'fs'") - return cls(bf=bf, daughters=daughters, **dm) # type: ignore[arg-type] + return cls(**dm) @classmethod def from_pdgids( @@ -436,10 +450,6 @@ def __str__(self) -> str: return repr(self) -DecayModeDict = Dict[str, Union[float, str, List[Any]]] -DecayChainDict = Dict[str, List[DecayModeDict]] - - def _has_no_subdecay(ds: list[Any]) -> bool: """ Internal function to check whether the input list @@ -894,9 +904,9 @@ def _print( for i_decay in decay_dict[mother]: print(prefix, arrow if depth > 0 else "", mother, sep="") # noqa: T201 fsps = i_decay["fs"] - n = len(list(fsps)) # type: ignore[arg-type] + n = len(list(fsps)) depth += 1 - for j, fsp in enumerate(fsps): # type: ignore[arg-type] + for j, fsp in enumerate(fsps): prefix = bar if (link and depth > 1) else "" if last: prefix = prefix + " " * indent * (depth - 1) + " " diff --git a/tests/dec/test_dec.py b/tests/dec/test_dec.py index 198ec571..bf488c5b 100644 --- a/tests/dec/test_dec.py +++ b/tests/dec/test_dec.py @@ -413,7 +413,12 @@ def test_decay_mode_details(): p.parse() tree_Dp = p._find_decay_modes("D+")[0] - output = (1.0, ["K-", "pi+", "pi+", "pi0"], "PHSP", "") + output = { + "bf": 1.0, + "fs": ["K-", "pi+", "pi+", "pi0"], + "model": "PHSP", + "model_params": "", + } assert p._decay_mode_details(tree_Dp, display_photos_keyword=False) == output