Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear Discriminant Analysis MVP #268

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions elk/training/common.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@ class FitterConfig(Serializable, decode_into_subclasses=True):
@dataclass
class Reporter(PlattMixin):
weight: Tensor
eraser: LeaceEraser
eraser: LeaceEraser | None = None

def __post_init__(self):
# Platt scaling parameters
@@ -27,5 +27,8 @@ def __post_init__(self):

def __call__(self, hiddens: Tensor) -> Tensor:
"""Return the predicted log odds on input `x`."""
raw_scores = self.eraser(hiddens) @ self.weight.mT
if self.eraser is not None:
hiddens = self.eraser(hiddens)

raw_scores = hiddens @ self.weight.mT
return raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
98 changes: 98 additions & 0 deletions elk/training/lda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""An ELK reporter network."""

from dataclasses import dataclass

import torch
import torch.nn.functional as F
from concept_erasure import optimal_linear_shrinkage
from einops import rearrange
from torch import Tensor

from ..utils.math_util import cov, cov_mean_fused
from .common import FitterConfig, Reporter


@dataclass
class LdaConfig(FitterConfig):
"""Configuration for an LdaFitter."""

anchor_gamma: float = 1.0
"""Gamma parameter for anchor regression."""

invariance_weight: float = 0.5
"""Weight of the prompt invariance term in the loss."""

l2_penalty: float = 0.0

def __post_init__(self):
assert self.anchor_gamma >= 0, "anchor_gamma must be non-negative"
assert 0 <= self.invariance_weight <= 1, "invariance_weight must be in [0, 1]"
assert self.l2_penalty >= 0, "l2_penalty must be non-negative"


class LdaFitter:
"""Linear Discriminant Analysis (LDA)"""

config: LdaConfig

def __init__(self, cfg: LdaConfig):
super().__init__()
self.config = cfg

def fit(self, hiddens: Tensor, labels: Tensor) -> Reporter:
"""Fit the probe to the contrast set `hiddens`.

