Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[P1] Adding in constant source intervention support with new tests #59

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyvene/models/configuration_intervenable_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
"intervenable_layer intervenable_representation_type "
"intervenable_unit max_number_of_units "
"intervenable_low_rank_dimension "
"subspace_partition group_key intervention_link_key",
defaults=(0, "block_output", "pos", 1, None, None, None, None),
"subspace_partition group_key intervention_link_key intervenable_moe "
"source_representation",
defaults=(0, "block_output", "pos", 1, None, None, None, None, None, None),
)


Expand Down
82 changes: 51 additions & 31 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,10 @@ def __init__(self, intervenable_config, model, **kwargs):
get_internal_model_type(model), model.config, representation
),
proj_dim=representation.intervenable_low_rank_dimension,
# we can partition the subspace, and intervene on subspace
# additional args
subspace_partition=representation.subspace_partition,
use_fast=self.use_fast,
source_representation=representation.source_representation,
)
if representation.intervention_link_key in self._intervention_pointers:
self._intervention_reverse_link[
Expand All @@ -129,9 +131,10 @@ def __init__(self, intervenable_config, model, **kwargs):
get_internal_model_type(model), model.config, representation
),
proj_dim=representation.intervenable_low_rank_dimension,
# we can partition the subspace, and intervene on subspace
# additional args
subspace_partition=representation.subspace_partition,
use_fast=self.use_fast,
source_representation=representation.source_representation,
)
# we cache the intervention for sharing if the key is not None
if representation.intervention_link_key is not None:
Expand Down Expand Up @@ -803,8 +806,9 @@ def hook_callback(model, args, kwargs, output=None):
if not self.is_model_stateless:
selected_output = selected_output.clone()


if isinstance(
intervention,
intervention,
CollectIntervention
):
intervened_representation = do_intervention(
Expand All @@ -820,16 +824,24 @@ def hook_callback(model, args, kwargs, output=None):
# no-op to the output

else:
intervened_representation = do_intervention(
selected_output,
self._reconcile_stateful_cached_activations(
key,
if intervention.is_source_constant:
intervened_representation = do_intervention(
selected_output,
unit_locations_base[key_i],
),
intervention,
subspaces[key_i] if subspaces is not None else None,
)
None,
intervention,
subspaces[key_i] if subspaces is not None else None,
)
else:
intervened_representation = do_intervention(
selected_output,
self._reconcile_stateful_cached_activations(
key,
selected_output,
unit_locations_base[key_i],
),
intervention,
subspaces[key_i] if subspaces is not None else None,
)

# setter can produce hot activations for shared subspace interventions if linked
if key in self._intervention_reverse_link:
Expand Down Expand Up @@ -873,10 +885,10 @@ def _input_validation(
):
"""Fail fast input validation"""
if self.mode == "parallel":
assert "sources->base" in unit_locations
assert "sources->base" in unit_locations or "base" in unit_locations
elif activations_sources is None and self.mode == "serial":
assert "sources->base" not in unit_locations

# sources may contain None, but length should match
if sources is not None:
if len(sources) != len(self._intervention_group):
Expand Down Expand Up @@ -982,10 +994,7 @@ def _wait_for_forward_with_parallel_intervention(
for intervenable_key in intervenable_keys:
# skip in case smart jump
if intervenable_key in self.activations or \
isinstance(
self.interventions[intervenable_key][0],
CollectIntervention
):
self.interventions[intervenable_key][0].is_source_constant:
set_handlers = self._intervention_setter(
[intervenable_key],
[
Expand Down Expand Up @@ -1054,10 +1063,7 @@ def _wait_for_forward_with_serial_intervention(
for intervenable_key in intervenable_keys:
# skip in case smart jump
if intervenable_key in self.activations or \
isinstance(
self.interventions[intervenable_key][0],
CollectIntervention
):
self.interventions[intervenable_key][0].is_source_constant:
# set with intervened activation to source_i+1
set_handlers = self._intervention_setter(
[intervenable_key],
Expand All @@ -1080,21 +1086,30 @@ def _broadcast_unit_locations(
batch_size,
unit_locations
):
_unit_locations = copy.deepcopy(unit_locations)
_unit_locations = {}
for k, v in unit_locations.items():
# special broadcast for base-only interventions
is_base_only = False
if k == "base":
is_base_only = True
k = "sources->base"
if isinstance(v, int):
_unit_locations[k] = ([[[v]]*batch_size], [[[v]]*batch_size])
self.use_fast = True
elif isinstance(v[0], int) and isinstance(v[1], int):
elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int):
_unit_locations[k] = ([[[v[0]]]*batch_size], [[[v[1]]]*batch_size])
self.use_fast = True
elif isinstance(v[0], list) and isinstance(v[1], list):
pass # we don't support boardcase here yet.
elif len(v) == 2 and v[0] == None and isinstance(v[1], int):
_unit_locations[k] = (None, [[[v[1]]]*batch_size])
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and v[1] == None:
_unit_locations[k] = ([[[v[0]]]*batch_size], None)
self.use_fast = True
else:
raise ValueError(
f"unit_locations {unit_locations} contains invalid format."
)

if is_base_only:
_unit_locations[k] = (None, v)
else:
_unit_locations[k] = v
return _unit_locations

def forward(
Expand Down Expand Up @@ -1173,12 +1188,15 @@ def forward(
self._cleanup_states()

# if no source inputs, we are calling a simple forward
if sources is None and activations_sources is None:
if sources is None and activations_sources is None \
and unit_locations is None:
return self.model(**base), None

unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)

sources = [None] if sources is None else sources

self._input_validation(
base,
sources,
Expand Down Expand Up @@ -1287,6 +1305,8 @@ def generate(
unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)

sources = [None] if sources is None else None

self._input_validation(
base,
sources,
Expand Down
21 changes: 21 additions & 0 deletions pyvene/models/intervention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,19 @@ def __repr__(self):
def __str__(self):
return json.dumps(self.state_dict, indent=4)

def broadcast_tensor(x, target_shape):
# Ensure the last dimension of target_shape matches x's size
if target_shape[-1] != x.shape[-1]:
raise ValueError("The last dimension of target_shape must match the size of x")

# Create a shape for reshaping x that is compatible with target_shape
reshape_shape = [1] * (len(target_shape) - 1) + [x.shape[-1]]

# Reshape x and then broadcast it
x_reshaped = x.view(*reshape_shape)
broadcasted_x = x_reshaped.expand(*target_shape)
return broadcasted_x

def _do_intervention_by_swap(
base,
source,
Expand All @@ -50,6 +62,15 @@ def _do_intervention_by_swap(
"""The basic do function that guards interventions"""
if mode == "collect":
assert source is None
# auto broadcast
if base.shape != source.shape:
try:
source = broadcast_tensor(source, base.shape)
except:
raise ValueError(
f"source with shape {source.shape} cannot be broadcasted "
f"into base with shape {base.shape}."
)
# interchange
if use_fast:
if subspaces is not None:
Expand Down
Loading