Skip to content

Commit

Permalink
Overhaul
Browse files Browse the repository at this point in the history
  • Loading branch information
rosecers committed Jan 20, 2024
1 parent 9de555c commit 82eea56
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 46 deletions.
2 changes: 1 addition & 1 deletion anisoap/representations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .ellipsoidal_density_projection import EllipsoidalDensityProjection
from .radial_basis import RadialBasis
from .radial_basis import MonomialBasis, GTORadialBasis
42 changes: 24 additions & 18 deletions anisoap/representations/ellipsoidal_density_projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scipy.spatial.transform import Rotation
from tqdm.auto import tqdm

from anisoap.representations.radial_basis import RadialBasis
from anisoap.representations.radial_basis import GTORadialBasis, MonomialBasis
from anisoap.utils.moment_generator import *
from anisoap.utils.spherical_to_cartesian import spherical_to_cartesian

Expand Down Expand Up @@ -420,6 +420,8 @@ def __init__(
max_radial=None,
rotation_key="quaternion",
rotation_type="quaternion",
basis_rcond=None,
basis_tol=None,
):
# Store the input variables
self.max_angular = max_angular
Expand All @@ -433,28 +435,32 @@ def __init__(
raise NotImplementedError("Sorry! Gradients have not yet been implemented")
#

# Initialize the radial basis class
if radial_basis_name not in ["monomial", "gto"]:
raise NotImplementedError(
f"{self.radial_basis_name} is not an implemented basis"
". Try 'monomial' or 'gto'"
)
if radial_gaussian_width != None and radial_basis_name != "gto":
raise ValueError("Gaussian width can only be provided with GTO basis")
elif radial_gaussian_width is None and radial_basis_name == "gto":
raise ValueError("Gaussian width must be provided with GTO basis")
elif type(radial_gaussian_width) == int:
raise ValueError(
"radial_gaussian_width is set as an integer, which could cause overflow errors. Pass in float."
)

radial_hypers = {}
radial_hypers["radial_basis"] = radial_basis_name.lower() # lower case
radial_hypers["radial_gaussian_width"] = radial_gaussian_width
radial_hypers["max_angular"] = max_angular
radial_hypers["cutoff_radius"] = cutoff_radius
radial_hypers["max_radial"] = max_radial
self.radial_basis = RadialBasis(**radial_hypers)
radial_hypers["rcond"] = basis_rcond
radial_hypers["tol"] = basis_tol

# Initialize the radial basis class
if radial_basis_name == "gto":
self.radial_basis = GTORadialBasis(**radial_hypers)
elif radial_basis_name == "monomial":
self.radial_basis = MonomialBasis(**radial_hypers)
else:
raise NotImplementedError(
f"{self.radial_basis_name} is not an implemented basis"
". Try 'monomial' or 'gto'"
)
# if radial_gaussian_width != None and radial_basis_name != "gto":
# raise ValueError("Gaussian width can only be provided with GTO basis")
# elif radial_gaussian_width is None and radial_basis_name == "gto":
# raise ValueError("Gaussian width must be provided with GTO basis")
# elif type(radial_gaussian_width) == int:
# raise ValueError(
# "radial_gaussian_width is set as an integer, which could cause overflow errors. Pass in float."
# )

self.num_ns = self.radial_basis.get_num_radial_functions()
self.sph_to_cart = spherical_to_cartesian(self.max_angular, self.num_ns)
Expand Down
Loading

0 comments on commit 82eea56

Please sign in to comment.