Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] slogdet for matrices of dimension larger than 1000 #91

Open
neel-maniar opened this issue Jun 5, 2024 · 0 comments
Open

[Bug] slogdet for matrices of dimension larger than 1000 #91

neel-maniar opened this issue Jun 5, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@neel-maniar
Copy link

🐛 Bug

There is an issue with beartyping that means that taking the determinant of a matrix with dimension larger than $1000\times 1000$ returns an error.

E.g. try the below code with dim=1000 and dim=1001. The bug exists regardless of whether the line Sigma = cola.PSD(Sigma) is included.

To reproduce

** Code snippet to reproduce **

import cola
import cola.linalg
import jax.numpy as jnp
import jax.random as jr
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jr

dim = 1001
master_key = jr.key(0)

A = jr.normal(master_key, (dim, dim))
Sigma = A @ A.T  # Ensure Sigma is PSD
Sigma += 0.1 * jnp.eye(dim)  # Ensure Sigma has strictly positive determinant
Sigma = cola.ops.Dense(Sigma)  # Convert Sigma to a cola Dense object
Sigma = cola.PSD(Sigma)  # Tell cola that Sigma is PSD

print("Signed Log Determinant:")
print("-----------------------")
print("(jax.numpy)", jnp.linalg.slogdet(Sigma.to_dense()))
print("(cola)", cola.linalg.slogdet(Sigma))

** Stack trace/error message **

Signed Log Determinant:
-----------------------
(jax.numpy) SlogdetResult(sign=Array(1., dtype=float64), logabsdet=Array(5935.35245205, dtype=float64))
C:\Users\neelm\miniconda3\envs\gp\lib\site-packages\beartype\_util\hint\pep\utilpeptest.py:311: BeartypeDecorHintPep585DeprecationWarning: PEP 484 type hint typing.Callable deprecated by PEP 585. This hint is scheduled for removal in the first Python version released after October 5th, 2025. To resolve this, import this hint from "beartype.typing" rather than "typing". For further commentary and alternatives, see also:
    https://beartype.readthedocs.io/en/latest/api_roar/#pep-585-deprecations
  warn(

Expected Behavior

With dim=1000, I get expected behaviour:

Signed Log Determinant:
-----------------------
(jax.numpy) SlogdetResult(sign=Array(1., dtype=float64), logabsdet=Array(5926.90396938, dtype=float64))
(cola) (Array(1., dtype=float64), Array(5926.90396938, dtype=float64))

System information

CoLA Version: 0.0.5
JaX Version: 0.4.28
Computer OS: Windows 11 Home

Additional context

Could be related to #41 but it seems like a different error.

There also seem to be issues with accuracy of cola.solve for matrices larger than 1000 by 1000.

@neel-maniar neel-maniar added the bug Something isn't working label Jun 5, 2024
@neel-maniar neel-maniar changed the title [Bug] slogdet [Bug] slogdet for matrices of dimension larger than 1000 Jun 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant