Skip to content

Commit

Permalink
More optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Jun 5, 2024
1 parent 97a9949 commit 92693ed
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
24 changes: 20 additions & 4 deletions src/admmsolver/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,14 @@ def __matmul__(self, other: Union[MatrixBase, np.ndarray]) -> Union[MatrixBase,
(self.shape[0], other.shape[1])
)
elif isinstance(other, PartialDiagonalMatrix):
return DenseMatrix(self._diagonals[:, None] * other.asmatrix())
diags = self.diagonals.reshape(other.matrix.shape[0], -1)
if np.allclose(diags, diags[:, 0:1]):
return PartialDiagonalMatrix(
diags[:, 0][:, None] * other.matrix.asmatrix(),
other.rest_dims
)
else:
return DenseMatrix(self._diagonals[:, None] * other.asmatrix())
elif isinstance(other, ScaledIdentityMatrix):
return self @ other.to_diagonal_matrix()
else:
Expand Down Expand Up @@ -339,10 +346,13 @@ def __matmul__(self, other) -> Union[np.ndarray, MatrixBase]:
assert isinstance(other, MatrixBase) or isinstance(other, np.ndarray)
if isinstance(other, np.ndarray):
return self.matvec(other)
elif isinstance(other, PartialDiagonalMatrix) and self.rest_dims == other.rest_dims:

if isinstance(other, PartialDiagonalMatrix) and self.rest_dims == other.rest_dims:
return PartialDiagonalMatrix(self.matrix@other.matrix, self.rest_dims)
else:
return DenseMatrix(self.asmatrix() @ other.asmatrix())
if isinstance(other, ScaledIdentityMatrix) and other.is_diagonal():
return PartialDiagonalMatrix(other.coeff * self.matrix, self.rest_dims)

return DenseMatrix(self.asmatrix() @ other.asmatrix())

def __mul__(self, other) -> 'PartialDiagonalMatrix':
if type(other) in [float, complex, np.float64, np.complex128]:
Expand Down Expand Up @@ -459,6 +469,12 @@ def _add_DiagonalMatrix_PartialDiagonalMatrix(a, b):
return DenseMatrix(a.asmatrix() + b.asmatrix())


def _add_PartialDiagonalMatrix_PartialDiagonalMatrix(a, b):
if a.rest_dims == b.rest_dims:
return PartialDiagonalMatrix(a.matrix + b.matrix, a.rest_dims)
return DenseMatrix(a.asmatrix() + b.asmatrix())


def _add_DenseMatrix_DenseMatrix(a, b):
return DenseMatrix(a.asmatrix() + b.asmatrix())

Expand Down
32 changes: 29 additions & 3 deletions test/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def _randn_cmplx(*shape) -> np.ndarray:
return np.random.randn(*shape) + 1j * np.random.randn(*shape)


def test_matmal():
def test_matmul():
np.random.seed(100)

# (12, 12) * (12, 4)
Expand All @@ -33,7 +33,7 @@ def test_matmal():
np.testing.assert_allclose(lr.asmatrix(), l.asmatrix() @ r.asmatrix())


def test_matmal2():
def test_matmul():
np.random.seed(100)

# (4, 12) * (12, 12)
Expand Down Expand Up @@ -124,6 +124,32 @@ def test_DiagonalMatrix_PartialDiagonalMatrix():
np.testing.assert_allclose(ab.asmatrix(), a.asmatrix() + b.asmatrix())


def test_PartialDiagonalMatrix_PartialDiagonalMatrix():
np.random.seed(100)
n = 3
a = PartialDiagonalMatrix(_randn_cmplx(n, n), (2, 2))
b = PartialDiagonalMatrix(_randn_cmplx(n, n), (2, 2))
ab = a + b
assert isinstance(ab, PartialDiagonalMatrix)
np.testing.assert_allclose(ab.asmatrix(), a.asmatrix() + b.asmatrix())


def test_matmul_DiagonalMatrix_PartialDiagonalMatrix():
np.random.seed(100)

n = 3
diags_ = np.random.randn(n)
diags = np.zeros((n, 4))
for i in range(4):
diags[:, i] = diags_

a = DiagonalMatrix(diags.ravel())
b = PartialDiagonalMatrix(_randn_cmplx(n, n), (2, 2))

ab = a @ b
assert isinstance(ab, PartialDiagonalMatrix)
np.testing.assert_allclose(ab.asmatrix(), a.asmatrix() @ b.asmatrix())

def test_inv():
np.random.seed(100)

Expand Down Expand Up @@ -207,7 +233,7 @@ def test_batched_matvec(n, m):
np.testing.assert_allclose(mv, m.asmatrix()@vec)


def test_matmal_diagonal():
def test_matmul_diagonal():
np.random.seed(100)
a = DiagonalMatrix(np.random.randn(2), shape=(4,2))
b = DiagonalMatrix(np.random.randn(2), shape=(2,4))
Expand Down

0 comments on commit 92693ed

Please sign in to comment.