Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pfebrer committed Feb 20, 2024
1 parent ceed46a commit 715d43c
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 32 deletions.
5 changes: 3 additions & 2 deletions src/sisl/viz/plots/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


def atomic_matrix_plot(
matrix: Union[np.ndarray, sisl.SparseCSR, spmatrix],
matrix: Union[
np.ndarray, sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, spmatrix
],
dim: int = 0,
isc: Optional[int] = None,
fill_value: Optional[float] = None,
Expand Down Expand Up @@ -121,7 +123,6 @@ def atomic_matrix_plot(
geometry,
matrix_mode=mode,
constrain_axes=constrain_axes,
draw_supercells=draw_supercells,
set_labels=set_labels,
)

Expand Down
10 changes: 10 additions & 0 deletions src/sisl/viz/plots/tests/test_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import sisl
from sisl.viz.plots import atomic_matrix_plot


def test_atomic_matrix_plot():

graphene = sisl.geom.graphene()
H = sisl.Hamiltonian(graphene)

atomic_matrix_plot(H)
85 changes: 57 additions & 28 deletions src/sisl/viz/plotters/matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Literal, Union

import numpy as np

Expand All @@ -11,11 +11,30 @@
def draw_matrix_separators(
line: Union[bool, dict],
geometry: sisl.Geometry,
matrix_mode: str,
separator_mode: str,
matrix_mode: Literal["orbitals", "atoms"],
separator_mode: Literal["orbitals", "atoms", "supercells"],
draw_supercells: bool = True,
showlegend: bool = True,
) -> List[dict]:
"""Returns the actions to draw separators in a matrix.
Parameters
----------
line:
If False, no lines are drawn.
If True, the default line style is used, which depends on `separator_mode`.
If a dictionary, it must contain the line style.
geometry:
The geometry associated to the matrix.
matrix_mode:
Whether the elements of the matrix belong to orbitals or atoms.
separator_mode:
What the separators should separate.
draw_supercells:
Whether to draw separators for the whole matrix (not just the unit cell).
showlegend:
Show the separator lines in the legend.
"""
# Orbital separators don't make sense if it is an atom matrix.
if separator_mode == "orbitals" and matrix_mode == "atoms":
return []
Expand Down Expand Up @@ -106,26 +125,38 @@ def draw_matrix_separators(
def set_matrix_axes(
matrix,
geometry: sisl.Geometry,
matrix_mode: str,
matrix_mode: Literal["orbitals", "atoms"],
constrain_axes: bool = True,
draw_supercells: bool = True,
set_labels: bool = False,
):
) -> List[dict]:
"""Configure the axes of a matrix plot
Parameters
----------
matrix:
The matrix that is plotted.
geometry:
The geometry associated to the matrix
matrix_mode:
Whether the elements of the matrix belong to orbitals or atoms.
constrain_axes:
Whether to try to constrain the axes to the domain of the matrix.
set_labels:
Whether to set the axis labels for each element of the matrix.
"""
actions = []

actions.append(plot_actions.set_axes_equal())

x_kwargs = {}
y_kwargs = {}

if constrain_axes:
actions.append(
plot_actions.set_axis(
axis="y", range=[matrix.shape[0] - 0.5, -0.5], constrain="domain"
)
)
actions.append(
plot_actions.set_axis(
axis="x", range=[-0.5, matrix.shape[1] - 0.5], constrain="domain"
)
)
x_kwargs["range"] = [-0.5, matrix.shape[1] - 0.5]
x_kwargs["constrain"] = "domain"

y_kwargs["range"] = [matrix.shape[0] - 0.5, -0.5]
y_kwargs["constrain"] = "domain"

if set_labels:
if matrix_mode == "orbitals":
Expand All @@ -143,17 +174,15 @@ def set_matrix_axes(
else:
ticks = np.arange(matrix.shape[0]).astype(str)

Check warning on line 175 in src/sisl/viz/plotters/matrix.py

View check run for this annotation

Codecov / codecov/patch

src/sisl/viz/plotters/matrix.py#L175

Added line #L175 was not covered by tests

actions.append(
plot_actions.set_axis(
axis="y", ticktext=ticks, tickvals=np.arange(matrix.shape[0])
)
)
actions.append(
plot_actions.set_axis(
axis="x",
ticktext=np.tile(ticks, geometry.n_s),
tickvals=np.arange(matrix.shape[1]),
)
)
x_kwargs["ticktext"] = np.tile(ticks, geometry.n_s)
x_kwargs["tickvals"] = np.arange(matrix.shape[1])

y_kwargs["ticktext"] = ticks
y_kwargs["tickvals"] = np.arange(matrix.shape[0])

if len(x_kwargs) > 0:
actions.append(plot_actions.set_axis(axis="x", **x_kwargs))
if len(y_kwargs) > 0:
actions.append(plot_actions.set_axis(axis="y", **y_kwargs))

return actions
129 changes: 129 additions & 0 deletions src/sisl/viz/plotters/tests/test_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import itertools

import numpy as np
import pytest

import sisl
from sisl.viz.plotters.matrix import draw_matrix_separators, set_matrix_axes


def test_draw_matrix_separators_empty():
C = sisl.Atom(
"C",
orbitals=[
sisl.AtomicOrbital("2s"),
sisl.AtomicOrbital("2px"),
sisl.AtomicOrbital("2py"),
sisl.AtomicOrbital("2pz"),
],
)
geom = sisl.geom.graphene(atoms=C)

# Check combinations that should give no lines
assert draw_matrix_separators(False, geom, "orbitals", "orbitals") == []
assert draw_matrix_separators(True, geom, "atoms", "orbitals") == []


@pytest.mark.parametrize(
"draw_supercells,separator_mode",
itertools.product([True, False], ["atoms", "orbitals", "supercells"]),
)
def test_draw_matrix_separators(draw_supercells, separator_mode):
C = sisl.Atom(
"C",
orbitals=[
sisl.AtomicOrbital("2s"),
sisl.AtomicOrbital("2px"),
sisl.AtomicOrbital("2py"),
sisl.AtomicOrbital("2pz"),
],
)
geom = sisl.geom.graphene(atoms=C)

lines = draw_matrix_separators(
{"color": "red"},
geom,
"orbitals",
separator_mode=separator_mode,
draw_supercells=draw_supercells,
)

if not draw_supercells and separator_mode == "supercells":
assert len(lines) == 0
return

assert len(lines) == 1
assert isinstance(lines[0], dict)
action = lines[0]
assert action["method"] == "draw_line"
# Check that the number of points in the line is fine
n_expected_points = {
("atoms", False): 6,
("atoms", True): 30,
("orbitals", False): 12,
("orbitals", True): 60,
("supercells", True): 24,
}[separator_mode, draw_supercells]

assert action["kwargs"]["x"].shape == (n_expected_points,)
assert action["kwargs"]["y"].shape == (n_expected_points,)

assert action["kwargs"]["line"]["color"] == "red"


def test_set_matrix_axes():

C = sisl.Atom(
"C",
orbitals=[
sisl.AtomicOrbital("2s"),
sisl.AtomicOrbital("2px"),
sisl.AtomicOrbital("2py"),
sisl.AtomicOrbital("2pz"),
],
)
geom = sisl.geom.graphene(atoms=C)

matrix = np.zeros((geom.no, geom.no * geom.n_s))

actions = set_matrix_axes(
matrix, geom, "orbitals", constrain_axes=False, set_labels=False
)
assert len(actions) == 1
assert actions[0]["method"] == "set_axes_equal"

# Test without labels
actions = set_matrix_axes(
matrix, geom, "orbitals", constrain_axes=True, set_labels=False
)
assert len(actions) == 3
assert actions[0]["method"] == "set_axes_equal"
assert actions[1]["method"] == "set_axis"
assert actions[1]["kwargs"]["axis"] == "x"
assert actions[1]["kwargs"]["range"] == [-0.5, geom.no * geom.n_s - 0.5]
assert "tickvals" not in actions[1]["kwargs"]
assert "ticktext" not in actions[1]["kwargs"]

assert actions[2]["method"] == "set_axis"
assert actions[2]["kwargs"]["axis"] == "y"
assert actions[2]["kwargs"]["range"] == [geom.no - 0.5, -0.5]
assert "tickvals" not in actions[2]["kwargs"]
assert "ticktext" not in actions[2]["kwargs"]

# Test with labels
actions = set_matrix_axes(
matrix, geom, "orbitals", constrain_axes=True, set_labels=True
)
assert len(actions) == 3
assert actions[0]["method"] == "set_axes_equal"
assert actions[1]["method"] == "set_axis"
assert actions[1]["kwargs"]["axis"] == "x"
assert actions[1]["kwargs"]["range"] == [-0.5, geom.no * geom.n_s - 0.5]
assert np.all(actions[1]["kwargs"]["tickvals"] == np.arange(geom.no * geom.n_s))
assert len(actions[1]["kwargs"]["ticktext"]) == geom.no * geom.n_s

assert actions[2]["method"] == "set_axis"
assert actions[2]["kwargs"]["axis"] == "y"
assert actions[2]["kwargs"]["range"] == [geom.no - 0.5, -0.5]
assert np.all(actions[2]["kwargs"]["tickvals"] == np.arange(geom.no))
assert len(actions[2]["kwargs"]["ticktext"]) == geom.no
77 changes: 75 additions & 2 deletions src/sisl/viz/processors/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
import sisl


def get_orbital_sets_positions(atoms: List[sisl.Atom]):
def get_orbital_sets_positions(atoms: List[sisl.Atom]) -> List[List[int]]:
"""Gets the orbital indices where an orbital set starts for each atom.
An "orbital set" is a group of 2l + 1 orbitals with an angular momentum l
and different m.
Parameters
----------
atoms :
List of atoms for which the orbital sets positions are desired.
"""
specie_orb_sets = []
for at in atoms:
orbitals = at.orbitals
Expand All @@ -27,6 +37,17 @@ def get_geometry_from_matrix(
matrix: Union[sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, np.ndarray],
geometry: Optional[sisl.Geometry] = None,
):
"""Returns the geometry associated to a matrix.
Parameters
----------
matrix :
The matrix for which the geometry is desired, which may have
an associated geometry.
geometry :
The geometry to be returned. This is to be used when we already
have a geometry and we don't want to extract it from the matrix.
"""
if geometry is not None:
pass
elif hasattr(matrix, "geometry"):
Expand All @@ -41,6 +62,23 @@ def matrix_as_array(
isc: Optional[int] = None,
fill_value: Optional[float] = None,
) -> np.ndarray:
"""Converts any type of matrix to a numpy array.
Parameters
----------
matrix :
The matrix to be converted.
dim :
If the matrix is a sisl sparse matrix and it has a third dimension, the
index to get in that third dimension.
isc :
If the matrix is a sisl SparseAtom or SparseOrbital, the index of the
cell within the auxiliary supercell.
If None, the whole matrix is returned.
fill_value :
If the matrix is a sparse matrix, the value to fill the unset elements.
"""
if isinstance(matrix, (sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital)):
if dim is None:
if isinstance(matrix, (sisl.SparseAtom, sisl.SparseOrbital)):
Expand All @@ -63,8 +101,25 @@ def matrix_as_array(


def determine_color_midpoint(
matrix: np.ndarray, cmid: Optional[float], crange: Optional[Tuple[float, float]]
matrix: np.ndarray,
cmid: Optional[float] = None,
crange: Optional[Tuple[float, float]] = None,
) -> Optional[float]:
"""Determines the midpoint of a colorscale given a matrix of values.
If ``cmid`` or ``crange`` are provided, this function just returns ``cmid``.
However, if none of them are provided, it returns 0 if the matrix has both
positive and negative values, and None otherwise.
Parameters
----------
matrix :
The matrix of values for which the colorscale is to be determined.
cmid :
Possible already determined midpoint.
crange :
Possible already determined range.
"""
if cmid is not None:
return cmid
elif crange is not None:
Expand All @@ -76,10 +131,28 @@ def determine_color_midpoint(


def get_matrix_mode(matrix) -> Literal["atoms", "orbitals"]:
"""Returns what the elements of the matrix represent.
If the matrix is a sisl SparseAtom, the elements are atoms.
Otherwise, they are assumed to be orbitals.
Parameters
----------
matrix :
The matrix for which the mode is desired.
"""
return "atoms" if isinstance(matrix, sisl.SparseAtom) else "orbitals"


def sanitize_matrix_arrows(arrows: Union[dict, List[dict]]) -> List[dict]:
"""Sanitizes an ``arrows`` argument to a list of sanitized specifications.
Parameters
----------
arrows :
The arrows argument to be sanitized. If it is a dictionary, it is converted to a list
with a single element.
"""
if isinstance(arrows, dict):
arrows = [arrows]

Expand Down
Loading

0 comments on commit 715d43c

Please sign in to comment.