Skip to content

Commit

Permalink
enabled axes in sort function, more natural for tuple
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Dec 20, 2024
1 parent 7c2df9a commit 866a97b
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 38 deletions.
61 changes: 31 additions & 30 deletions src/sisl/_core/_ufuncs_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def sort(
r"""Sort atoms in a nested fashion according to various criteria
There are many ways to sort a `Geometry`.
- by Cartesian coordinates, `axis`
- by Cartesian coordinates, `axes`/`axis`
- by lattice vectors, `lattice`
- by user defined vectors, `vector`
- by grouping atoms, `group`
Expand All @@ -192,15 +192,15 @@ def sort(
Default, all atoms will be sorted.
ret_atoms : bool, optional
return a list of list for the groups of atoms that have been sorted.
axis : int or tuple of int, optional
axis / axes : int or tuple of int, optional
sort coordinates according to Cartesian coordinates, if a tuple of
ints is passed it will be equivalent to ``sort(axis0=axis[0], axis1=axis[1])``.
ints is passed it will be equivalent to ``sort(axes=axes) == sort(axis0=axes[0], axis1=axes[1])``.
This behaves differently than `numpy.lexsort`!
lattice : int or tuple of int, optional
sort coordinates according to lattice vectors, if a tuple of
ints is passed it will be equivalent to ``sort(lattice0=lattice[0], lattice1=lattice[1])``.
Note that before sorting we multiply the fractional coordinates by the length of the
lattice vector. This ensures that `atol` is meaningful for both `axis` and `lattice` since
lattice vector. This ensures that `atol` is meaningful for both `axes` and `lattice` since
they will be on the same order of magnitude.
This behaves differently than `numpy.lexsort`!
vector : Coord, optional
Expand Down Expand Up @@ -248,8 +248,8 @@ def sort(
Notes
-----
The order of arguments is also the sorting order. ``sort(axis=0, lattice=0)`` is different
from ``sort(lattice=0, axis=0)``
The order of arguments is also the sorting order. ``sort(axes=0, lattice=0)`` is different
from ``sort(lattice=0, axes=0)``
All arguments may be suffixed with integers. This allows multiple keyword arguments
to control sorting algorithms
Expand All @@ -271,49 +271,49 @@ def sort(
Sort according to :math:`x` coordinate
>>> geom.sort(axis=0)
>>> geom.sort(axes=0)
Sort according to :math:`z`, then :math:`x` for each group created from first sort
>>> geom.sort(axis=(2, 0))
>>> geom.sort(axes=(2, 0))
Sort according to :math:`z`, then first lattice vector
>>> geom.sort(axis=2, lattice=0)
>>> geom.sort(axes=2, lattice=0)
Sort according to :math:`z` (ascending), then first lattice vector (descending)
>>> geom.sort(axis=2, ascend=False, lattice=0)
>>> geom.sort(axes=2, ascend=False, lattice=0)
Sort according to :math:`z` (descending), then first lattice vector (ascending)
Note how integer suffixes has no importance.
>>> geom.sort(ascend1=False, axis=2, ascend0=True, lattice=0)
>>> geom.sort(ascend1=False, axes=2, ascend0=True, lattice=0)
Sort only atoms ``range(1, 5)`` first by :math:`z`, then by first lattice vector
>>> geom.sort(axis=2, lattice=0, atoms=np.arange(1, 5))
>>> geom.sort(axes=2, lattice=0, atoms=np.arange(1, 5))
Sort two groups of atoms ``[range(1, 5), range(5, 10)]`` (individually) by :math:`z` coordinate
>>> geom.sort(axis=2, atoms=[np.arange(1, 5), np.arange(5, 10)])
>>> geom.sort(axes=2, atoms=[np.arange(1, 5), np.arange(5, 10)])
The returned sorting indices may be used for manual sorting. Note
however, that this requires one to perform a sorting for all atoms.
In such a case the following sortings are equal.
>>> geom0, atoms0 = geom.sort(axis=2, lattice=0, ret_atoms=True)
>>> _, atoms1 = geom.sort(axis=2, ret_atoms=True)
>>> geom0, atoms0 = geom.sort(axes=2, lattice=0, ret_atoms=True)
>>> _, atoms1 = geom.sort(axes=2, ret_atoms=True)
>>> geom1, atoms1 = geom.sort(lattice=0, atoms=atoms1, ret_atoms=True)
>>> geom2 = geom.sub(np.concatenate(atoms0))
>>> geom3 = geom.sub(np.concatenate(atoms1))
>>> assert geom0 == geom1
>>> assert geom0 == geom2
>>> assert geom0 == geom3
Default sorting is equivalent to ``axis=(0, 1, 2)``
Default sorting is equivalent to ``axes=(0, 1, 2)``
>>> assert geom.sort() == geom.sort(axis=(0, 1, 2))
>>> assert geom.sort() == geom.sort(axes=(0, 1, 2))
Sort along a user defined vector ``[2.2, 1., 0.]``
Expand All @@ -322,7 +322,7 @@ def sort(
Integer specification has no influence on the order of operations.
It is _always_ the keyword argument order that determines the operation.
>>> assert geom.sort(axis2=1, axis0=0, axis1=2) == geom.sort(axis=(1, 0, 2))
>>> assert geom.sort(axis2=1, axis0=0, axis1=2) == geom.sort(axes=(1, 0, 2))
Sort by atomic numbers
Expand All @@ -331,8 +331,8 @@ def sort(
One may group several elements together on an equal footing (``None`` means all non-mentioned elements)
The order of the groups are important (the first two are _not_ equal, the last three _are_ equal)
>>> geom.sort(group=('symbol', 'C'), axis=2) # C will be sorted along z
>>> geom.sort(axis=1, atoms='C', axis1=2) # all along y, then C sorted along z
>>> geom.sort(group=('symbol', 'C'), axes=2) # C will be sorted along z
>>> geom.sort(axes=1, atoms='C', axes1=2) # all along y, then C sorted along z
>>> geom.sort(group=('symbol', 'C', None)) # C, [B, N]
>>> geom.sort(group=('symbol', None, 'C')) # [B, N], C
>>> geom.sort(group=('symbol', ['N', 'B'], 'C')) # [B, N], C (B and N unaltered order)
Expand All @@ -355,20 +355,20 @@ def sort(
>>> x = np.arange(5) * 0.1
>>> x[3:] -= 0.095
y = z = np.zeros(5)
geom = si.Geometry(np.stack((x, y, z), axis=1))
geom = si.Geometry(np.stack((x, y, z), axes=1))
>>> geom.xyz[:, 0]
[0. 0.1 0.2 0.205 0.305]
In this case a high tolerance (``atol>0.005``) would group atoms 2 and 3
together
>>> geom.sort(atol=0.01, axis=0, ret_atoms=True)[1]
>>> geom.sort(atol=0.01, axes=0, ret_atoms=True)[1]
[[0], [1], [2, 3], [4]]
However, a very low tolerance will not find these two as atoms close
to each other.
>>> geom.sort(atol=0.001, axis=0, ret_atoms=True)[1]
>>> geom.sort(atol=0.001, axes=0, ret_atoms=True)[1]
[[0], [1], [2], [3], [4]]
"""

Expand Down Expand Up @@ -447,15 +447,16 @@ def _sort(val, atoms, **kwargs):
# Functions allowed by external users
funcs = dict()

def _axis(axis, atoms, **kwargs):
def _axes(axes, atoms, **kwargs):
"""Cartesian coordinate sort"""
if isinstance(axis, int):
axis = (axis,)
for ax in axis:
atoms = _sort(geometry.xyz[:, ax], atoms, **kwargs)
if isinstance(axes, int):
axes = (axes,)
for axis in axes:
atoms = _sort(geometry.xyz[:, axis], atoms, **kwargs)
return atoms

funcs["axis"] = _axis
funcs["axis"] = _axes
funcs["axes"] = _axes

def _lattice(lattice, atoms, **kwargs):
"""
Expand Down Expand Up @@ -657,7 +658,7 @@ def update_flag(kw, arg, val):

# In case the user just did geometry.sort, it will default to sort x, y, z
if len(kwargs) == 0:
kwargs["axis"] = (0, 1, 2)
kwargs["axes"] = (0, 1, 2)

for key_int, method in kwargs.items():
key = stripint(key_int)
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/_core/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4415,7 +4415,7 @@ def __call__(self, parser, ns, values, option_string=None):
nargs=1,
metavar="SORT",
action=Sort,
help='Semi-colon separated options for sort, please always encapsulate in quotation ["axis=0;descend;lattice=(1, 2);group=Z"].',
help='Semi-colon separated options for sort, please always encapsulate in quotation ["axes=0;descend;lattice=(1, 2);group=Z"].',
)

# Print some common information about the
Expand Down
12 changes: 6 additions & 6 deletions src/sisl/_core/tests/test_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,18 +1608,18 @@ def test_geometry_sort_simple():
atol = 1e-9

for i in [0, 1, 2]:
s = bi.sort(axis=i)
s = bi.sort(axes=i)
assert np.all(np.diff(s.xyz[:, i]) >= -atol)
s = bi.sort(lattice=i)
assert np.all(np.diff(s.fxyz[:, i] * bi.lattice.length[i]) >= -atol)

s, idx = bi.sort(axis=0, lattice=1, ret_atoms=True)
s, idx = bi.sort(axes=0, lattice=1, ret_atoms=True)
assert np.all(np.diff(s.xyz[:, 0]) >= -atol)
for ix in idx:
assert np.all(np.diff(bi.fxyz[ix, 1]) >= -atol)

s, idx = bi.sort(
axis=0, ascending=False, lattice=1, vector=[0, 0, 1], ret_atoms=True
axes=0, ascending=False, lattice=1, vector=[0, 0, 1], ret_atoms=True
)
assert np.all(np.diff(s.xyz[:, 0]) >= -atol)
for ix in idx:
Expand All @@ -1635,18 +1635,18 @@ def test_geometry_sort_int():
atol = 1e-9

for i in [0, 1, 2]:
s = bi.sort(axis0=i)
s = bi.sort(axes0=i)
assert np.all(np.diff(s.xyz[:, i]) >= -atol)
s = bi.sort(lattice3=i)
assert np.all(np.diff(s.fxyz[:, i] * bi.lattice.length[i]) >= -atol)

s, idx = bi.sort(axis12314=0, lattice0=1, ret_atoms=True)
s, idx = bi.sort(axes12314=0, lattice0=1, ret_atoms=True)
assert np.all(np.diff(s.xyz[:, 0]) >= -atol)
for ix in idx:
assert np.all(np.diff(bi.fxyz[ix, 1]) >= -atol)

s, idx = bi.sort(
ascending1=True, axis15=0, ascending0=False, lattice235=1, ret_atoms=True
ascending1=True, axes15=0, ascending0=False, lattice235=1, ret_atoms=True
)
assert np.all(np.diff(s.xyz[:, 0]) >= -atol)
for ix in idx:
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/geom/nanoribbon.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def nanoribbon(
ribbon.set_lattice([ribbon.cell[1, 0], -ribbon.cell[0, 1], ribbon.cell[2, 2]])

# Sort along x, then y
ribbon = ribbon.sort(axis=(0, 1))
ribbon = ribbon.sort(axes=(0, 1))

if kind == "chiral":
# continue with the zigzag ribbon as building block
Expand Down

0 comments on commit 866a97b

Please sign in to comment.