Args:
hiddens: The contrast set of shape [batch, variants, (choices,) dim].
labels: Integer labels of shape [batch].
"""
n, v, *_ = hiddens.shape
assert n == labels.shape[0], "hiddens and labels must have the same batch size"

# This is a contrast set; create a true-false label for each element
if len(hiddens.shape) == 4:
hiddens = rearrange(hiddens, "n v k d -> (n k) v d")
labels = F.one_hot(labels.long()).flatten()

n = len(labels)
counts = (labels.sum(), n - labels.sum())
else:
counts = torch.bincount(labels)
assert len(counts) == 2, "Only binary classification is supported for now"

# Construct targets for the least-squares dual problem
z = torch.where(labels.bool(), n / counts[0], -n / counts[1]).unsqueeze(1)

# Adjust X and Z for anchor regression <https://arxiv.org/abs/1801.06229>
gamma = self.config.anchor_gamma
if gamma != 1.0:
# Implicitly compute n x n orthogonal projection onto the column space of
# the anchor variables without materializing the whole matrix. Since the
# anchors are one-hot, it turns out this is equivalent to adding a multiple
# of the anchor-conditional means.
# In general you're supposed to adjust the labels too, but we don't need
# to do that because by construction the anchor-conditional means of the
# labels are already all zero.
hiddens = hiddens + (gamma**0.5 - 1) * hiddens.mean(0)

# We can decompose the covariance matrix into the sum of the within-cluster
# covariance and the between-cluster covariance. This allows us to put extra
# weight on the within-cluster variance to encourage invariance to the prompt.
# NOTE: We're not applying shrinkage to each cluster covariance matrix because
# we're averaging over them, which should reduce the variance of the estimate
# a lot. Shrinkage could make MSE worse in this case.
S_between = optimal_linear_shrinkage(cov(hiddens.mean(1)), n)
S_within = cov_mean_fused(hiddens)

# Convex combination but multiply by 2 to keep the same scale
alpha = 2 * self.config.invariance_weight
S = alpha * S_within + (2 - alpha) * S_between

# Add ridge penalty
torch.linalg.diagonal(S).add_(self.config.l2_penalty)

# Broadcast the labels across variants
sigma_xz = cov(hiddens, z.expand_as(hiddens[..., 0]).unsqueeze(-1))
w = torch.linalg.solve(S, sigma_xz.squeeze(-1))

return Reporter(w[None])
11 changes: 10 additions & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
@@ -18,14 +18,20 @@
from .ccs_reporter import CcsConfig, CcsReporter
from .common import FitterConfig
from .eigen_reporter import EigenFitter, EigenFitterConfig
from .lda import LdaConfig, LdaFitter


@dataclass
class Elicit(Run):
"""Full specification of a reporter training run."""

net: FitterConfig = subgroups(
{"ccs": CcsConfig, "eigen": EigenFitterConfig}, default="eigen"
{
"ccs": CcsConfig,
"eigen": EigenFitterConfig,
"lda": LdaConfig,
},
default="eigen",
)
"""Config for building the reporter network."""

@@ -74,6 +80,7 @@ def apply_to_layer(
if not all(other_h.shape[-2] == k for other_h, _, _ in rest):
raise ValueError("All datasets must have the same number of classes")

train_loss = None
reporter_dir, lr_dir = self.create_models_dir(assert_type(Path, self.out_dir))
train_loss = None

@@ -111,6 +118,8 @@ def apply_to_layer(
torch.cat(label_list),
torch.cat(hidden_list),
)
elif isinstance(self.net, LdaConfig):
reporter = LdaFitter(self.net).fit(first_train_h, train_gt)
else:
raise ValueError(f"Unknown reporter config type: {type(self.net)}")

4 changes: 2 additions & 2 deletions elk/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -10,14 +10,14 @@
)
from .gpu_utils import select_usable_devices
from .hf_utils import instantiate_model, instantiate_tokenizer, is_autoregressive
from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained
from .math_util import cov, cov_mean_fused, stochastic_round_constrained
from .pretty import Color, colorize
from .tree_utils import pytree_map
from .typing import assert_type, float_to_int16, int16_to_float32

__all__ = [
"assert_type",
"batch_cov",
"cov",
"Color",
"colorize",
"cov_mean_fused",
28 changes: 19 additions & 9 deletions elk/utils/math_util.py
Original file line number Diff line number Diff line change
@@ -5,18 +5,28 @@
from torch import Tensor


@torch.jit.script
def batch_cov(x: Tensor) -> Tensor:
"""Compute a batch of covariance matrices.
def cov(
x: Tensor, y: Tensor | None = None, dim: int | None = None, unbiased: bool = False
) -> Tensor:
"""Compute the (cross-)covariance matrix for `x` (and `y`).

Args:
x: A tensor of shape [..., n, d].

Returns:
A tensor of shape [..., d, d].
x: A tensor of shape [*, d].
y: An optional tensor of shape [*, k]. If not provided, defaults to `x`.
dim: The dimension to reduce over. If not provided, defaults to all but the
last dimension.
unbiased: Whether to use Bessel's correction.
"""
x_ = x - x.mean(dim=-2, keepdim=True)
return x_.mT @ x_ / x_.shape[-2]
if y is None:
y = x
if dim is None:
dim = 0
x = x.flatten(0, -2)
y = y.flatten(0, -2)

x = x - x.mean(dim)
y = y - y.mean(dim)
return x.T @ y / (x.shape[dim] - unbiased)


@torch.jit.script
4 changes: 2 additions & 2 deletions tests/test_eigen_reporter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from elk.training import EigenFitter, EigenFitterConfig
from elk.utils import batch_cov, cov_mean_fused
from elk.utils import cov, cov_mean_fused


def test_eigen_reporter():
@@ -31,7 +31,7 @@ def test_eigen_reporter():

# Check that the streaming covariance is correct
neg_centroids, pos_centroids = x_neg.mean(dim=1), x_pos.mean(dim=1)
true_cov = 0.5 * (batch_cov(neg_centroids) + batch_cov(pos_centroids))
true_cov = 0.5 * (cov(neg_centroids) + cov(pos_centroids))
torch.testing.assert_close(reporter.intercluster_cov, true_cov)

# Check that the streaming negative covariance is correct
4 changes: 2 additions & 2 deletions tests/test_math.py
Original file line number Diff line number Diff line change
@@ -6,12 +6,12 @@
from hypothesis import given
from hypothesis import strategies as st

from elk.utils import batch_cov, cov_mean_fused, stochastic_round_constrained
from elk.utils import cov, cov_mean_fused, stochastic_round_constrained


def test_cov_mean_fused():
X = torch.randn(10, 500, 100, dtype=torch.float64)
cov_gt = batch_cov(X).mean(dim=0)
cov_gt = cov(X).mean(dim=0)
cov_fused = cov_mean_fused(X)
assert torch.allclose(cov_gt, cov_fused)