Skip to content

Commit

Permalink
Fixed bug that made --print output inconsitent (#112)
Browse files Browse the repository at this point in the history
* Added tests for --no-hydrogen/filters and --print is consistency

* Discovered and fixed printed structure, when filtering atoms (#111)

* Updated test coverage to cover bad --print order, Restructured main to be consistent with print

* Added new --print options for printing only atoms used in RMSD calculation/rotation

* Removed out-commented code and added clarification comment

---------

Co-authored-by: Takafumi Shiraogawa <[email protected]>
  • Loading branch information
charnley and takafumi-shiraogawa authored Nov 23, 2024
1 parent 69f9239 commit a543e55
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 66 deletions.
130 changes: 77 additions & 53 deletions rmsd/calculate_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,9 @@ def set_coordinates(atoms: ndarray, V: ndarray, title: str = "", decimals: int =
"""
N, D = V.shape

if N != len(atoms):
raise ValueError("Mismatch between expected atoms and coordinate size")

fmt = "{:<2}" + (" {:15." + str(decimals) + "f}") * 3

out = list()
Expand Down Expand Up @@ -1818,6 +1821,14 @@ def parse_arguments(arguments: Optional[Union[str, List[str]]] = None) -> argpar
),
)

parser.add_argument(
"--print-only-rmsd-atoms",
action="store_true",
help=(
"Print only atoms used in finding optimal RMSD calculation (relevant if filtering e.g. Hydrogens)"
),
)

args = parser.parse_args(arguments)

# Check illegal combinations
Expand Down Expand Up @@ -1877,35 +1888,35 @@ def parse_arguments(arguments: Optional[Union[str, List[str]]] = None) -> argpar
return args


def main(args: Optional[List[str]] = None):
def main(args: Optional[List[str]] = None) -> str:

# Parse arguments
settings = parse_arguments(args)

# As default, load the extension as format
# Parse pdb.gz and xyz.gz as pdb and xyz formats
p_all_atoms, p_all = get_coordinates(
p_atoms, p_coord = get_coordinates(
settings.structure_a,
settings.format,
is_gzip=settings.format_is_gzip,
return_atoms_as_int=True,
)

q_all_atoms, q_all = get_coordinates(
q_atoms, q_coord = get_coordinates(
settings.structure_b,
settings.format,
is_gzip=settings.format_is_gzip,
return_atoms_as_int=True,
)

p_size = p_all.shape[0]
q_size = q_all.shape[0]
p_size = p_coord.shape[0]
q_size = q_coord.shape[0]

if not p_size == q_size:
print("error: Structures not same size")
sys.exit()

if np.count_nonzero(p_all_atoms != q_all_atoms) and not settings.reorder:
if np.count_nonzero(p_atoms != q_atoms) and not settings.reorder:
msg = """
error: Atoms are not in the same order.
Expand All @@ -1923,12 +1934,11 @@ def main(args: Optional[List[str]] = None):
# Set local view
p_view: Optional[ndarray] = None
q_view: Optional[ndarray] = None
use_view: bool = True

if settings.ignore_hydrogen:
assert type(p_all_atoms[0]) != str
assert type(q_all_atoms[0]) != str
p_view = np.where(p_all_atoms != 1) # type: ignore
q_view = np.where(q_all_atoms != 1) # type: ignore
p_view = np.where(p_atoms != 1) # type: ignore
q_view = np.where(q_atoms != 1) # type: ignore

elif settings.remove_idx:
index = np.array(list(set(range(p_size)) - set(settings.remove_idx)))
Expand All @@ -1939,26 +1949,27 @@ def main(args: Optional[List[str]] = None):
p_view = settings.add_idx
q_view = settings.add_idx

else:
use_view = False

# Set local view
if p_view is None:
p_coord = copy.deepcopy(p_all)
q_coord = copy.deepcopy(q_all)
p_atoms = copy.deepcopy(p_all_atoms)
q_atoms = copy.deepcopy(q_all_atoms)
if use_view:
p_coord_sub = copy.deepcopy(p_coord[p_view])
q_coord_sub = copy.deepcopy(q_coord[q_view])
p_atoms_sub = copy.deepcopy(p_atoms[p_view])
q_atoms_sub = copy.deepcopy(q_atoms[q_view])

else:
assert p_view is not None
assert q_view is not None
p_coord = copy.deepcopy(p_all[p_view])
q_coord = copy.deepcopy(q_all[q_view])
p_atoms = copy.deepcopy(p_all_atoms[p_view])
q_atoms = copy.deepcopy(q_all_atoms[q_view])
p_coord_sub = copy.deepcopy(p_coord)
q_coord_sub = copy.deepcopy(q_coord)
p_atoms_sub = copy.deepcopy(p_atoms)
q_atoms_sub = copy.deepcopy(q_atoms)

# Recenter to centroid
p_cent = centroid(p_coord)
q_cent = centroid(q_coord)
p_coord -= p_cent
q_coord -= q_cent
p_cent_sub = centroid(p_coord_sub)
q_cent_sub = centroid(q_coord_sub)
p_coord_sub -= p_cent_sub
q_coord_sub -= q_cent_sub

rmsd_method: RmsdCallable
reorder_method: Optional[ReorderCallable]
Expand All @@ -1985,7 +1996,7 @@ def main(args: Optional[List[str]] = None):
reorder_method = reorder_distance

# Save the resulting RMSD
result_rmsd = None
result_rmsd: Optional[float] = None

# Collect changes to be done on q coords
q_swap = None
Expand All @@ -1995,21 +2006,21 @@ def main(args: Optional[List[str]] = None):
if settings.use_reflections:

result_rmsd, q_swap, q_reflection, q_review = check_reflections(
p_atoms,
q_atoms,
p_coord,
q_coord,
p_atoms_sub,
q_atoms_sub,
p_coord_sub,
q_coord_sub,
reorder_method=reorder_method,
rmsd_method=rmsd_method,
)

elif settings.use_reflections_keep_stereo:

result_rmsd, q_swap, q_reflection, q_review = check_reflections(
p_atoms,
q_atoms,
p_coord,
q_coord,
p_atoms_sub,
q_atoms_sub,
p_coord_sub,
q_coord_sub,
reorder_method=reorder_method,
rmsd_method=rmsd_method,
keep_stereo=True,
Expand All @@ -2023,42 +2034,55 @@ def main(args: Optional[List[str]] = None):
# If there is a reorder, then apply before print
if q_review is not None:

q_all_atoms = q_all_atoms[q_review]
q_atoms = q_atoms[q_review]
q_coord = q_coord[q_review]
q_atoms_sub = q_atoms_sub[q_review]
q_coord_sub = q_coord_sub[q_review]

assert all(
p_atoms == q_atoms
p_atoms_sub == q_atoms_sub
), "error: Structure not aligned. Please submit bug report at http://github.com/charnley/rmsd"

# Calculate the RMSD value
if result_rmsd is None:
result_rmsd = rmsd_method(p_coord_sub, q_coord_sub)

# print result
if settings.output:

if q_swap is not None:
q_coord = q_coord[:, q_swap]
q_coord_sub = q_coord_sub[:, q_swap]

if q_reflection is not None:
q_coord = np.dot(q_coord, np.diag(q_reflection))
q_coord_sub = np.dot(q_coord_sub, np.diag(q_reflection))

q_coord -= centroid(q_coord)
U = kabsch(q_coord_sub, p_coord_sub)

# Rotate q coordinates
# TODO Should actually follow rotation method
q_coord = kabsch_rotate(q_coord, p_coord)
if settings.print_only_rmsd_atoms or not use_view:
q_coord_sub = np.dot(q_coord_sub, U)
q_coord_sub += p_cent_sub
return set_coordinates(
q_atoms_sub,
q_coord_sub,
title=f"Rotated '{settings.structure_b}' to match '{settings.structure_a}', with a RMSD of {result_rmsd:.8f}",
)

# center q on p's original coordinates
q_coord += p_cent
# Swap, reflect, rotate and re-center on the full atom and coordinate set
q_coord -= q_cent_sub

# done and done
xyz = set_coordinates(q_all_atoms, q_coord, title=f"{settings.structure_b} - modified")
return xyz
if q_swap is not None:
q_coord = q_coord[:, q_swap]

else:
if q_reflection is not None:
q_coord = np.dot(q_coord, np.diag(q_reflection))

if not result_rmsd:
result_rmsd = rmsd_method(p_coord, q_coord)
q_coord = np.dot(q_coord, U)
q_coord += p_cent_sub
return set_coordinates(
q_atoms,
q_coord,
title=f"Rotated {settings.structure_b} to match {settings.structure_a}, with RMSD of {result_rmsd:.8f}",
)

return result_rmsd
return str(result_rmsd)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions tests/resources/issue93/b.xyz
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
C -0.2769166422 -0.7099045701 -1.2865831838
C -0.0318894674 0.6967020886 -1.8794695568
C 0.6647089780 -1.0315870523 -0.0954617436
H 0.7080595367 -2.0732382850 -2.5032982174
H -1.3132498520 -0.6901271514 -0.9163939649
C 0.2611763277 -0.4666702466 1.2320987930
C -0.0191751268 -1.0867015822 2.4345274135
C -0.2291714337 1.0136818561 2.8036493128
H -0.8594037813 -2.4661239083 -2.2309783928
H 0.7080595367 -2.0732382850 -2.5032982174
H -1.3132498520 -0.6901271514 -0.9163939649
H 1.6889557061 -0.7206820514 -0.3611372211
H 0.6818980659 -2.1236321622 0.0197593401
H 0.2183924342 1.5973044617 0.7751919144
Expand Down
51 changes: 48 additions & 3 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from context import RESOURCE_PATH, call_main

import rmsd as rmsdlib
from rmsd.calculate_rmsd import get_coordinates_pdb, get_coordinates_xyz, get_coordinates_xyz_lines


def test_print_reflection_reorder() -> None:
Expand Down Expand Up @@ -59,16 +60,19 @@ def test_print_reflection_reorder() -> None:

# Main call print, check rmsd is still the same
# Note, that --print is translating b to a center
args = f"--use-reflections --reorder --print {filename_a} {filename_b}"
stdout = call_main(args.split())
_, coord = rmsdlib.get_coordinates_xyz_lines(stdout)
_args = f"--use-reflections --reorder --print {filename_a} {filename_b}"
_stdout: str = rmsdlib.main(_args.split())
atoms, coord = rmsdlib.get_coordinates_xyz_lines(_stdout.split("\n"), return_atoms_as_int=True)
coord -= rmsdlib.centroid(coord) # fix translation
print(coord)
print(atoms)
print(atoms_b)

rmsd_check1 = rmsdlib.kabsch_rmsd(coord, coord_a)
rmsd_check2 = rmsdlib.rmsd(coord, coord_a)
print(rmsd_check1)
print(rmsd_check2)
print(result_rmsd)
np.testing.assert_almost_equal(rmsd_check2, rmsd_check1)
np.testing.assert_almost_equal(rmsd_check2, result_rmsd)

Expand Down Expand Up @@ -136,3 +140,44 @@ def test_ignore() -> None:
rmsdlib.main(f"{filename_a} {filename_b} --remove-idx 0 5".split())

rmsdlib.main(f"{filename_a} {filename_b} --add-idx 0 1 2 3 4".split())


def test_print_match_no_hydrogen() -> None:

filename_a = RESOURCE_PATH / "CHEMBL3039407_order.xyz"
filename_b = RESOURCE_PATH / "CHEMBL3039407_order.xyz"

cmd = f"--no-hydrogen --print {filename_a} {filename_b}"
print(cmd)
out = rmsdlib.main(cmd.split()).split("\n")
atoms1, coord1 = get_coordinates_xyz_lines(out)

print(atoms1)
print(len(atoms1))

assert len(atoms1) == 60
assert coord1.shape
assert "H" in atoms1

cmd = f"--print {filename_a} {filename_b}"
out = rmsdlib.main(cmd.split()).split("\n")
atoms2, coord2 = get_coordinates_xyz_lines(out)

print(atoms2)
print(len(atoms2))

assert len(atoms2) == 60
assert coord2.shape
assert "H" in atoms2

out = rmsdlib.main(
f"--no-hydrogen --print --print-only-rmsd-atoms {filename_a} {filename_b}".split()
).split("\n")
atoms1, coord1 = get_coordinates_xyz_lines(out)

print(atoms1)
print(len(atoms1))

assert len(atoms1) == 30
assert coord1.shape
assert "H" not in atoms1
24 changes: 16 additions & 8 deletions tests/test_reorder_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,32 @@ def test_reorder_print_and_rmsd() -> None:

filename_a = RESOURCE_PATH / "issue93" / "a.xyz"
filename_b = RESOURCE_PATH / "issue93" / "b.xyz"
atoms_a, coord_a = get_coordinates_xyz(filename_a)
atoms_b, coord_b = get_coordinates_xyz(filename_b)

# Get reorder rmsd
args = ["--reorder", f"{filename_a}", f"{filename_b}"]
stdout = call_main(args)
rmsd_ab = float(stdout[-1])
rmsd_ab = float(rmsdlib.main(f"--reorder {filename_a} {filename_b}".split()))
print(rmsd_ab)
assert isinstance(rmsd_ab, float)

# Get printed structure
stdout = call_main(args + ["--print"])
stdout = rmsdlib.main(f"--reorder --print {filename_a} {filename_b}".split())
print(stdout)
atoms_c, coord_c = get_coordinates_xyz_lines(stdout.split("\n"))

atoms_a, coord_a = get_coordinates_xyz(filename_a)
atoms_c, coord_c = get_coordinates_xyz_lines(stdout)
coord_c -= rmsdlib.centroid(coord_c)
coord_a -= rmsdlib.centroid(coord_a)

print(coord_a)
print(atoms_a)
print(atoms_b)
print(atoms_c)

print(coord_a)
print(coord_b)
print(coord_c)
print(atoms_c)

assert (atoms_a == atoms_c).all()
assert (atoms_a != atoms_b).any()

rmsd_ac = rmsdlib.rmsd(coord_a, coord_c)
print(rmsd_ac)
Expand Down

0 comments on commit a543e55

Please sign in to comment.