Skip to content

Commit

Permalink
fix: support passing though dm, nicer typing
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
pre-commit-ci[bot] authored and henryiii committed Nov 28, 2023
1 parent c9b6f7b commit 0a7acc3
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 33 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.1.5"
rev: "v0.1.6"
hooks:
- id: ruff
args: ["--fix", "--show-fixes"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.7.1
hooks:
- id: mypy
files: '^src/decaylanguage/(decay|dec|utils)/'
Expand Down
29 changes: 16 additions & 13 deletions src/decaylanguage/dec/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@

from .. import data
from .._compat.typing import Self
from ..decay.decay import _expand_decay_modes
from ..decay.decay import DecayModeDict, _expand_decay_modes
from ..utils import charge_conjugate_name
from .enums import PhotosEnum

Expand Down Expand Up @@ -747,7 +747,7 @@ def list_decay_modes(self, mother: str, pdg_name: bool = False) -> list[list[str

def _decay_mode_details(
self, decay_mode: Tree, display_photos_keyword: bool = True
) -> tuple[float, list[str], str, str | list[str | Any]]:
) -> DecayModeDict:
"""
Parse a decay mode (Tree instance)
and return the relevant bits of information in it.
Expand All @@ -768,7 +768,9 @@ def _decay_mode_details(
if display_photos_keyword and list(decay_mode.find_data("photos")):
model = "PHOTOS " + model

return (bf, fsp_names, model, model_params)
return DecayModeDict(
bf=bf, fs=fsp_names, model=model, model_params=model_params
)

def print_decay_modes(
self,
Expand Down Expand Up @@ -864,11 +866,13 @@ def print_decay_modes(

ls_dict = {}
for dm in dms:
bf, fsp_names, model, model_params = self._decay_mode_details(
dm, display_photos_keyword
dmdict = self._decay_mode_details(dm, display_photos_keyword)
model_params = [str(i) for i in dmdict["model_params"]]
ls_dict[dmdict["bf"]] = (
dmdict["fs"],
dmdict["model"],
model_params,
)
model_params = [str(i) for i in model_params]
ls_dict[bf] = (fsp_names, model, model_params)

dec_details = list(ls_dict.values())
ls_attrs_aligned = list(
Expand Down Expand Up @@ -937,7 +941,7 @@ def build_decay_chains(
self,
mother: str,
stable_particles: list[str] | set[str] | tuple[str] | tuple[()] = (),
) -> dict[str, list[dict[str, float | str | list[Any]]]]:
) -> dict[str, list[DecayModeDict]]:
"""
Iteratively build the entire decay chains of a given mother particle,
optionally considering, on the fly, certain particles as stable.
Expand Down Expand Up @@ -992,14 +996,12 @@ def build_decay_chains(
>>> p.build_decay_chains('D+', stable_particles=['pi0']) # doctest: +SKIP
{'D+': [{'bf': 1.0, 'fs': ['K-', 'pi+', 'pi+', 'pi0'], 'model': 'PHSP', 'model_params': ''}]}
"""
keys = ("bf", "fs", "model", "model_params")

info = []
for dm in self._find_decay_modes(mother):
list_dm_details = self._decay_mode_details(dm, display_photos_keyword=False)
d = dict(zip(keys, list_dm_details))
d = self._decay_mode_details(dm, display_photos_keyword=False)

for i, fs in enumerate(d["fs"]): # type: ignore[arg-type, var-annotated]
for i, fs in enumerate(d["fs"]):
if fs in stable_particles:
continue

Expand All @@ -1008,14 +1010,15 @@ def build_decay_chains(
# if fs does not have decays defined in the parsed file
# _n_dms = len(self._find_decay_modes(fs))

assert isinstance(fs, str)
_info = self.build_decay_chains(fs, stable_particles)
d["fs"][i] = _info # type: ignore[index]
except DecayNotFound:
pass

info.append(d)

return {mother: info} # type: ignore[dict-item]
return {mother: info}

def __repr__(self) -> str:
if self._parsed_dec_file is not None:
Expand Down
44 changes: 27 additions & 17 deletions src/decaylanguage/decay/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from __future__ import annotations

from collections import Counter
from collections.abc import Sequence
from copy import deepcopy
from itertools import product
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List

from particle import PDGID, ParticleNotFound
from particle.converters import EvtGenName2PDGIDBiMap
from particle.exceptions import MatchingIDNotFound

from .._compat.typing import Self
from .._compat.typing import Self, TypedDict
from ..utils import DescriptorFormat, charge_conjugate_name

if TYPE_CHECKING:
Expand All @@ -24,6 +25,16 @@
CounterStr = Counter


class DecayModeDict(TypedDict):
bf: float
fs: Sequence[str | DecayChainDict]
model: str
model_params: str | Sequence[str | Any]


DecayChainDict = Dict[str, List[DecayModeDict]]


class DaughtersDict(CounterStr):
"""
Class holding a decay final state as a dictionary.
Expand Down Expand Up @@ -187,8 +198,12 @@ class DecayMode:
def __init__(
self,
bf: float = 0,
daughters: None
| (DaughtersDict | dict[str, int] | list[str] | tuple[str] | str) = None,
daughters: DaughtersDict
| dict[str, int]
| list[str]
| tuple[str]
| str
| None = None,
**info: Any,
) -> None:
"""
Expand Down Expand Up @@ -241,6 +256,8 @@ def __init__(
True
"""
self.bf = bf
if daughters is None and "fs" in info:
self.daughters = DaughtersDict(info.pop("fs"))
self.daughters = DaughtersDict(daughters)

self.metadata: dict[str, str | None] = {"model": "", "model_params": ""}
Expand All @@ -249,7 +266,7 @@ def __init__(
@classmethod
def from_dict(
cls,
decay_mode_dict: dict[str, int | float | str | list[str]],
decay_mode_dict: DecayModeDict,
) -> Self:
"""
Constructor from a dictionary of the form
Expand Down Expand Up @@ -285,13 +302,10 @@ def from_dict(
dm = deepcopy(decay_mode_dict)

# Ensure the input dict has the 2 required keys 'bf' and 'fs'
try:
bf = dm.pop("bf")
daughters = dm.pop("fs")
except KeyError as e:
raise RuntimeError("Input not in the expected format!") from e
if not dm.keys() >= {"bf", "fs"}:
raise RuntimeError("Input not in the expected format! Needs 'bf' and 'fs'")

return cls(bf=bf, daughters=daughters, **dm) # type: ignore[arg-type]
return cls(**dm)

@classmethod
def from_pdgids(
Expand Down Expand Up @@ -436,10 +450,6 @@ def __str__(self) -> str:
return repr(self)


DecayModeDict = Dict[str, Union[float, str, List[Any]]]
DecayChainDict = Dict[str, List[DecayModeDict]]


def _has_no_subdecay(ds: list[Any]) -> bool:
"""
Internal function to check whether the input list
Expand Down Expand Up @@ -894,9 +904,9 @@ def _print(
for i_decay in decay_dict[mother]:
print(prefix, arrow if depth > 0 else "", mother, sep="") # noqa: T201
fsps = i_decay["fs"]
n = len(list(fsps)) # type: ignore[arg-type]
n = len(list(fsps))
depth += 1
for j, fsp in enumerate(fsps): # type: ignore[arg-type]
for j, fsp in enumerate(fsps):
prefix = bar if (link and depth > 1) else ""
if last:
prefix = prefix + " " * indent * (depth - 1) + " "
Expand Down
7 changes: 6 additions & 1 deletion tests/dec/test_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,12 @@ def test_decay_mode_details():
p.parse()

tree_Dp = p._find_decay_modes("D+")[0]
output = (1.0, ["K-", "pi+", "pi+", "pi0"], "PHSP", "")
output = {
"bf": 1.0,
"fs": ["K-", "pi+", "pi+", "pi0"],
"model": "PHSP",
"model_params": "",
}
assert p._decay_mode_details(tree_Dp, display_photos_keyword=False) == output


Expand Down

0 comments on commit 0a7acc3

Please sign in to comment.