From d23513a60a504b08aabf6320aac6c581d3cdc365 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 11 Nov 2024 23:00:12 +0000 Subject: [PATCH 1/8] rename alf-qleace --- concept_erasure/__init__.py | 3 + concept_erasure/alf_qleace.py | 366 ++++++++++++++++++++++++++++++++++ 2 files changed, 369 insertions(+) create mode 100644 concept_erasure/alf_qleace.py diff --git a/concept_erasure/__init__.py b/concept_erasure/__init__.py index 1e71bed..4572cc2 100644 --- a/concept_erasure/__init__.py +++ b/concept_erasure/__init__.py @@ -1,3 +1,4 @@ +from .alf_qleace import AlfQLeaceEraser, AlfQLeaceFitter from .concept_scrubber import ConceptScrubber from .groupby import GroupedTensor, groupby from .leace import ErasureMethod, LeaceEraser, LeaceFitter @@ -24,4 +25,6 @@ "QuadraticEraser", "QuadraticFitter", "QuantileNormalizer", + "AlfQLeaceEraser", + "AlfQLeaceFitter", ] diff --git a/concept_erasure/alf_qleace.py b/concept_erasure/alf_qleace.py new file mode 100644 index 0000000..eb61cc7 --- /dev/null +++ b/concept_erasure/alf_qleace.py @@ -0,0 +1,366 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from torch import Tensor + +from .caching import cached_property, invalidates_cache +from .groupby import groupby +from .shrinkage import optimal_linear_shrinkage + +ErasureMethod = Literal["leace", "orth"] + + +@dataclass(frozen=True) +class AlfQLeaceEraser: + """QLEACE eraser that erases concepts from a representation. First applies LEACE, + then applies pair-wise QLEACE using a projection matrix optimized to the class with + the covariance most divergent from the mean covariance. + + Since the LEACE projection matrix is guaranteed to be a rank k - 1 perturbation of + the identity, we store it implicitly in the d x k matrices `proj_left` and + `proj_right`. The full matrix is given by `torch.eye(d) - proj_left @ proj_right`. + + The ALF-QLEACE projection matrix is guaranteed to be a rank 1 perturbation of the + identity, given by torch.eye(d) - alf_qleace_vec @ alf_qleace_vec. + """ + + proj_left: Tensor + proj_right: Tensor + bias: Tensor | None + alf_qleace_vec: Tensor + + @classmethod + def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "AlfQLeaceEraser": + """Convenience method to fit a LeaceEraser on data and return it.""" + return AlfQLeaceFitter.fit(x, z, **kwargs).eraser + + @property + def P(self) -> Tensor: + """The LEACE projection matrix.""" + eye = torch.eye( + self.proj_left.shape[0], + device=self.proj_left.device, + dtype=self.proj_left.dtype, + ) + return eye - self.proj_left @ self.proj_right + + @property + def Q(self) -> Tensor: + """The ALF-QLEACE projection matrix.""" + eye = torch.eye( + self.alf_qleace_vec.shape[0], + device=self.alf_qleace_vec.device, + dtype=self.alf_qleace_vec.dtype, + ) + return eye - torch.outer(self.alf_qleace_vec, self.alf_qleace_vec) + + def __call__(self, x: Tensor) -> Tensor: + """Apply the projection to the input tensor.""" + delta = x - self.bias if self.bias is not None else x + + # Ensure we do the matmul in the most efficient order. + x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH + + # Apply the ALF-QLEACE projection + v = self.alf_qleace_vec + x_ = x_ - torch.einsum("i,bi->bi", v, (v @ x_.mH).unsqueeze(1)) + + return x_.type_as(x) + + def to(self, device: torch.device | str) -> "AlfQLeaceEraser": + """Move eraser to a new device.""" + return AlfQLeaceEraser( + self.proj_left.to(device), + self.proj_right.to(device), + self.bias.to(device) if self.bias is not None else None, + self.alf_qleace_vec.to(device), + ) + + +class AlfQLeaceFitter: + """Fits LEACE plus a linear transform that surgically erases the direction of + maximum covariance from a representation. + + This class implements Least-squares Concept Erasure (LEACE) from + https://arxiv.org/abs/2306.03819. You can also use a slightly simpler orthogonal + projection-based method by setting `method="orth"`. + + This class stores all the covariance statistics needed to compute the QLEACE eraser. + This allows the statistics to be updated incrementally with `update()`. + """ + + global_mean_x: Tensor + """Running mean of X.""" + + global_mean_z: Tensor + """Running mean of Z.""" + + sigma_xz_: Tensor + """Unnormalized cross-covariance matrix X^T Z.""" + + sigma_xx_: Tensor | None + """Unnormalized covariance matrix X^T X.""" + + sigma_xx_z_: Tensor + """Unnormalized cross-covariance matrix X^T X for each class Z""" + + global_n: Tensor + """Number of X samples seen so far.""" + + @classmethod + def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "AlfQLeaceFitter": + """Convenience method to fit a LeaceFitter on data and return it.""" + n, d = x.shape + _, k = z.reshape(n, -1).shape + + fitter = AlfQLeaceFitter(d, k, device=x.device, dtype=x.dtype, **kwargs) + return fitter.update(x, z) + + def __init__( + self, + x_dim: int, + z_dim: int, + method: ErasureMethod = "leace", + *, + affine: bool = True, + constrain_cov_trace: bool = True, + device: str | torch.device | None = None, + dtype: torch.dtype | None = None, + shrinkage: bool = True, + svd_tol: float = 0.01, + ): + """Initialize a `LeaceFitter`. + + Args: + x_dim: Dimensionality of the representation. + z_dim: Dimensionality of the concept. + affine: Whether to use a bias term to ensure the unconditional mean of the + features remains the same after erasure. + constrain_cov_trace: Whether to constrain the trace of the covariance of X + after erasure to be no greater than before erasure. This is especially + useful when injecting the scrubbed features back into a model. Without + this constraint, the norm of the model's hidden states may diverge in + some cases. + device: Device to put the statistics on. + dtype: Data type to use for the statistics. + shrinkage: Whether to use shrinkage to estimate the covariance matrix of X. + svd_tol: Singular values under this threshold are truncated, both during + the phase where we do SVD on the cross-covariance matrix, and at the + phase where we compute the pseudoinverse of the projected covariance + matrix. Higher values are more numerically stable and result in less + damage to the representation, but may leave trace correlations intact. + """ + super().__init__() + + self.x_dim = x_dim + self.z_dim = z_dim + + self.affine = affine + self.constrain_cov_trace = constrain_cov_trace + self.method = method + self.shrinkage = shrinkage + + assert svd_tol > 0.0, "`svd_tol` must be positive for numerical stability." + self.svd_tol = svd_tol + + self.global_mean_x = torch.zeros(x_dim, device=device, dtype=dtype) + self.global_mean_z = torch.zeros(z_dim, device=device, dtype=dtype) + + self.global_n = torch.tensor(0, device=device) + self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype) + + self.sigma_xx_ = torch.zeros(x_dim, x_dim, device=device, dtype=dtype) + + self.mean_x = torch.zeros(z_dim, x_dim, device=device, dtype=dtype) + self.n = torch.zeros(z_dim, device=device) + self.sigma_xx_z_ = torch.zeros(z_dim, x_dim, x_dim, device=device, dtype=dtype) + + @torch.no_grad() + @invalidates_cache("eraser") + def update(self, x: Tensor, z: Tensor) -> "AlfQLeaceFitter": + """Update the running statistics with a new batch of data.""" + + # Update the QLEACE-specific statistics + x_for_quadratic = x.flatten(0, -2).type_as(self.mean_x) + label_encoded_z = torch.argmax(z, dim=1) + for label, group in groupby(x_for_quadratic, label_encoded_z, dim=0): + self.update_single(group, label) + + # Update the LEACE statistics + d, c = self.sigma_xz_.shape + + x = x.reshape(-1, d).type_as(self.global_mean_x) + + n, d2 = x.shape + + assert d == d2, f"Unexpected number of features {d2}" + self.global_n += n + + # Welford's online algorithm + delta_x = x - self.global_mean_x + self.global_mean_x += delta_x.sum(dim=0) / self.global_n + delta_x2 = x - self.global_mean_x + + # Update the covariance matrix of X if needed (for LEACE) + if self.method == "leace": + assert self.sigma_xx_ is not None + self.sigma_xx_.addmm_(delta_x.mH, delta_x2) + + z = z.reshape(n, -1).type_as(x) + assert z.shape[-1] == c, f"Unexpected number of classes {z.shape[-1]}" + + delta_z = z - self.global_mean_z + self.global_mean_z += delta_z.sum(dim=0) / self.global_n + delta_z2 = z - self.global_mean_z + + # Update the cross-covariance matrix + self.sigma_xz_.addmm_(delta_x.mH, delta_z2) + + return self + + @torch.no_grad() + @invalidates_cache("eraser") + def update_single(self, x: Tensor, z: int) -> "AlfQLeaceFitter": + """Update the running statistics with `x`, all sampled from class `z`.""" + x = x.flatten(0, -2).type_as(self.mean_x) + + self.n[z] += len(x) + + # Welford's online algorithm + delta_x = x - self.mean_x[z] + self.mean_x[z] += delta_x.sum(dim=0) / self.n[z] + delta_x2 = x - self.mean_x[z] + + self.sigma_xx_z_[z].addmm_(delta_x.mH, delta_x2) + + return self + + @cached_property + def eraser(self) -> AlfQLeaceEraser: + """Erasure function lazily computed given the current statistics.""" + eye = torch.eye( + self.x_dim, device=self.global_mean_x.device, dtype=self.global_mean_x.dtype + ) + + # Compute QLEACE component + # Compute the (covariance - mean covariance) matrix difference for each class + self.sigma_xx_z_.shape + mean_sigma_xx_z = self.sigma_xx_z_.mean(dim=0) + sigma_xx_z_diffs = self.sigma_xx_z_ - mean_sigma_xx_z + + # Find the class that has the difference with the largest singular + # value (spectral norm) + svds: list[tuple[Tensor, Tensor, Tensor]] = [ + torch.svd_lowrank(sigma_xx_z_diffs[i], q=1) for i in range(self.z_dim) + ] + spectral_norms = torch.stack([svd[1][0] for svd in svds]) + z_idx = spectral_norms.argmax() + + # Select the principal direction associated with the singular value + U, S, Vh = svds[z_idx] + principal_direction = U[:, 0] + + # Projection collapses the principal direction + proj_qleace = eye - torch.outer(principal_direction, principal_direction) + + assert torch.isclose( + principal_direction.norm(p=2), torch.tensor(1.0), rtol=1e-5 + ) + assert torch.allclose(proj_qleace @ proj_qleace, proj_qleace, rtol=1e-5) + del proj_qleace + + # Compute LEACE component + # Compute the whitening and unwhitening matrices + sigma = self.sigma_xx + + # Find the transformation that minimizes + L, V = torch.linalg.eigh(sigma) + + # Threshold used by torch.linalg.pinv + mask = L > (L[-1] * sigma.shape[-1] * torch.finfo(L.dtype).eps) + + # Assuming PSD; account for numerical error + L.clamp_min_(0.0) + + W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mH + W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mH + + u, s, _ = torch.linalg.svd(W @ self.sigma_xz, full_matrices=False) + + # Throw away singular values that are too small + u *= s > self.svd_tol + + proj_left = W_inv @ u + proj_right = u.mH @ W + + if self.constrain_cov_trace: + P = eye - proj_left @ proj_right + + # Prevent the covariance trace from increasing + sigma = self.sigma_xx + old_trace = torch.trace(sigma) + new_trace = torch.trace(P @ sigma @ P.mH) + + # If applying the projection matrix increases the variance, this might + # cause instability, especially when erasure is applied multiple times. + # We regularize toward the orthogonal projection matrix to avoid this. + if new_trace.real > old_trace.real: + Q = eye - u @ u.mH + + # Set up the variables for the quadratic equation + x = new_trace + y = 2 * torch.trace(P @ sigma @ Q.mH) + z = torch.trace(Q @ sigma @ Q.mH) + w = old_trace + + # Solve for the mixture of P and Q that makes the trace equal to the + # trace of the original covariance matrix + discr = torch.sqrt( + 4 * w * x - 4 * w * y + 4 * w * z - 4 * x * z + y**2 + ) + alpha1 = (-y / 2 + z - discr / 2) / (x - y + z) + alpha2 = (-y / 2 + z + discr / 2) / (x - y + z) + + # Choose the positive root + alpha = torch.where(alpha1.real > 0, alpha1, alpha2).clamp(0, 1) + P = alpha * P + (1 - alpha) * Q + + # TODO: Avoid using SVD here + u, s, vh = torch.linalg.svd(eye - P) + proj_left = u * s.sqrt() + proj_right = vh * s.sqrt() + + return AlfQLeaceEraser( + proj_left, + proj_right, + bias=self.global_mean_x if self.affine else None, + alf_qleace_vec=principal_direction, + ) + + @property + def sigma_xx(self) -> Tensor: + """The covariance matrix of X.""" + assert self.global_n > 1, "Call update() before accessing sigma_xx" + assert ( + self.sigma_xx_ is not None + ), "Covariance statistics are not being tracked for X" + + # Accumulated numerical error may cause this to be slightly non-symmetric + S_hat = (self.sigma_xx_ + self.sigma_xx_.mH) / 2 + + # Apply Random Matrix Theory-based shrinkage + if self.shrinkage: + return optimal_linear_shrinkage( + S_hat / self.global_n, self.global_n, inplace=True + ) + + # Just apply Bessel's correction + else: + return S_hat / (self.global_n - 1) + + @property + def sigma_xz(self) -> Tensor: + """The cross-covariance matrix.""" + assert self.global_n > 1, "Call update() with labels before accessing sigma_xz" + return self.sigma_xz_ / (self.global_n - 1) From d4a93d8d3b6976cd2aefa68859adcda641658f70 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 12 Nov 2024 11:18:28 +0000 Subject: [PATCH 2/8] Compute ALF-Q after LEACE --- concept_erasure/alf_qleace.py | 63 ++++++++++++++++++++--------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/concept_erasure/alf_qleace.py b/concept_erasure/alf_qleace.py index eb61cc7..f20cc2d 100644 --- a/concept_erasure/alf_qleace.py +++ b/concept_erasure/alf_qleace.py @@ -243,33 +243,6 @@ def eraser(self) -> AlfQLeaceEraser: self.x_dim, device=self.global_mean_x.device, dtype=self.global_mean_x.dtype ) - # Compute QLEACE component - # Compute the (covariance - mean covariance) matrix difference for each class - self.sigma_xx_z_.shape - mean_sigma_xx_z = self.sigma_xx_z_.mean(dim=0) - sigma_xx_z_diffs = self.sigma_xx_z_ - mean_sigma_xx_z - - # Find the class that has the difference with the largest singular - # value (spectral norm) - svds: list[tuple[Tensor, Tensor, Tensor]] = [ - torch.svd_lowrank(sigma_xx_z_diffs[i], q=1) for i in range(self.z_dim) - ] - spectral_norms = torch.stack([svd[1][0] for svd in svds]) - z_idx = spectral_norms.argmax() - - # Select the principal direction associated with the singular value - U, S, Vh = svds[z_idx] - principal_direction = U[:, 0] - - # Projection collapses the principal direction - proj_qleace = eye - torch.outer(principal_direction, principal_direction) - - assert torch.isclose( - principal_direction.norm(p=2), torch.tensor(1.0), rtol=1e-5 - ) - assert torch.allclose(proj_qleace @ proj_qleace, proj_qleace, rtol=1e-5) - del proj_qleace - # Compute LEACE component # Compute the whitening and unwhitening matrices sigma = self.sigma_xx @@ -331,6 +304,42 @@ def eraser(self) -> AlfQLeaceEraser: proj_left = u * s.sqrt() proj_right = vh * s.sqrt() + # Compute ALF-Q component + + # Apply LEACE to the class-conditional covariance matrices + eye = torch.eye( + proj_left.shape[0], + device=proj_left.device, + dtype=proj_left.dtype, + ) + P = eye - proj_left @ proj_right + + leaced_sigma_xx_z_ = torch.stack( + [P @ self.sigma_xx_z_[i] @ P for i in range(self.z_dim)] + ) + + # Compute the (covariance - mean covariance) matrix difference for each class + mean_sigma_xx_z = leaced_sigma_xx_z_.mean(dim=0) + sigma_xx_z_diffs = leaced_sigma_xx_z_ - mean_sigma_xx_z + + # Find the class that has the difference with the largest singular value + batch_svd = torch.vmap( + lambda x: torch.svd_lowrank(x, q=1, niter=10), randomness="different" + ) + U, S, Vh = batch_svd(sigma_xx_z_diffs) + max_idx = torch.argmax(S.squeeze()) + + # Save the first principal direction of the largest covariance difference + principal_direction = U.squeeze()[max_idx] + assert torch.isclose( + principal_direction.norm(p=2), torch.tensor(1.0), rtol=1e-5 + ) + + # This projection collapses the principal direction + proj_qleace = eye - torch.outer(principal_direction, principal_direction) + assert torch.allclose(proj_qleace @ proj_qleace, proj_qleace, rtol=1e-5) + del proj_qleace + return AlfQLeaceEraser( proj_left, proj_right, From 64493869d0a4eb386c4b50300dff6de135910bee Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 13 Nov 2024 01:30:45 +0000 Subject: [PATCH 3/8] Try erasing multiple principal directions --- concept_erasure/alf_qleace.py | 57 ++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/concept_erasure/alf_qleace.py b/concept_erasure/alf_qleace.py index f20cc2d..6373a9c 100644 --- a/concept_erasure/alf_qleace.py +++ b/concept_erasure/alf_qleace.py @@ -28,7 +28,7 @@ class AlfQLeaceEraser: proj_left: Tensor proj_right: Tensor bias: Tensor | None - alf_qleace_vec: Tensor + alf_qleace_vecs: Tensor @classmethod def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "AlfQLeaceEraser": @@ -49,11 +49,11 @@ def P(self) -> Tensor: def Q(self) -> Tensor: """The ALF-QLEACE projection matrix.""" eye = torch.eye( - self.alf_qleace_vec.shape[0], - device=self.alf_qleace_vec.device, - dtype=self.alf_qleace_vec.dtype, + self.alf_qleace_vecs.shape[1], + device=self.alf_qleace_vecs.device, + dtype=self.alf_qleace_vecs.dtype, ) - return eye - torch.outer(self.alf_qleace_vec, self.alf_qleace_vec) + return eye - (self.alf_qleace_vecs.mH @ self.alf_qleace_vecs) def __call__(self, x: Tensor) -> Tensor: """Apply the projection to the input tensor.""" @@ -63,8 +63,8 @@ def __call__(self, x: Tensor) -> Tensor: x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH # Apply the ALF-QLEACE projection - v = self.alf_qleace_vec - x_ = x_ - torch.einsum("i,bi->bi", v, (v @ x_.mH).unsqueeze(1)) + v = self.alf_qleace_vecs + x_ = x_ - (v @ x.mH).mH @ v return x_.type_as(x) @@ -74,7 +74,7 @@ def to(self, device: torch.device | str) -> "AlfQLeaceEraser": self.proj_left.to(device), self.proj_right.to(device), self.bias.to(device) if self.bias is not None else None, - self.alf_qleace_vec.to(device), + self.alf_qleace_vecs.to(device), ) @@ -239,6 +239,7 @@ def update_single(self, x: Tensor, z: int) -> "AlfQLeaceFitter": @cached_property def eraser(self) -> AlfQLeaceEraser: """Erasure function lazily computed given the current statistics.""" + n_dims = 10 eye = torch.eye( self.x_dim, device=self.global_mean_x.device, dtype=self.global_mean_x.dtype ) @@ -314,37 +315,37 @@ def eraser(self) -> AlfQLeaceEraser: ) P = eye - proj_left @ proj_right - leaced_sigma_xx_z_ = torch.stack( + transformed_sigma_xx_z_ = torch.stack( [P @ self.sigma_xx_z_[i] @ P for i in range(self.z_dim)] ) - # Compute the (covariance - mean covariance) matrix difference for each class - mean_sigma_xx_z = leaced_sigma_xx_z_.mean(dim=0) - sigma_xx_z_diffs = leaced_sigma_xx_z_ - mean_sigma_xx_z + principal_directions = [] + for _ in range(n_dims): + # Compute the class conditional covariance differences from the mean + mean_sigma_xx_z = transformed_sigma_xx_z_.mean(dim=0) + sigma_xx_z_diffs = transformed_sigma_xx_z_ - mean_sigma_xx_z - # Find the class that has the difference with the largest singular value - batch_svd = torch.vmap( - lambda x: torch.svd_lowrank(x, q=1, niter=10), randomness="different" - ) - U, S, Vh = batch_svd(sigma_xx_z_diffs) - max_idx = torch.argmax(S.squeeze()) + batch_svd = torch.vmap( + lambda x: torch.svd_lowrank(x, q=1, niter=10), randomness="different" + ) + U, S, Vh = batch_svd(sigma_xx_z_diffs) - # Save the first principal direction of the largest covariance difference - principal_direction = U.squeeze()[max_idx] - assert torch.isclose( - principal_direction.norm(p=2), torch.tensor(1.0), rtol=1e-5 - ) + max_idx = torch.argmax(S.squeeze()) + principal_directions.append(U.squeeze()[max_idx]) - # This projection collapses the principal direction - proj_qleace = eye - torch.outer(principal_direction, principal_direction) - assert torch.allclose(proj_qleace @ proj_qleace, proj_qleace, rtol=1e-5) - del proj_qleace + # Transform the class-conditional covariance matrices for the next iteration + proj_qleace = eye - torch.outer( + principal_directions[-1], principal_directions[-1] + ) + transformed_sigma_xx_z_ = torch.stack( + [proj_qleace @ sigma @ proj_qleace for sigma in transformed_sigma_xx_z_] + ) return AlfQLeaceEraser( proj_left, proj_right, bias=self.global_mean_x if self.affine else None, - alf_qleace_vec=principal_direction, + alf_qleace_vecs=torch.stack(principal_directions), ) @property From 8f7f6608291db90b4fb8eb2d7c87d27523ebf0c0 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 16 Dec 2024 05:20:36 +0000 Subject: [PATCH 4/8] Support erasing up to a specified worst-case covariance difference or intervention rank --- concept_erasure/alf_qleace.py | 79 ++++++++++++++++++++++++----------- concept_erasure/quadratic.py | 8 ++++ 2 files changed, 62 insertions(+), 25 deletions(-) diff --git a/concept_erasure/alf_qleace.py b/concept_erasure/alf_qleace.py index 6373a9c..36cee99 100644 --- a/concept_erasure/alf_qleace.py +++ b/concept_erasure/alf_qleace.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Literal import torch from torch import Tensor @@ -8,8 +7,6 @@ from .groupby import groupby from .shrinkage import optimal_linear_shrinkage -ErasureMethod = Literal["leace", "orth"] - @dataclass(frozen=True) class AlfQLeaceEraser: @@ -70,12 +67,12 @@ def __call__(self, x: Tensor) -> Tensor: def to(self, device: torch.device | str) -> "AlfQLeaceEraser": """Move eraser to a new device.""" - return AlfQLeaceEraser( - self.proj_left.to(device), - self.proj_right.to(device), - self.bias.to(device) if self.bias is not None else None, - self.alf_qleace_vecs.to(device), - ) + self.proj_left = self.proj_left.to(device) + self.proj_right = self.proj_right.to(device) + self.bias = self.bias.to(device) if self.bias is not None else None + self.alf_qleace_vecs = self.alf_qleace_vecs.to(device) + + return self class AlfQLeaceFitter: @@ -83,8 +80,7 @@ class AlfQLeaceFitter: maximum covariance from a representation. This class implements Least-squares Concept Erasure (LEACE) from - https://arxiv.org/abs/2306.03819. You can also use a slightly simpler orthogonal - projection-based method by setting `method="orth"`. + https://arxiv.org/abs/2306.03819. This class stores all the covariance statistics needed to compute the QLEACE eraser. This allows the statistics to be updated incrementally with `update()`. @@ -99,7 +95,7 @@ class AlfQLeaceFitter: sigma_xz_: Tensor """Unnormalized cross-covariance matrix X^T Z.""" - sigma_xx_: Tensor | None + sigma_xx_: Tensor """Unnormalized covariance matrix X^T X.""" sigma_xx_z_: Tensor @@ -121,7 +117,6 @@ def __init__( self, x_dim: int, z_dim: int, - method: ErasureMethod = "leace", *, affine: bool = True, constrain_cov_trace: bool = True, @@ -129,6 +124,8 @@ def __init__( dtype: torch.dtype | None = None, shrinkage: bool = True, svd_tol: float = 0.01, + max_rank: int | None = None, + target_erasure: float = 0.9, ): """Initialize a `LeaceFitter`. @@ -150,6 +147,9 @@ def __init__( phase where we compute the pseudoinverse of the projected covariance matrix. Higher values are more numerically stable and result in less damage to the representation, but may leave trace correlations intact. + target_erasure: Fraction of the worst-case covariance difference between + classes to erase. Higher values result in higher rank interventions. + max_rank: Maximum rank of the intervention. """ super().__init__() @@ -158,9 +158,9 @@ def __init__( self.affine = affine self.constrain_cov_trace = constrain_cov_trace - self.method = method self.shrinkage = shrinkage - + self.target_erasure = target_erasure + self.max_rank = max_rank assert svd_tol > 0.0, "`svd_tol` must be positive for numerical stability." self.svd_tol = svd_tol @@ -202,10 +202,8 @@ def update(self, x: Tensor, z: Tensor) -> "AlfQLeaceFitter": self.global_mean_x += delta_x.sum(dim=0) / self.global_n delta_x2 = x - self.global_mean_x - # Update the covariance matrix of X if needed (for LEACE) - if self.method == "leace": - assert self.sigma_xx_ is not None - self.sigma_xx_.addmm_(delta_x.mH, delta_x2) + # Update the covariance matrix of X for LEACE + self.sigma_xx_.addmm_(delta_x.mH, delta_x2) z = z.reshape(n, -1).type_as(x) assert z.shape[-1] == c, f"Unexpected number of classes {z.shape[-1]}" @@ -239,7 +237,6 @@ def update_single(self, x: Tensor, z: int) -> "AlfQLeaceFitter": @cached_property def eraser(self) -> AlfQLeaceEraser: """Erasure function lazily computed given the current statistics.""" - n_dims = 10 eye = torch.eye( self.x_dim, device=self.global_mean_x.device, dtype=self.global_mean_x.dtype ) @@ -318,17 +315,22 @@ def eraser(self) -> AlfQLeaceEraser: transformed_sigma_xx_z_ = torch.stack( [P @ self.sigma_xx_z_[i] @ P for i in range(self.z_dim)] ) + base_cov_diff_norm = ( + (transformed_sigma_xx_z_ - transformed_sigma_xx_z_.mean(dim=0)) + .norm(dim=(1, 2)) + .max() + ) + # Erase to max_rank principal directions to minimize the worst-case + # covariance difference between classes principal_directions = [] - for _ in range(n_dims): + max_rank = self.max_rank or transformed_sigma_xx_z_.flatten(1).shape[1] + for i in range(max_rank): # Compute the class conditional covariance differences from the mean mean_sigma_xx_z = transformed_sigma_xx_z_.mean(dim=0) sigma_xx_z_diffs = transformed_sigma_xx_z_ - mean_sigma_xx_z - batch_svd = torch.vmap( - lambda x: torch.svd_lowrank(x, q=1, niter=10), randomness="different" - ) - U, S, Vh = batch_svd(sigma_xx_z_diffs) + U, S, Vh = torch.svd_lowrank(sigma_xx_z_diffs, q=1, niter=10) max_idx = torch.argmax(S.squeeze()) principal_directions.append(U.squeeze()[max_idx]) @@ -341,6 +343,19 @@ def eraser(self) -> AlfQLeaceEraser: [proj_qleace @ sigma @ proj_qleace for sigma in transformed_sigma_xx_z_] ) + current_cov_diff_norm = ( + (transformed_sigma_xx_z_ - transformed_sigma_xx_z_.mean(dim=0)) + .norm(dim=(1, 2)) + .max() + ) + + if current_cov_diff_norm < (1 - self.target_erasure) * base_cov_diff_norm: + print( + f"Found rank {i + 1} intervention to reduce worst-case covariance\ + difference norm by {self.target_erasure:.0%}" + ) + break + return AlfQLeaceEraser( proj_left, proj_right, @@ -374,3 +389,17 @@ def sigma_xz(self) -> Tensor: """The cross-covariance matrix.""" assert self.global_n > 1, "Call update() with labels before accessing sigma_xz" return self.sigma_xz_ / (self.global_n - 1) + + def to(self, device: torch.device | str) -> "AlfQLeaceFitter": + """Move fitter to a new device.""" + self.global_mean_x = self.global_mean_x.to(device) + self.global_mean_z = self.global_mean_z.to(device) + self.global_n = self.global_n.to(device) + + self.sigma_xz_ = self.sigma_xz_.to(device) + self.sigma_xx_ = self.sigma_xx_.to(device) + self.mean_x = self.mean_x.to(device) + self.n = self.n.to(device) + self.sigma_xx_z_ = self.sigma_xx_z_.to(device) + + return self diff --git a/concept_erasure/quadratic.py b/concept_erasure/quadratic.py index 480049f..ed93f12 100644 --- a/concept_erasure/quadratic.py +++ b/concept_erasure/quadratic.py @@ -208,3 +208,11 @@ def sigma_xx(self) -> Tensor: # Just apply Bessel's correction else: return S_hat / (n - 1) + + def to(self, device: torch.device | str) -> "QuadraticFitter": + """Move fitter to a new device.""" + self.mean_x = self.mean_x.to(device) + self.n = self.n.to(device) + self.sigma_xx_ = self.sigma_xx_.to(device) + + return self From df412845662840a6dbf3c45eaba8fcc7da1c1398 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 17 Dec 2024 01:54:31 +0000 Subject: [PATCH 5/8] Fix bug --- concept_erasure/alf_qleace.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/concept_erasure/alf_qleace.py b/concept_erasure/alf_qleace.py index 36cee99..cb45ff9 100644 --- a/concept_erasure/alf_qleace.py +++ b/concept_erasure/alf_qleace.py @@ -54,6 +54,7 @@ def Q(self) -> Tensor: def __call__(self, x: Tensor) -> Tensor: """Apply the projection to the input tensor.""" + # Delta comes from the mean centered input distribution delta = x - self.bias if self.bias is not None else x # Ensure we do the matmul in the most efficient order. @@ -61,18 +62,18 @@ def __call__(self, x: Tensor) -> Tensor: # Apply the ALF-QLEACE projection v = self.alf_qleace_vecs - x_ = x_ - (v @ x.mH).mH @ v + x_ = x_ - (v @ x_.mH).mH @ v return x_.type_as(x) def to(self, device: torch.device | str) -> "AlfQLeaceEraser": """Move eraser to a new device.""" - self.proj_left = self.proj_left.to(device) - self.proj_right = self.proj_right.to(device) - self.bias = self.bias.to(device) if self.bias is not None else None - self.alf_qleace_vecs = self.alf_qleace_vecs.to(device) - - return self + return AlfQLeaceEraser( + self.proj_left.to(device), + self.proj_right.to(device), + self.bias.to(device) if self.bias is not None else None, + self.alf_qleace_vecs.to(device), + ) class AlfQLeaceFitter: From 605e9430403d4499d8f4a873a44ef6e8e4104cc6 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 18 Dec 2024 04:40:49 +0000 Subject: [PATCH 6/8] fixup --- concept_erasure/alf_qleace.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/concept_erasure/alf_qleace.py b/concept_erasure/alf_qleace.py index cb45ff9..ed7e789 100644 --- a/concept_erasure/alf_qleace.py +++ b/concept_erasure/alf_qleace.py @@ -59,10 +59,7 @@ def __call__(self, x: Tensor) -> Tensor: # Ensure we do the matmul in the most efficient order. x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH - - # Apply the ALF-QLEACE projection - v = self.alf_qleace_vecs - x_ = x_ - (v @ x_.mH).mH @ v + x_ = x_ - (x_ @ self.alf_qleace_vecs.mH) @ self.alf_qleace_vecs return x_.type_as(x) From 18afdbef12326d07efe89e2925f85bd2625754c2 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 22 Jan 2025 23:22:17 +0000 Subject: [PATCH 7/8] Update eigh --- concept_erasure/alf_qleace.py | 230 ++++++++++++++++----------- concept_erasure/leace.py | 16 +- concept_erasure/optimal_transport.py | 6 +- 3 files changed, 158 insertions(+), 94 deletions(-) diff --git a/concept_erasure/alf_qleace.py b/concept_erasure/alf_qleace.py index ed7e789..8203315 100644 --- a/concept_erasure/alf_qleace.py +++ b/concept_erasure/alf_qleace.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Literal import torch from torch import Tensor @@ -7,6 +8,8 @@ from .groupby import groupby from .shrinkage import optimal_linear_shrinkage +LinearErasureMethod = Literal["leace", "orth", "none"] + @dataclass(frozen=True) class AlfQLeaceEraser: @@ -22,8 +25,8 @@ class AlfQLeaceEraser: identity, given by torch.eye(d) - alf_qleace_vec @ alf_qleace_vec. """ - proj_left: Tensor - proj_right: Tensor + proj_left: Tensor | None + proj_right: Tensor | None bias: Tensor | None alf_qleace_vecs: Tensor @@ -33,8 +36,11 @@ def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "AlfQLeaceEraser": return AlfQLeaceFitter.fit(x, z, **kwargs).eraser @property - def P(self) -> Tensor: + def P(self) -> Tensor | None: """The LEACE projection matrix.""" + if self.proj_left is None: + return None + eye = torch.eye( self.proj_left.shape[0], device=self.proj_left.device, @@ -55,22 +61,34 @@ def Q(self) -> Tensor: def __call__(self, x: Tensor) -> Tensor: """Apply the projection to the input tensor.""" # Delta comes from the mean centered input distribution - delta = x - self.bias if self.bias is not None else x - # Ensure we do the matmul in the most efficient order. - x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH + if self.proj_left is not None: + delta = x - self.bias if self.bias is not None else x + # Ensure we do the matmul in the most efficient order. + x_ = x - (delta @ self.proj_right.mH) @ self.proj_left.mH + else: + x_ = x + x_ = x_ - (x_ @ self.alf_qleace_vecs.mH) @ self.alf_qleace_vecs return x_.type_as(x) def to(self, device: torch.device | str) -> "AlfQLeaceEraser": """Move eraser to a new device.""" - return AlfQLeaceEraser( - self.proj_left.to(device), - self.proj_right.to(device), - self.bias.to(device) if self.bias is not None else None, - self.alf_qleace_vecs.to(device), - ) + if self.proj_left is not None: + return AlfQLeaceEraser( + self.proj_left.to(device), + self.proj_right.to(device), + self.bias.to(device) if self.bias is not None else None, + self.alf_qleace_vecs.to(device), + ) + else: + return AlfQLeaceEraser( + None, + None, + None, + self.alf_qleace_vecs.to(device), + ) class AlfQLeaceFitter: @@ -93,7 +111,7 @@ class AlfQLeaceFitter: sigma_xz_: Tensor """Unnormalized cross-covariance matrix X^T Z.""" - sigma_xx_: Tensor + sigma_xx_: Tensor | None """Unnormalized covariance matrix X^T X.""" sigma_xx_z_: Tensor @@ -104,7 +122,7 @@ class AlfQLeaceFitter: @classmethod def fit(cls, x: Tensor, z: Tensor, **kwargs) -> "AlfQLeaceFitter": - """Convenience method to fit a LeaceFitter on data and return it.""" + """Convenience method to fit a AlfQLeaceFitter on data and return it.""" n, d = x.shape _, k = z.reshape(n, -1).shape @@ -115,6 +133,7 @@ def __init__( self, x_dim: int, z_dim: int, + method: LinearErasureMethod = "leace", *, affine: bool = True, constrain_cov_trace: bool = True, @@ -125,11 +144,12 @@ def __init__( max_rank: int | None = None, target_erasure: float = 0.9, ): - """Initialize a `LeaceFitter`. + """Initialize a `AlfQLeaceFitter`. Args: x_dim: Dimensionality of the representation. z_dim: Dimensionality of the concept. + method: Type of projection matrix to use. affine: Whether to use a bias term to ensure the unconditional mean of the features remains the same after erasure. constrain_cov_trace: Whether to constrain the trace of the covariance of X @@ -153,7 +173,7 @@ def __init__( self.x_dim = x_dim self.z_dim = z_dim - + self.method = method self.affine = affine self.constrain_cov_trace = constrain_cov_trace self.shrinkage = shrinkage @@ -168,7 +188,12 @@ def __init__( self.global_n = torch.tensor(0, device=device) self.sigma_xz_ = torch.zeros(x_dim, z_dim, device=device, dtype=dtype) - self.sigma_xx_ = torch.zeros(x_dim, x_dim, device=device, dtype=dtype) + if self.method == "leace": + self.sigma_xx_ = torch.zeros(x_dim, x_dim, device=device, dtype=dtype) + elif self.method == "orth" or self.method == "none": + self.sigma_xx_ = None + else: + raise ValueError(f"Unknown projection type {self.method}") self.mean_x = torch.zeros(z_dim, x_dim, device=device, dtype=dtype) self.n = torch.zeros(z_dim, device=device) @@ -201,7 +226,9 @@ def update(self, x: Tensor, z: Tensor) -> "AlfQLeaceFitter": delta_x2 = x - self.global_mean_x # Update the covariance matrix of X for LEACE - self.sigma_xx_.addmm_(delta_x.mH, delta_x2) + if self.method == "leace": + assert self.sigma_xx_ is not None + self.sigma_xx_.addmm_(delta_x.mH, delta_x2) z = z.reshape(n, -1).type_as(x) assert z.shape[-1] == c, f"Unexpected number of classes {z.shape[-1]}" @@ -239,81 +266,95 @@ def eraser(self) -> AlfQLeaceEraser: self.x_dim, device=self.global_mean_x.device, dtype=self.global_mean_x.dtype ) - # Compute LEACE component - # Compute the whitening and unwhitening matrices - sigma = self.sigma_xx - - # Find the transformation that minimizes - L, V = torch.linalg.eigh(sigma) - - # Threshold used by torch.linalg.pinv - mask = L > (L[-1] * sigma.shape[-1] * torch.finfo(L.dtype).eps) - - # Assuming PSD; account for numerical error - L.clamp_min_(0.0) - - W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mH - W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mH - - u, s, _ = torch.linalg.svd(W @ self.sigma_xz, full_matrices=False) - - # Throw away singular values that are too small - u *= s > self.svd_tol - - proj_left = W_inv @ u - proj_right = u.mH @ W - - if self.constrain_cov_trace: + proj_left = None + proj_right = None + if self.method != "none": + # Compute LEACE component + # Compute the whitening and unwhitening matrices + if self.method == "leace": + sigma = self.sigma_xx + L, V = torch.linalg.eigh(sigma.double()) + L, V = L.to(sigma.dtype), V.to(sigma.dtype) + torch.allclose(sigma, sigma.T, rtol=1e-05, atol=1e-08) + + # Threshold used by torch.linalg.pinv + mask = L > (L[-1] * sigma.shape[-1] * torch.finfo(L.dtype).eps) + + # Assuming PSD; account for numerical error + L.clamp_min_(0.0) + + W = V * torch.where(mask, L.rsqrt(), 0.0) @ V.mH + W_inv = V * torch.where(mask, L.sqrt(), 0.0) @ V.mH + else: + W, W_inv = eye, eye + + u, s, _ = torch.linalg.svd(W @ self.sigma_xz, full_matrices=False) + + # Throw away singular values that are too small + u *= s > self.svd_tol + + proj_left = W_inv @ u + proj_right = u.mH @ W + + if self.constrain_cov_trace and self.method == "leace": + P = eye - proj_left @ proj_right + + # Prevent the covariance trace from increasing + sigma = self.sigma_xx + old_trace = torch.trace(sigma) + new_trace = torch.trace(P @ sigma @ P.mH) + + # If applying the projection matrix increases the variance, this might + # cause instability, especially when erasure is applied multiple times. + # We regularize toward the orthogonal projection matrix to avoid this. + if new_trace.real > old_trace.real: + Q = eye - u @ u.mH + + # Set up the variables for the quadratic equation + x = new_trace + y = 2 * torch.trace(P @ sigma @ Q.mH) + z = torch.trace(Q @ sigma @ Q.mH) + w = old_trace + + # Solve for the mixture of P and Q that makes the trace equal to the + # trace of the original covariance matrix + discr = torch.sqrt( + 4 * w * x - 4 * w * y + 4 * w * z - 4 * x * z + y**2 + ) + alpha1 = (-y / 2 + z - discr / 2) / (x - y + z) + alpha2 = (-y / 2 + z + discr / 2) / (x - y + z) + + # Choose the positive root + alpha = torch.where(alpha1.real > 0, alpha1, alpha2).clamp(0, 1) + P = alpha * P + (1 - alpha) * Q + + # TODO: Avoid using SVD here + u, s, vh = torch.linalg.svd(eye - P) + proj_left = u * s.sqrt() + proj_right = vh * s.sqrt() + + # Apply LEACE to the class-conditional covariance matrices + eye = torch.eye( + proj_left.shape[0], + device=proj_left.device, + dtype=proj_left.dtype, + ) P = eye - proj_left @ proj_right + else: + P = eye - # Prevent the covariance trace from increasing - sigma = self.sigma_xx - old_trace = torch.trace(sigma) - new_trace = torch.trace(P @ sigma @ P.mH) - - # If applying the projection matrix increases the variance, this might - # cause instability, especially when erasure is applied multiple times. - # We regularize toward the orthogonal projection matrix to avoid this. - if new_trace.real > old_trace.real: - Q = eye - u @ u.mH - - # Set up the variables for the quadratic equation - x = new_trace - y = 2 * torch.trace(P @ sigma @ Q.mH) - z = torch.trace(Q @ sigma @ Q.mH) - w = old_trace - - # Solve for the mixture of P and Q that makes the trace equal to the - # trace of the original covariance matrix - discr = torch.sqrt( - 4 * w * x - 4 * w * y + 4 * w * z - 4 * x * z + y**2 - ) - alpha1 = (-y / 2 + z - discr / 2) / (x - y + z) - alpha2 = (-y / 2 + z + discr / 2) / (x - y + z) - - # Choose the positive root - alpha = torch.where(alpha1.real > 0, alpha1, alpha2).clamp(0, 1) - P = alpha * P + (1 - alpha) * Q + # """Apply the projection to the input tensor.""" + # delta = x - self.bias if self.bias is not None else x - # TODO: Avoid using SVD here - u, s, vh = torch.linalg.svd(eye - P) - proj_left = u * s.sqrt() - proj_right = vh * s.sqrt() + # # Ensure we do the matmul in the most efficient order. + # x_ = x - (x @ self.proj_right.mH) @ self.proj_left.mH + # return x_.type_as(x) # Compute ALF-Q component - - # Apply LEACE to the class-conditional covariance matrices - eye = torch.eye( - proj_left.shape[0], - device=proj_left.device, - dtype=proj_left.dtype, - ) - P = eye - proj_left @ proj_right - transformed_sigma_xx_z_ = torch.stack( [P @ self.sigma_xx_z_[i] @ P for i in range(self.z_dim)] ) - base_cov_diff_norm = ( + base_max_cov_diff_norm = ( (transformed_sigma_xx_z_ - transformed_sigma_xx_z_.mean(dim=0)) .norm(dim=(1, 2)) .max() @@ -322,8 +363,10 @@ def eraser(self) -> AlfQLeaceEraser: # Erase to max_rank principal directions to minimize the worst-case # covariance difference between classes principal_directions = [] - max_rank = self.max_rank or transformed_sigma_xx_z_.flatten(1).shape[1] - for i in range(max_rank): + max_rank = self.max_rank or transformed_sigma_xx_z_.shape[-1] + from tqdm import tqdm + + for i in tqdm(range(max_rank)): # Compute the class conditional covariance differences from the mean mean_sigma_xx_z = transformed_sigma_xx_z_.mean(dim=0) sigma_xx_z_diffs = transformed_sigma_xx_z_ - mean_sigma_xx_z @@ -341,13 +384,16 @@ def eraser(self) -> AlfQLeaceEraser: [proj_qleace @ sigma @ proj_qleace for sigma in transformed_sigma_xx_z_] ) - current_cov_diff_norm = ( + current_max_cov_diff_norm = ( (transformed_sigma_xx_z_ - transformed_sigma_xx_z_.mean(dim=0)) .norm(dim=(1, 2)) .max() ) - if current_cov_diff_norm < (1 - self.target_erasure) * base_cov_diff_norm: + if ( + current_max_cov_diff_norm + < (1 - self.target_erasure) * base_max_cov_diff_norm + ): print( f"Found rank {i + 1} intervention to reduce worst-case covariance\ difference norm by {self.target_erasure:.0%}" @@ -395,8 +441,10 @@ def to(self, device: torch.device | str) -> "AlfQLeaceFitter": self.global_n = self.global_n.to(device) self.sigma_xz_ = self.sigma_xz_.to(device) - self.sigma_xx_ = self.sigma_xx_.to(device) - self.mean_x = self.mean_x.to(device) + if self.sigma_xx_ is not None: + self.sigma_xx_ = self.sigma_xx_.to(device) + if self.mean_x is not None: + self.mean_x = self.mean_x.to(device) self.n = self.n.to(device) self.sigma_xx_z_ = self.sigma_xx_z_.to(device) diff --git a/concept_erasure/leace.py b/concept_erasure/leace.py index ce58f36..ee8e119 100644 --- a/concept_erasure/leace.py +++ b/concept_erasure/leace.py @@ -192,7 +192,9 @@ def eraser(self) -> LeaceEraser: # Compute the whitening and unwhitening matrices if self.method == "leace": sigma = self.sigma_xx - L, V = torch.linalg.eigh(sigma) + L, V = torch.linalg.eigh(sigma.double()) + L, V = L.to(sigma.dtype), V.to(sigma.dtype) + torch.allclose(sigma, sigma.T, rtol=1e-05, atol=1e-08) # Threshold used by torch.linalg.pinv mask = L > (L[-1] * sigma.shape[-1] * torch.finfo(L.dtype).eps) @@ -278,3 +280,15 @@ def sigma_xz(self) -> Tensor: """The cross-covariance matrix.""" assert self.n > 1, "Call update() with labels before accessing sigma_xz" return self.sigma_xz_ / (self.n - 1) + + def to(self, device: str | torch.device | None = None) -> "LeaceFitter": + """Move the fitter to a new device.""" + self.mean_x = self.mean_x.to(device) + self.mean_z = self.mean_z.to(device) + self.sigma_xz_ = self.sigma_xz_.to(device) + self.n = self.n.to(device) + + if self.n > 1 and self.sigma_xx_ is not None: + self.sigma_xx_ = self.sigma_xx_.to(device) + + return self diff --git a/concept_erasure/optimal_transport.py b/concept_erasure/optimal_transport.py index 1f5acf1..2432192 100644 --- a/concept_erasure/optimal_transport.py +++ b/concept_erasure/optimal_transport.py @@ -12,14 +12,16 @@ def is_positive_definite(A: Tensor) -> Tensor: @torch.jit.script def psd_sqrt(A: Tensor) -> Tensor: """Compute the unique p.s.d. square root of a positive semidefinite matrix.""" - L, U = torch.linalg.eigh(A) + L, U = torch.linalg.eigh(A.double()) + L, U = L.to(A.dtype), U.to(A.dtype) L = L[..., None, :].clamp_min(0.0) return U * L.sqrt() @ U.mH def psd_sqrt_rsqrt(A: Tensor) -> tuple[Tensor, Tensor]: """Efficiently compute both the p.s.d. sqrt & pinv sqrt of p.s.d. matrix `A`.""" - L, U = torch.linalg.eigh(A) + L, U = torch.linalg.eigh(A.double()) + L, U = L.to(A.dtype), U.to(A.dtype) L = L[..., None, :].clamp_min(0.0) # Square root is easy From 67052e729ebe8314a2b9fc6b85db4fd3b73f9eda Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 27 Jan 2025 01:06:35 +0000 Subject: [PATCH 8/8] Add random eraser --- concept_erasure/re.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 concept_erasure/re.py diff --git a/concept_erasure/re.py b/concept_erasure/re.py new file mode 100644 index 0000000..1e02320 --- /dev/null +++ b/concept_erasure/re.py @@ -0,0 +1,37 @@ +import torch +from torch import Tensor + + +class RandomEraser: + """Random eraser that projects a random subspace to the nullspace.""" + + basis: Tensor + + def __init__(self, ndims: int, erase_dims: int, **kwargs): + """Create a RandomEraser that projects erase_dims dimensions to zero. + + Args: + ndims: Total dimensions of the input space + erase_dims: Number of dimensions to project to zero + """ + # Create a random orthonormal basis + rand_basis = torch.randn(ndims, ndims) + Q, R = torch.linalg.qr(rand_basis) + + # Take the first erase_dims columns to get basis for subspace to nullify + Q = Q[:, :erase_dims] + + self.basis = Q + + def __call__(self, x: Tensor) -> Tensor: + """Apply the projection to the input tensor.""" + result = x - (x @ self.basis) @ self.basis.T + torch.testing.assert_close( + result, result - (result @ self.basis) @ self.basis.T + ) + return result + + def to(self, device: torch.device | str) -> "RandomEraser": + """Move eraser to a new device.""" + self.basis = self.basis.to(device) + return self