Skip to content

Commit

Permalink
Merge pull request #70 from stanfordnlp/zen/updategenerate
Browse files Browse the repository at this point in the history
Model generation API simplified and cleanup tech debt on redundant variables
  • Loading branch information
frankaging authored Jan 19, 2024
2 parents d0cddb8 + 9a6af4e commit fbc2c7e
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 183 deletions.
1 change: 1 addition & 0 deletions pyvene/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@
from .models.gru.modelings_intervenable_gru import create_gru
from .models.gru.modelings_intervenable_gru import create_gru_lm
from .models.gru.modelings_intervenable_gru import create_gru_classifier
from .models.gru.modelings_gru import GRUConfig
from .models.llama.modelings_intervenable_llama import create_llama
from .models.mlp.modelings_intervenable_mlp import create_mlp_classifier
35 changes: 27 additions & 8 deletions pyvene/models/configuration_intervenable_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json, warnings
import json, warnings, torch
from collections import OrderedDict, namedtuple
from typing import Any, List, Mapping, Optional

Expand All @@ -14,36 +14,52 @@
"intervenable_unit max_number_of_units "
"intervenable_low_rank_dimension "
"subspace_partition group_key intervention_link_key intervenable_moe "
"source_representation",
defaults=(0, "block_output", "pos", 1, None, None, None, None, None, None),
"source_representation hidden_source_representation",
defaults=(
0, "block_output", "pos", 1,
None, None, None, None, None, None, None),
)


class IntervenableConfig(PretrainedConfig):
def __init__(
self,
intervenable_model_type=None,
intervenable_representations=[IntervenableRepresentationConfig()],
intervenable_interventions_type=VanillaIntervention,
mode="parallel",
intervenable_interventions=[None],
sorted_keys=None,
intervention_dimensions=None,
intervenable_model_type=None,
**kwargs,
):
self.intervenable_model_type = intervenable_model_type
self.intervenable_representations = intervenable_representations
if isinstance(intervenable_representations, list):
self.intervenable_representations = intervenable_representations
else:
self.intervenable_representations = [intervenable_representations]
self.intervenable_interventions_type = intervenable_interventions_type
self.mode = mode
self.intervenable_interventions = intervenable_interventions
self.sorted_keys = sorted_keys
self.intervention_dimensions = intervention_dimensions
self.intervenable_model_type = intervenable_model_type
super().__init__(**kwargs)

def __repr__(self):
intervenable_representations = []
for reprs in self.intervenable_representations:
if isinstance(reprs, list):
reprs = IntervenableRepresentationConfig(*reprs)
new_d = {}
for k, v in reprs._asdict().items():
if type(v) not in {str, int, list, tuple, dict} and v is not None and v != [None]:
new_d[k] = "PLACEHOLDER"
else:
new_d[k] = v
intervenable_representations += [new_d]
_repr = {
"intervenable_model_type": str(self.intervenable_model_type),
"intervenable_representations": tuple(self.intervenable_representations),
"intervenable_representations": tuple(intervenable_representations),
"intervenable_interventions_type": str(
self.intervenable_interventions_type
),
Expand All @@ -52,9 +68,12 @@ def __repr__(self):
str(intervenable_intervention)
for intervenable_intervention in self.intervenable_interventions
],
"sorted_keys": tuple(self.sorted_keys),
"sorted_keys": tuple(self.sorted_keys) if self.sorted_keys is not None else str(self.sorted_keys),
"intervention_dimensions": str(self.intervention_dimensions),
}
_repr_string = json.dumps(_repr, indent=4)

return f"IntervenableConfig\n{_repr_string}"

