diff --git a/src/sisl/viz/plots/matrix.py b/src/sisl/viz/plots/matrix.py index 9a9c265015..7476cf38a8 100644 --- a/src/sisl/viz/plots/matrix.py +++ b/src/sisl/viz/plots/matrix.py @@ -20,7 +20,9 @@ def atomic_matrix_plot( - matrix: Union[np.ndarray, sisl.SparseCSR, spmatrix], + matrix: Union[ + np.ndarray, sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, spmatrix + ], dim: int = 0, isc: Optional[int] = None, fill_value: Optional[float] = None, @@ -121,7 +123,6 @@ def atomic_matrix_plot( geometry, matrix_mode=mode, constrain_axes=constrain_axes, - draw_supercells=draw_supercells, set_labels=set_labels, ) diff --git a/src/sisl/viz/plots/tests/test_matrix.py b/src/sisl/viz/plots/tests/test_matrix.py new file mode 100644 index 0000000000..e506b9b9d7 --- /dev/null +++ b/src/sisl/viz/plots/tests/test_matrix.py @@ -0,0 +1,10 @@ +import sisl +from sisl.viz.plots import atomic_matrix_plot + + +def test_atomic_matrix_plot(): + + graphene = sisl.geom.graphene() + H = sisl.Hamiltonian(graphene) + + atomic_matrix_plot(H) diff --git a/src/sisl/viz/plotters/matrix.py b/src/sisl/viz/plotters/matrix.py index cc997af5a3..fbcd7e57b1 100644 --- a/src/sisl/viz/plotters/matrix.py +++ b/src/sisl/viz/plotters/matrix.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import List, Literal, Union import numpy as np @@ -11,11 +11,30 @@ def draw_matrix_separators( line: Union[bool, dict], geometry: sisl.Geometry, - matrix_mode: str, - separator_mode: str, + matrix_mode: Literal["orbitals", "atoms"], + separator_mode: Literal["orbitals", "atoms", "supercells"], draw_supercells: bool = True, showlegend: bool = True, ) -> List[dict]: + """Returns the actions to draw separators in a matrix. + + Parameters + ---------- + line: + If False, no lines are drawn. + If True, the default line style is used, which depends on `separator_mode`. + If a dictionary, it must contain the line style. + geometry: + The geometry associated to the matrix. + matrix_mode: + Whether the elements of the matrix belong to orbitals or atoms. + separator_mode: + What the separators should separate. + draw_supercells: + Whether to draw separators for the whole matrix (not just the unit cell). + showlegend: + Show the separator lines in the legend. + """ # Orbital separators don't make sense if it is an atom matrix. if separator_mode == "orbitals" and matrix_mode == "atoms": return [] @@ -106,26 +125,38 @@ def draw_matrix_separators( def set_matrix_axes( matrix, geometry: sisl.Geometry, - matrix_mode: str, + matrix_mode: Literal["orbitals", "atoms"], constrain_axes: bool = True, - draw_supercells: bool = True, set_labels: bool = False, -): +) -> List[dict]: + """Configure the axes of a matrix plot + + Parameters + ---------- + matrix: + The matrix that is plotted. + geometry: + The geometry associated to the matrix + matrix_mode: + Whether the elements of the matrix belong to orbitals or atoms. + constrain_axes: + Whether to try to constrain the axes to the domain of the matrix. + set_labels: + Whether to set the axis labels for each element of the matrix. + """ actions = [] actions.append(plot_actions.set_axes_equal()) + x_kwargs = {} + y_kwargs = {} + if constrain_axes: - actions.append( - plot_actions.set_axis( - axis="y", range=[matrix.shape[0] - 0.5, -0.5], constrain="domain" - ) - ) - actions.append( - plot_actions.set_axis( - axis="x", range=[-0.5, matrix.shape[1] - 0.5], constrain="domain" - ) - ) + x_kwargs["range"] = [-0.5, matrix.shape[1] - 0.5] + x_kwargs["constrain"] = "domain" + + y_kwargs["range"] = [matrix.shape[0] - 0.5, -0.5] + y_kwargs["constrain"] = "domain" if set_labels: if matrix_mode == "orbitals": @@ -143,17 +174,15 @@ def set_matrix_axes( else: ticks = np.arange(matrix.shape[0]).astype(str) - actions.append( - plot_actions.set_axis( - axis="y", ticktext=ticks, tickvals=np.arange(matrix.shape[0]) - ) - ) - actions.append( - plot_actions.set_axis( - axis="x", - ticktext=np.tile(ticks, geometry.n_s), - tickvals=np.arange(matrix.shape[1]), - ) - ) + x_kwargs["ticktext"] = np.tile(ticks, geometry.n_s) + x_kwargs["tickvals"] = np.arange(matrix.shape[1]) + + y_kwargs["ticktext"] = ticks + y_kwargs["tickvals"] = np.arange(matrix.shape[0]) + + if len(x_kwargs) > 0: + actions.append(plot_actions.set_axis(axis="x", **x_kwargs)) + if len(y_kwargs) > 0: + actions.append(plot_actions.set_axis(axis="y", **y_kwargs)) return actions diff --git a/src/sisl/viz/plotters/tests/test_matrix.py b/src/sisl/viz/plotters/tests/test_matrix.py new file mode 100644 index 0000000000..92b2ac1425 --- /dev/null +++ b/src/sisl/viz/plotters/tests/test_matrix.py @@ -0,0 +1,129 @@ +import itertools + +import numpy as np +import pytest + +import sisl +from sisl.viz.plotters.matrix import draw_matrix_separators, set_matrix_axes + + +def test_draw_matrix_separators_empty(): + C = sisl.Atom( + "C", + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + geom = sisl.geom.graphene(atoms=C) + + # Check combinations that should give no lines + assert draw_matrix_separators(False, geom, "orbitals", "orbitals") == [] + assert draw_matrix_separators(True, geom, "atoms", "orbitals") == [] + + +@pytest.mark.parametrize( + "draw_supercells,separator_mode", + itertools.product([True, False], ["atoms", "orbitals", "supercells"]), +) +def test_draw_matrix_separators(draw_supercells, separator_mode): + C = sisl.Atom( + "C", + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + geom = sisl.geom.graphene(atoms=C) + + lines = draw_matrix_separators( + {"color": "red"}, + geom, + "orbitals", + separator_mode=separator_mode, + draw_supercells=draw_supercells, + ) + + if not draw_supercells and separator_mode == "supercells": + assert len(lines) == 0 + return + + assert len(lines) == 1 + assert isinstance(lines[0], dict) + action = lines[0] + assert action["method"] == "draw_line" + # Check that the number of points in the line is fine + n_expected_points = { + ("atoms", False): 6, + ("atoms", True): 30, + ("orbitals", False): 12, + ("orbitals", True): 60, + ("supercells", True): 24, + }[separator_mode, draw_supercells] + + assert action["kwargs"]["x"].shape == (n_expected_points,) + assert action["kwargs"]["y"].shape == (n_expected_points,) + + assert action["kwargs"]["line"]["color"] == "red" + + +def test_set_matrix_axes(): + + C = sisl.Atom( + "C", + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + geom = sisl.geom.graphene(atoms=C) + + matrix = np.zeros((geom.no, geom.no * geom.n_s)) + + actions = set_matrix_axes( + matrix, geom, "orbitals", constrain_axes=False, set_labels=False + ) + assert len(actions) == 1 + assert actions[0]["method"] == "set_axes_equal" + + # Test without labels + actions = set_matrix_axes( + matrix, geom, "orbitals", constrain_axes=True, set_labels=False + ) + assert len(actions) == 3 + assert actions[0]["method"] == "set_axes_equal" + assert actions[1]["method"] == "set_axis" + assert actions[1]["kwargs"]["axis"] == "x" + assert actions[1]["kwargs"]["range"] == [-0.5, geom.no * geom.n_s - 0.5] + assert "tickvals" not in actions[1]["kwargs"] + assert "ticktext" not in actions[1]["kwargs"] + + assert actions[2]["method"] == "set_axis" + assert actions[2]["kwargs"]["axis"] == "y" + assert actions[2]["kwargs"]["range"] == [geom.no - 0.5, -0.5] + assert "tickvals" not in actions[2]["kwargs"] + assert "ticktext" not in actions[2]["kwargs"] + + # Test with labels + actions = set_matrix_axes( + matrix, geom, "orbitals", constrain_axes=True, set_labels=True + ) + assert len(actions) == 3 + assert actions[0]["method"] == "set_axes_equal" + assert actions[1]["method"] == "set_axis" + assert actions[1]["kwargs"]["axis"] == "x" + assert actions[1]["kwargs"]["range"] == [-0.5, geom.no * geom.n_s - 0.5] + assert np.all(actions[1]["kwargs"]["tickvals"] == np.arange(geom.no * geom.n_s)) + assert len(actions[1]["kwargs"]["ticktext"]) == geom.no * geom.n_s + + assert actions[2]["method"] == "set_axis" + assert actions[2]["kwargs"]["axis"] == "y" + assert actions[2]["kwargs"]["range"] == [geom.no - 0.5, -0.5] + assert np.all(actions[2]["kwargs"]["tickvals"] == np.arange(geom.no)) + assert len(actions[2]["kwargs"]["ticktext"]) == geom.no diff --git a/src/sisl/viz/processors/matrix.py b/src/sisl/viz/processors/matrix.py index 630858d039..b7beb5f784 100644 --- a/src/sisl/viz/processors/matrix.py +++ b/src/sisl/viz/processors/matrix.py @@ -6,7 +6,17 @@ import sisl -def get_orbital_sets_positions(atoms: List[sisl.Atom]): +def get_orbital_sets_positions(atoms: List[sisl.Atom]) -> List[List[int]]: + """Gets the orbital indices where an orbital set starts for each atom. + + An "orbital set" is a group of 2l + 1 orbitals with an angular momentum l + and different m. + + Parameters + ---------- + atoms : + List of atoms for which the orbital sets positions are desired. + """ specie_orb_sets = [] for at in atoms: orbitals = at.orbitals @@ -27,6 +37,17 @@ def get_geometry_from_matrix( matrix: Union[sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital, np.ndarray], geometry: Optional[sisl.Geometry] = None, ): + """Returns the geometry associated to a matrix. + + Parameters + ---------- + matrix : + The matrix for which the geometry is desired, which may have + an associated geometry. + geometry : + The geometry to be returned. This is to be used when we already + have a geometry and we don't want to extract it from the matrix. + """ if geometry is not None: pass elif hasattr(matrix, "geometry"): @@ -41,6 +62,23 @@ def matrix_as_array( isc: Optional[int] = None, fill_value: Optional[float] = None, ) -> np.ndarray: + """Converts any type of matrix to a numpy array. + + Parameters + ---------- + matrix : + The matrix to be converted. + dim : + If the matrix is a sisl sparse matrix and it has a third dimension, the + index to get in that third dimension. + isc : + If the matrix is a sisl SparseAtom or SparseOrbital, the index of the + cell within the auxiliary supercell. + + If None, the whole matrix is returned. + fill_value : + If the matrix is a sparse matrix, the value to fill the unset elements. + """ if isinstance(matrix, (sisl.SparseCSR, sisl.SparseAtom, sisl.SparseOrbital)): if dim is None: if isinstance(matrix, (sisl.SparseAtom, sisl.SparseOrbital)): @@ -63,8 +101,25 @@ def matrix_as_array( def determine_color_midpoint( - matrix: np.ndarray, cmid: Optional[float], crange: Optional[Tuple[float, float]] + matrix: np.ndarray, + cmid: Optional[float] = None, + crange: Optional[Tuple[float, float]] = None, ) -> Optional[float]: + """Determines the midpoint of a colorscale given a matrix of values. + + If ``cmid`` or ``crange`` are provided, this function just returns ``cmid``. + However, if none of them are provided, it returns 0 if the matrix has both + positive and negative values, and None otherwise. + + Parameters + ---------- + matrix : + The matrix of values for which the colorscale is to be determined. + cmid : + Possible already determined midpoint. + crange : + Possible already determined range. + """ if cmid is not None: return cmid elif crange is not None: @@ -76,10 +131,28 @@ def determine_color_midpoint( def get_matrix_mode(matrix) -> Literal["atoms", "orbitals"]: + """Returns what the elements of the matrix represent. + + If the matrix is a sisl SparseAtom, the elements are atoms. + Otherwise, they are assumed to be orbitals. + + Parameters + ---------- + matrix : + The matrix for which the mode is desired. + """ return "atoms" if isinstance(matrix, sisl.SparseAtom) else "orbitals" def sanitize_matrix_arrows(arrows: Union[dict, List[dict]]) -> List[dict]: + """Sanitizes an ``arrows`` argument to a list of sanitized specifications. + + Parameters + ---------- + arrows : + The arrows argument to be sanitized. If it is a dictionary, it is converted to a list + with a single element. + """ if isinstance(arrows, dict): arrows = [arrows] diff --git a/src/sisl/viz/processors/tests/test_matrix.py b/src/sisl/viz/processors/tests/test_matrix.py new file mode 100644 index 0000000000..9052fc2fcd --- /dev/null +++ b/src/sisl/viz/processors/tests/test_matrix.py @@ -0,0 +1,151 @@ +import numpy as np +import pytest + +import sisl +from sisl.viz.processors.matrix import ( + determine_color_midpoint, + get_geometry_from_matrix, + get_matrix_mode, + get_orbital_sets_positions, + matrix_as_array, + sanitize_matrix_arrows, +) + +pytestmark = [pytest.mark.viz, pytest.mark.processors] + + +def test_orbital_positions(): + + C = sisl.Atom( + 6, + orbitals=[ + sisl.AtomicOrbital("2s"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + sisl.AtomicOrbital("2px"), + sisl.AtomicOrbital("2py"), + sisl.AtomicOrbital("2pz"), + ], + ) + + H = sisl.Atom(1, orbitals=[sisl.AtomicOrbital("1s")]) + + positions = get_orbital_sets_positions([C, H]) + + assert len(positions) == 2 + + assert positions[0] == [0, 1, 4] + assert positions[1] == [0] + + +def test_get_geometry_from_matrix(): + + geom = sisl.geom.graphene() + + matrix = sisl.Hamiltonian(geom) + + assert get_geometry_from_matrix(matrix) is geom + + geom_copy = geom.copy() + + assert get_geometry_from_matrix(matrix, geom_copy) is geom_copy + + # Check that if we pass something without an associated geometry + # but we provide a geometry it will work + assert get_geometry_from_matrix(np.array([1, 2]), geom) is geom + + +def test_matrix_as_array(): + + matrix = sisl.SparseCSR((2, 2, 2)) + + matrix[0, 0, 0] = 1 + matrix[0, 0, 1] = 2 + + array = matrix_as_array(matrix, fill_value=0) + assert np.allclose(array, np.array([[1, 0], [0, 0]])) + + array = matrix_as_array(matrix, dim=1, fill_value=0) + assert np.allclose(array, np.array([[2, 0], [0, 0]])) + + array = matrix_as_array(matrix) + assert array[0, 0] == 1 + assert np.isnan(array).sum() == 3 + + # Check that it can work with auxiliary supercells + geom = sisl.geom.graphene( + atoms=sisl.Atom("C", orbitals=[sisl.AtomicOrbital("2pz")]) + ) + matrix = sisl.Hamiltonian(geom) + + array = matrix_as_array(matrix) + assert array.shape == matrix.shape[:-1] + + array = matrix_as_array(matrix, isc=1) + assert array.shape == (geom.no, geom.no) + + # Check that a numpy array is kept untouched + matrix = np.array([[1, 2], [3, 4]]) + assert np.allclose(matrix_as_array(matrix), matrix) + + +def test_determine_color_midpoint(): + + # With the matrix containing only positive values + matrix = np.array([1, 2]) + + assert determine_color_midpoint(matrix) is None + assert determine_color_midpoint(matrix, cmid=1, crange=(0, 1)) == 1 + assert determine_color_midpoint(matrix, crange=(0, 1)) is None + + # With the matrix containing only negative values + matrix = np.array([-1, -2]) + + assert determine_color_midpoint(matrix) is None + assert determine_color_midpoint(matrix, cmid=1, crange=(0, 1)) == 1 + assert determine_color_midpoint(matrix, crange=(0, 1)) is None + + # With the matrix containing both positive and negative values + matrix = np.array([-1, 1]) + + assert determine_color_midpoint(matrix) == 0 + assert determine_color_midpoint(matrix, cmid=1, crange=(-1, 1)) == 1 + assert determine_color_midpoint(matrix, crange=(-1, 1)) is None + + +def test_get_matrix_mode(): + + geom = sisl.geom.graphene() + + matrix = sisl.SparseAtom(geom) + assert get_matrix_mode(matrix) == "atoms" + + matrix = sisl.Hamiltonian(geom) + assert get_matrix_mode(matrix) == "orbitals" + + matrix = sisl.SparseCSR((2, 2)) + assert get_matrix_mode(matrix) == "orbitals" + + matrix = np.array([[1, 2], [3, 4]]) + assert get_matrix_mode(matrix) == "orbitals" + + +def test_sanitize_matrix_arrows(): + + arrows = {} + assert sanitize_matrix_arrows(arrows) == [{"center": "middle"}] + + geom = sisl.geom.graphene() + data = sisl.Hamiltonian(geom, dim=2) + data[0, 0, 0] = 1 + data[0, 0, 1] = 2 + + arrows = [{"data": data}] + sanitized = sanitize_matrix_arrows(arrows) + + assert len(sanitized) == 1 + assert sanitized[0]["data"].shape == data.shape + san_data = sanitized[0]["data"] + assert san_data[0, 0, 0] == 1 + assert san_data[0, 0, 1] == -2