Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fmt] Format topology module and tests #4849

Merged
merged 12 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 51 additions & 27 deletions package/MDAnalysis/topology/CRDParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@
Type and mass are not longer guessed here. Until 3.0 these will still be
set by default through through universe.guess_TopologyAttrs() API.
"""
format = 'CRD'

format = "CRD"

def parse(self, **kwargs):
"""Create the Topology object
Expand All @@ -102,8 +103,10 @@
----
Could use the resnum and temp factor better
"""
extformat = FORTRANReader('2I10,2X,A8,2X,A8,3F20.10,2X,A8,2X,A8,F20.10')
stdformat = FORTRANReader('2I5,1X,A4,1X,A4,3F10.5,1X,A4,1X,A4,F10.5')
extformat = FORTRANReader(
"2I10,2X,A8,2X,A8,3F20.10,2X,A8,2X,A8,F20.10"
)
stdformat = FORTRANReader("2I5,1X,A4,1X,A4,3F10.5,1X,A4,1X,A4,F10.5")

atomids = []
atomnames = []
Expand All @@ -116,21 +119,36 @@
with openany(self.filename) as crd:
for linenum, line in enumerate(crd):
# reading header
if line.split()[0] == '*':
if line.split()[0] == "*":
continue
elif line.split()[-1] == 'EXT' and int(line.split()[0]):
elif line.split()[-1] == "EXT" and int(line.split()[0]):
r = extformat
continue
elif line.split()[0] == line.split()[-1] and line.split()[0] != '*':
elif (
line.split()[0] == line.split()[-1]
and line.split()[0] != "*"
):
r = stdformat
continue
# anything else should be an atom
try:
(serial, resnum, resName, name,
x, y, z, segid, resid, tempFactor) = r.read(line)
(
serial,
resnum,
resName,
name,
x,
y,
z,
segid,
resid,
tempFactor,
) = r.read(line)
except Exception:
errmsg = (f"Check CRD format at line {linenum + 1}: "
f"{line.rstrip()}")
errmsg = (

Check warning on line 148 in package/MDAnalysis/topology/CRDParser.py

View check run for this annotation

Codecov / codecov/patch

package/MDAnalysis/topology/CRDParser.py#L148

Added line #L148 was not covered by tests
f"Check CRD format at line {linenum + 1}: "
f"{line.rstrip()}"
)
raise ValueError(errmsg) from None

atomids.append(serial)
Expand All @@ -150,22 +168,28 @@
resnums = np.array(resnums, dtype=np.int32)
segids = np.array(segids, dtype=object)

atom_residx, (res_resids, res_resnames, res_resnums, res_segids) = change_squash(
(resids, resnames), (resids, resnames, resnums, segids))
res_segidx, (seg_segids,) = change_squash(
(res_segids,), (res_segids,))

top = Topology(len(atomids), len(res_resids), len(seg_segids),
attrs=[
Atomids(atomids),
Atomnames(atomnames),
Tempfactors(tempfactors),
Resids(res_resids),
Resnames(res_resnames),
Resnums(res_resnums),
Segids(seg_segids),
],
atom_resindex=atom_residx,
residue_segindex=res_segidx)
atom_residx, (res_resids, res_resnames, res_resnums, res_segids) = (
change_squash(
(resids, resnames), (resids, resnames, resnums, segids)
)
)
res_segidx, (seg_segids,) = change_squash((res_segids,), (res_segids,))

top = Topology(
len(atomids),
len(res_resids),
len(seg_segids),
attrs=[
Atomids(atomids),
Atomnames(atomnames),
Tempfactors(tempfactors),
Resids(res_resids),
Resnames(res_resnames),
Resnums(res_resnums),
Segids(seg_segids),
],
atom_resindex=atom_residx,
residue_segindex=res_segidx,
)

return top
20 changes: 10 additions & 10 deletions package/MDAnalysis/topology/DLPolyParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class ConfigParser(TopologyReaderBase):
Removed type and mass guessing (attributes guessing takes place now
through universe.guess_TopologyAttrs() API).
"""
format = 'CONFIG'

format = "CONFIG"

def parse(self, **kwargs):
with openany(self.filename) as inf:
Expand Down Expand Up @@ -117,10 +118,9 @@ def parse(self, **kwargs):
Atomids(ids),
Resids(np.array([1])),
Resnums(np.array([1])),
Segids(np.array(['SYSTEM'], dtype=object)),
Segids(np.array(["SYSTEM"], dtype=object)),
]
top = Topology(n_atoms, 1, 1,
attrs=attrs)
top = Topology(n_atoms, 1, 1, attrs=attrs)

return top

Expand All @@ -130,7 +130,8 @@ class HistoryParser(TopologyReaderBase):