def __str__(self):
return self.__repr__()
74 changes: 42 additions & 32 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,12 @@ def __init__(self, intervenable_config, model, **kwargs):
if type(intervention_type) != list
else intervention_type[i]
)
other_medata = representation._asdict()
other_medata["use_fast"] = self.use_fast
intervention = intervention_function(
get_intervenable_dimension(
get_internal_model_type(model), model.config, representation
),
proj_dim=representation.intervenable_low_rank_dimension,
# additional args
subspace_partition=representation.subspace_partition,
use_fast=self.use_fast,
source_representation=representation.source_representation,
), **other_medata
)
if representation.intervention_link_key in self._intervention_pointers:
self._intervention_reverse_link[
Expand All @@ -125,25 +122,13 @@ def __init__(self, intervenable_config, model, **kwargs):
intervention = self._intervention_pointers[
representation.intervention_link_key
]
else:
intervention = intervention_function(
get_intervenable_dimension(
get_internal_model_type(model), model.config, representation
),
proj_dim=representation.intervenable_low_rank_dimension,
# additional args
subspace_partition=representation.subspace_partition,
use_fast=self.use_fast,
source_representation=representation.source_representation,
)
# we cache the intervention for sharing if the key is not None
if representation.intervention_link_key is not None:
self._intervention_pointers[
representation.intervention_link_key
] = intervention
self._intervention_reverse_link[
_key
] = f"link#{representation.intervention_link_key}"
elif representation.intervention_link_key is not None:
self._intervention_pointers[
representation.intervention_link_key
] = intervention
self._intervention_reverse_link[
_key
] = f"link#{representation.intervention_link_key}"
if isinstance(
intervention,
CollectIntervention
Expand Down Expand Up @@ -419,12 +404,34 @@ def save(
)
saving_config.intervenable_interventions_type = []
saving_config.intervention_dimensions = []

# handle constant source reprs if passed in.
serialized_intervenable_representations = []
for reprs in saving_config.intervenable_representations:
serialized_reprs = {}
for k, v in reprs._asdict().items():
if k == "hidden_source_representation":
continue
if k == "source_representation":
# hidden flag only set here
if v is not None:
serialized_reprs["hidden_source_representation"] = True
serialized_reprs[k] = None
else:
serialized_reprs[k] = v
serialized_intervenable_representations += [
IntervenableRepresentationConfig(**serialized_reprs)
]
saving_config.intervenable_representations = \
serialized_intervenable_representations

for k, v in self.interventions.items():
intervention = v[0]
saving_config.intervenable_interventions_type += [str(type(intervention))]
binary_filename = f"intkey_{k}.bin"
# save intervention binary file
if isinstance(intervention, TrainableIntervention):
if isinstance(intervention, TrainableIntervention) or \
intervention.source_representation is not None:
logging.warn(f"Saving trainable intervention to {binary_filename}.")
torch.save(
intervention.state_dict(),
Expand Down Expand Up @@ -514,17 +521,21 @@ def load(load_directory, model, local_directory=None, from_huggingface_hub=False
for i, (k, v) in enumerate(intervenable.interventions.items()):
intervention = v[0]
binary_filename = f"intkey_{k}.bin"
if isinstance(intervention, TrainableIntervention):
if isinstance(intervention, TrainableIntervention) or \
intervention.is_source_constant:
if not os.path.exists(load_directory) or from_huggingface_hub:
hf_hub_download(
repo_id=load_directory,
filename=binary_filename,
cache_dir=local_directory,
)
logging.warn(f"Loading trainable intervention from {binary_filename}.")
intervention.load_state_dict(
torch.load(os.path.join(load_directory, binary_filename))
)
saved_state_dict = torch.load(os.path.join(load_directory, binary_filename))
if intervention.is_source_constant:
intervention.register_buffer(
'source_representation', saved_state_dict['source_representation']
)
intervention.load_state_dict(saved_state_dict)
intervention.interchange_dim = saving_config.intervention_dimensions[i]

return intervenable
Expand Down Expand Up @@ -805,8 +816,7 @@ def hook_callback(model, args, kwargs, output=None):
# TODO: need to figure out why clone is needed
if not self.is_model_stateless:
selected_output = selected_output.clone()



if isinstance(
intervention,
CollectIntervention
Expand Down
19 changes: 10 additions & 9 deletions pyvene/models/intervention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,16 @@ def _do_intervention_by_swap(
"""The basic do function that guards interventions"""
if mode == "collect":
assert source is None
# auto broadcast
if base.shape != source.shape:
try:
source = broadcast_tensor(source, base.shape)
except:
raise ValueError(
f"source with shape {source.shape} cannot be broadcasted "
f"into base with shape {base.shape}."
)
else:
# auto broadcast
if base.shape != source.shape:
try:
source = broadcast_tensor(source, base.shape)
except:
raise ValueError(
f"source with shape {source.shape} cannot be broadcasted "
f"into base with shape {base.shape}."
)
# interchange
if use_fast:
if subspaces is not None:
Expand Down
Loading

0 comments on commit fbc2c7e

Please sign in to comment.