torch-spex
computes spherical expansions of atomic neighbourhoods in torch
. It provides both a ready-to-use calculator, spex.SphericalExpansion
, as well as building blocks required to implement custom expansions. It's fully compatible with TorchScript and also has a metatensor
calculator. As of now, outputs are precisely equivalent to rascaline
for matching settings, while typically outperforming it by a significant margin on GPUs.
A spherical expansion is commonly used as the input for machine learning models of molecules and materials, but can, in principle, be used for other tasks related to learning of labelled point clouds. In the language of atomistic ML, a spherical expansion computes a fixed-size representation of the arrangement of atoms
where list
of torch.Tensor
, with every tensor arranged as
From a technical standpoint, we decompose the problem exactly as discussed above: spex.species
provides options for C
, spex.radial
for R
, and spex.angular
for S
. Finally, spex.cutoff
provides options for spex.SphericalExpansion
multiplies it all together and performs the sum over
Species embeddings (spex.species
):
Alchemical
: Typical embedding of fixed size, can be understood as mapping elements to linear combinations of "pseudo species",Orthogonal
: One-hot embedding, keeping each species in a separate subspace.
Radial embeddings (spex.radial
):
LaplacianEigenstates
andPhysicalBasis
: Physics-inspired basis functions (splined for efficiency),Bernstein
: Bernstein polynomials.
Angular embeddings (spex.angular
):
SphericalHarmonics
: (real) spherical harmonics (provided bysphericart
),SolidHarmonics
: solid harmonics (i.e., spherical harmonics multiplied byr**l
).
Cutoff functions (spex.cutoff
):
ShiftedCosine
: Cosine shifted such that it goes from 1 to 0 within a certainwidth
of the cutoff radiusStep
: Step function, or hard cutoff. Not advisable to use in practice at it makes the potential-energy surface not continuously differentiable and the resulting force field non-conservative.
The interfaces to each of these types of components is defined in the readmes of the sub-packages. For custom components, please make sure to implement the corresponding interface.
spex
requires torch
and sphericart
(with torch
support) to be installed manually.
Using the metatensor
interface requires the installation of the torch
version.
Running the tests additionally requires rascaline
(with torch
).
# install the appropriate version of torch for your setup
pip install "sphericart[torch] @ git+https://github.com/lab-cosmo/sphericart"
pip install metatensor[torch]
# make sure that the rust compiler is available
pip install git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch
Once these depencies are present, you should be able to install spex
as usual:
git clone [email protected]:sirmarcel/spex-dev.git
cd spex-dev
pip install -e .
# or (to install pytest, etc)
pip install -e .[dev]
Regrettably, there is currently no nicely rendered documentation for this package. Instead, you are expected and encouraged to look into the code itself, where you will find docstrings on all public-facing functions, as well as docsctrings at the sub-package level (in the __init__.py
) that explain in more detail what is going on.
spex
components can all be instantiated from dict
s that take the form {ClassName: {"arg": 1, ...}}
, similar to featomic
(fka rascaline
) and inspired by specable
. This feature is used heavily in SphericalExpansion
, which accepts this style of dict
to specify the different embeddings to use. If a .
is present in ClassName
, for example mything.SpecialRadial
, spex
will try to import SpecialRadial
from mything
, so we have a basic plug-in system included for free! In all cases, the "inner" dict
will simple be **splatted
into the __init__
of ClassName
.
Here is a full example for a SphericalExpansion
:
from spex import SphericalExpansion
exp = SphericalExpansion(
5.0, # cutoff radius
max_angular=3, # l = 0,1,2,3
radial={"LaplacianEigenstates": {"max_radial": 8}},
angular="SphericalHarmonics",
species={"Alchemical": {"pseudo_species": 4}},
cutoff_function={"ShiftedCosine": {"width": 0.5}},
)
Equivalently, we can write this out in .yaml
:
spex.SphericalExpansion:
cutoff: 5.0
max_angular: 3
radial:
LaplacianEigenstates:
max_radial: 8
angular: SphericalHarmonics
species:
Alchemical:
pseudo_species: 4
cutoff_function:
ShiftedCosine:
width: 0.5
From this .yaml
file, we can instantiate with
from spex import from_dict, read_yaml
exp = from_dict(read_yaml("spex.yaml"))
This is already one half of everything we need for a general tool to save and load torch
models, since torch
gives us the ability to save the weights (and other parameters) of torch.nn.Module
with module.state_dict()
, but it doesn't store how to instantiate a template to load the weights into. The lightweight dict
-based approach here manages that. For convenience, spex
puts the two things together: spex.save
will make a folder with params.torch
for weights and model.yaml
for the dict
(we call it a spec
), and spex.load
will instantiate the model and load the weights. This is not needed for most of the components from spex
, since currently only some radial basis functions have learnable parameters, but it can also be used to improvise checkpointing for experiments.
There are two way to customise the expansion: (a) with custom embeddings, or (b) in other ways. (a) is supported by the "plugin" system: You can write a class that conforms to the interfaces defined in the respective sub-packages of spex
(spex.radial
, spex.angular
, spex.species
), and then just pass, for example, radial={"mypackage.MyClass": {"my_arg" .."}}
to the SphericalExpansion
. (Note that mypackage
can also be a mypackage.py
in the same folder as your current script.) All other customisations, (b), are not supported intrinsically by spex
and you are expected to copy spherical_expansion.py
and hack away. Do not hesitate to ask if you have trouble with any particular plan, we're happy to help.
spex
only computes a spherical expansion and not any downstream descriptors. The well-known SOAP descriptor can be obtained, for example, via:
expansion = exp(R_ij, i, j, species) # -> [[i, m, n, c], ...]
soap = [torch.einsum("imnc,imNC->inNcC", e, e) for e in expansion] # -> [[i, n1, n2, c1, c2], [...], ...]
Note that this may produce very large features. You may want to consider other, "contracted", approaches where inner instead of outer products are performed. How to do this is beyond the scope of this readme. :)
spex
uses ruff
for formatting. Please use the pre-commit hook to make sure that any contributions are formatted correctly, or run ruff format . && ruff check --fix .
.
We generally adhere to the Google style guidelines. Note in particular how docstrings are formatted, and the docstrings are only encouraged for public-facing API, unless required to explain something. All docstrings and particularly comments are expected to be kept to a minimum and to be concise. Make sure you edit your LLM output accordingly.
Please review the development readme in spex/README.md
for further information.