.. versionadded:: 0.10.1
"""
format = 'HISTORY'

format = "HISTORY"

def parse(self, **kwargs):
with openany(self.filename) as inf:
Expand All @@ -143,10 +144,10 @@ def parse(self, **kwargs):
line = inf.readline()
while not (len(line.split()) == 4 or len(line.split()) == 5):
line = inf.readline()
if line == '':
if line == "":
raise EOFError("End of file reached when reading HISTORY.")

while line and not line.startswith('timestep'):
while line and not line.startswith("timestep"):
name = line[:8].strip()
names.append(name)
try:
Expand Down Expand Up @@ -179,9 +180,8 @@ def parse(self, **kwargs):
Atomids(ids),
Resids(np.array([1])),
Resnums(np.array([1])),
Segids(np.array(['SYSTEM'], dtype=object)),
Segids(np.array(["SYSTEM"], dtype=object)),
]
top = Topology(n_atoms, 1, 1,
attrs=attrs)
top = Topology(n_atoms, 1, 1, attrs=attrs)

return top
99 changes: 54 additions & 45 deletions package/MDAnalysis/topology/DMSParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@

class Atomnums(AtomAttr):
"""The number for each Atom"""
attrname = 'atomnums'
singular = 'atomnum'

attrname = "atomnums"
singular = "atomnum"


class DMSParser(TopologyReaderBase):
Expand Down Expand Up @@ -100,7 +101,8 @@ class DMSParser(TopologyReaderBase):
through universe.guess_TopologyAttrs() API).

"""
format = 'DMS'

format = "DMS"

def parse(self, **kwargs):
"""Parse DMS file *filename* and return the Topology object"""
Expand All @@ -121,28 +123,29 @@ def dict_factory(cursor, row):
attrs = {}

# Row factories for different data types
facs = {np.int32: lambda c, r: r[0],
np.float32: lambda c, r: r[0],
object: lambda c, r: str(r[0].strip())}
facs = {
np.int32: lambda c, r: r[0],
np.float32: lambda c, r: r[0],
object: lambda c, r: str(r[0].strip()),
}

with sqlite3.connect(self.filename) as con:
# Selecting single column, so just strip tuple
for attrname, dt in [
('id', np.int32),
('anum', np.int32),
('mass', np.float32),
('charge', np.float32),
('name', object),
('resname', object),
('resid', np.int32),
('chain', object),
('segid', object),
("id", np.int32),
("anum", np.int32),
("mass", np.float32),
("charge", np.float32),
("name", object),
("resname", object),
("resid", np.int32),
("chain", object),
("segid", object),
]:
try:
cur = con.cursor()
cur.row_factory = facs[dt]
cur.execute('SELECT {} FROM particle'
''.format(attrname))
cur.execute("SELECT {} FROM particle" "".format(attrname))
vals = cur.fetchall()
except sqlite3.DatabaseError:
errmsg = "Failed reading the atoms from DMS Database"
Expand All @@ -152,7 +155,7 @@ def dict_factory(cursor, row):

try:
cur.row_factory = dict_factory
cur.execute('SELECT * FROM bond')
cur.execute("SELECT * FROM bond")
bonds = cur.fetchall()
except sqlite3.DatabaseError:
errmsg = "Failed reading the bonds from DMS Database"
Expand All @@ -161,44 +164,46 @@ def dict_factory(cursor, row):
bondlist = []
bondorder = {}
for b in bonds:
desc = tuple(sorted([b['p0'], b['p1']]))
desc = tuple(sorted([b["p0"], b["p1"]]))
bondlist.append(desc)
bondorder[desc] = b['order']
attrs['bond'] = bondlist
attrs['bondorder'] = bondorder
bondorder[desc] = b["order"]
attrs["bond"] = bondlist
attrs["bondorder"] = bondorder

topattrs = []
# Bundle in Atom level objects
for attr, cls in [
('id', Atomids),
('anum', Atomnums),
('mass', Masses),
('charge', Charges),
('name', Atomnames),
('chain', ChainIDs),
("id", Atomids),
("anum", Atomnums),
("mass", Masses),
("charge", Charges),
("name", Atomnames),
("chain", ChainIDs),
]:
topattrs.append(cls(attrs[attr]))

# Residues
atom_residx, (res_resids,
res_resnums,
res_resnames,
res_segids) = change_squash(
(attrs['resid'], attrs['resname'], attrs['segid']),
(attrs['resid'],
attrs['resid'].copy(),
attrs['resname'],
attrs['segid']),
atom_residx, (res_resids, res_resnums, res_resnames, res_segids) = (
change_squash(
(attrs["resid"], attrs["resname"], attrs["segid"]),
(
attrs["resid"],
attrs["resid"].copy(),
attrs["resname"],
attrs["segid"],
),
)
)

n_residues = len(res_resids)
topattrs.append(Resids(res_resids))
topattrs.append(Resnums(res_resnums))
topattrs.append(Resnames(res_resnames))

if any(res_segids) and not any(val is None for val in res_segids):
res_segidx, (res_segids,) = change_squash((res_segids,),
(res_segids,))
res_segidx, (res_segids,) = change_squash(
(res_segids,), (res_segids,)
)

uniq_seg = np.unique(res_segids)
idx2seg = {idx: res_segids[idx] for idx in res_segidx}
Expand All @@ -211,14 +216,18 @@ def dict_factory(cursor, row):
topattrs.append(Segids(res_segids))
else:
n_segments = 1
topattrs.append(Segids(np.array(['SYSTEM'], dtype=object)))
topattrs.append(Segids(np.array(["SYSTEM"], dtype=object)))
res_segidx = None

topattrs.append(Bonds(attrs['bond']))
topattrs.append(Bonds(attrs["bond"]))

top = Topology(len(attrs['id']), n_residues, n_segments,
attrs=topattrs,
atom_resindex=atom_residx,
residue_segindex=res_segidx)
top = Topology(
len(attrs["id"]),
n_residues,
n_segments,
attrs=topattrs,
atom_resindex=atom_residx,
residue_segindex=res_segidx,
)

return top
Loading
Loading