Skip to content

Commit

Permalink
always solve in float64
Browse files Browse the repository at this point in the history
  • Loading branch information
Linux-cpp-lisp committed Dec 20, 2022
1 parent 4bb345d commit 5a365e0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
14 changes: 9 additions & 5 deletions nequip/utils/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@


def solver(X, y, alpha: Optional[float] = 0.001, stride: Optional[int] = 1, **kwargs):

dtype = torch.get_default_dtype()
# results are in the same "units" as y, so same dtype too:
dtype_out = y.dtype
# always solve in float64 for numerical stability
dtype = torch.float64
X = X[::stride].to(dtype)
y = y[::stride].to(dtype)

Expand Down Expand Up @@ -40,7 +42,7 @@ def solver(X, y, alpha: Optional[float] = 0.001, stride: Optional[int] = 1, **kw

logging.debug(f"Ridge Regression, residue {sigma2}")

return mean, cov
return mean.to(dtype_out), cov.to(dtype_out)


def down_sampling_by_composition(
Expand All @@ -61,8 +63,10 @@ def down_sampling_by_composition(
id_end = torch.cat((node_icomp + 1, torch.as_tensor([len(sort_by)])))

n_points = len(percentage)
new_X = torch.zeros((n_types * n_points, X.shape[1]))
new_y = torch.zeros((n_types * n_points))
new_X = torch.zeros(
(n_types * n_points, X.shape[1]), dtype=X.dtype, device=X.device
)
new_y = torch.zeros((n_types * n_points), dtype=y.dtype, device=y.device)
for i in range(n_types):
ids = sort_by[id_start[i] : id_end[i]]
for j, p in enumerate(percentage):
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/utils/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ def test_random(full_rank, alpha, per_species_set):
if alpha == 0 and not full_rank:
return

torch.manual_seed(0)
rng = torch.Generator().manual_seed(343)

ref_mean, ref_std, E, n_samples, n_dim = per_species_set

dtype = torch.get_default_dtype()

X = torch.randint(low=1, high=10, size=(n_samples, n_dim)).to(dtype)
X = torch.randint(low=1, high=10, size=(n_samples, n_dim), generator=rng).to(
torch.get_default_dtype()
)
if not full_rank:
X[:, n_dim - 2] = X[:, n_dim - 1] * 2
y = (X * E).sum(axis=-1)
Expand Down

0 comments on commit 5a365e0

Please sign in to comment.