Skip to content

Commit

Permalink
Merge pull request #69 from stanfordnlp/zen/updategenerate
Browse files Browse the repository at this point in the history
[Minor] update broadcast in generation
  • Loading branch information
frankaging authored Jan 18, 2024
2 parents 8b0ff47 + 62749d4 commit d0cddb8
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 111 deletions.
39 changes: 25 additions & 14 deletions pyvene/models/intervenable_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,13 +890,13 @@ def _input_validation(
assert "sources->base" not in unit_locations

# sources may contain None, but length should match
if sources is not None:
if sources is not None and not (len(sources) == 1 and sources[0] == None):
if len(sources) != len(self._intervention_group):
raise ValueError(
f"Source length {len(sources)} is not "
f"equal to intervention length {len(self._intervention_group)}."
)
else:
elif activations_sources is not None:
if len(activations_sources) != len(self._intervention_group):
raise ValueError(
f"Source activations length {len(activations_sources)} is not "
Expand Down Expand Up @@ -1084,6 +1084,7 @@ def _wait_for_forward_with_serial_intervention(
def _broadcast_unit_locations(
self,
batch_size,
intervention_group_size,
unit_locations
):
_unit_locations = {}
Expand All @@ -1094,16 +1095,25 @@ def _broadcast_unit_locations(
is_base_only = True
k = "sources->base"
if isinstance(v, int):
_unit_locations[k] = ([[[v]]*batch_size], [[[v]]*batch_size])
if is_base_only:
_unit_locations[k] = (None, [[[v]]*batch_size]*intervention_group_size)
else:
_unit_locations[k] = (
[[[v]]*batch_size]*intervention_group_size,
[[[v]]*batch_size]*intervention_group_size
)
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and isinstance(v[1], int):
_unit_locations[k] = ([[[v[0]]]*batch_size], [[[v[1]]]*batch_size])
_unit_locations[k] = (
[[[v[0]]]*batch_size]*intervention_group_size,
[[[v[1]]]*batch_size]*intervention_group_size
)
self.use_fast = True
elif len(v) == 2 and v[0] == None and isinstance(v[1], int):
_unit_locations[k] = (None, [[[v[1]]]*batch_size])
_unit_locations[k] = (None, [[[v[1]]]*batch_size]*intervention_group_size)
self.use_fast = True
elif len(v) == 2 and isinstance(v[0], int) and v[1] == None:
_unit_locations[k] = ([[[v[0]]]*batch_size], None)
_unit_locations[k] = ([[[v[0]]]*batch_size]*intervention_group_size, None)
self.use_fast = True
else:
if is_base_only:
Expand Down Expand Up @@ -1193,9 +1203,9 @@ def forward(
return self.model(**base), None

unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)
get_batch_size(base), len(self._intervention_group), unit_locations)

sources = [None] if sources is None else sources
sources = [None]*len(self._intervention_group) if sources is None else sources

self._input_validation(
base,
Expand Down Expand Up @@ -1265,7 +1275,7 @@ def generate(
sources: Optional[List] = None,
unit_locations: Optional[Dict] = None,
activations_sources: Optional[Dict] = None,
intervene_on_prompt: bool = True,
intervene_on_prompt: bool = False,
subspaces: Optional[List] = None,
**kwargs,
):
Expand Down Expand Up @@ -1298,14 +1308,15 @@ def generate(

self._intervene_on_prompt = intervene_on_prompt
self._is_generation = True

if sources is None and activations_sources is None:
return self.model.generate(inputs=base["input_ids"], **kwargs), None

if not intervene_on_prompt and unit_locations is None:
# that means, we intervene on every generated tokens!
unit_locations = {"base": 0}

unit_locations = self._broadcast_unit_locations(
get_batch_size(base), unit_locations)
get_batch_size(base), len(self._intervention_group), unit_locations)

sources = [None] if sources is None else None
sources = [None]*len(self._intervention_group) if sources is None else sources

self._input_validation(
base,
Expand Down
124 changes: 27 additions & 97 deletions tutorials/advanced_tutorials/Intervened_Model_Generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,23 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "93b918c2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2024-01-11 01:38:26,971] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
}
],
"outputs": [],
"source": [
"try:\n",
" # This library is our indicator that the required installs\n",
" # need to be done.\n",
" import pyvene\n",
"\n",
"except ModuleNotFoundError:\n",
" !pip install git+https://github.com/frankaging/pyvene.git"
" !pip install git+https://github.com/stanfordnlp/pyvene.git"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "aa6a75e7",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -92,7 +84,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "354ac7e7",
"metadata": {},
"outputs": [
Expand All @@ -110,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "3e2e363d",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -163,7 +155,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "22544382",
"metadata": {},
"outputs": [
Expand All @@ -177,16 +169,14 @@
],
"source": [
"def activation_addition_position_config(\n",
" model_type, intervention_type, start_layer_idx, end_layer_idx\n",
" intervention_type, start_layer_idx, end_layer_idx, embedding,\n",
"):\n",
" intervenable_config = IntervenableConfig(\n",
" intervenable_model_type=model_type,\n",
" intervenable_representations=[\n",
" IntervenableRepresentationConfig(\n",
" i, # layer\n",
" i, # layer\n",
" intervention_type, # intervention type\n",
" \"pos\", # intervention unit\n",
" 1, # max number of unit\n",
" source_representation=embedding\n",
" )\n",
" for i in range(start_layer_idx, end_layer_idx)\n",
" ],\n",
Expand All @@ -196,66 +186,14 @@
"\n",
"\n",
"config, tokenizer, tinystory = create_gpt_neo()\n",
"intervenable_config = activation_addition_position_config(\n",
" type(tinystory), \"mlp_output\", 0, 4\n",
")\n",
"intervenable = IntervenableModel(intervenable_config, tinystory)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8b835b5b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['layer.0.repr.mlp_output.unit.pos.nunit.1#0',\n",
" 'layer.1.repr.mlp_output.unit.pos.nunit.1#0',\n",
" 'layer.2.repr.mlp_output.unit.pos.nunit.1#0',\n",
" 'layer.3.repr.mlp_output.unit.pos.nunit.1#0']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"sad_token_id = tokenizer(\" Sad\")[\"input_ids\"][0]\n",
"happy_token_id = tokenizer(\" Happy\")[\"input_ids\"][0]\n",
"\n",
"beta = 0.3 # very hacky way to control the effect\n",
"sad_embedding = (\n",
" tinystory.transformer.wte(torch.tensor(sad_token_id))\n",
" .clone()\n",
" .unsqueeze(0)\n",
" .unsqueeze(0)\n",
") # make it a fake batch\n",
"happy_embedding = (\n",
" tinystory.transformer.wte(torch.tensor(happy_token_id))\n",
" .clone()\n",
" .unsqueeze(0)\n",
" .unsqueeze(0)\n",
") # make it a fake batch\n",
"beta = 0.3\n",
"sad_embedding = tinystory.transformer.wte(torch.tensor(sad_token_id))\n",
"happy_embedding = tinystory.transformer.wte(torch.tensor(happy_token_id))\n",
"sad_embedding *= beta\n",
"happy_embedding *= beta\n",
"\n",
"activations_sad_sources = dict(\n",
" zip(\n",
" intervenable.sorted_intervenable_keys,\n",
" [sad_embedding] * len(intervenable.sorted_intervenable_keys),\n",
" )\n",
")\n",
"activations_happy_sources = dict(\n",
" zip(\n",
" intervenable.sorted_intervenable_keys,\n",
" [happy_embedding] * len(intervenable.sorted_intervenable_keys),\n",
" )\n",
")\n",
"# we intervene on all of the mlp output\n",
"intervenable.sorted_intervenable_keys"
"happy_embedding *= beta"
]
},
{
Expand Down Expand Up @@ -298,19 +236,16 @@
}
],
"source": [
"intervenable_config = activation_addition_position_config(\n",
" \"mlp_output\", 0, 4, sad_embedding\n",
")\n",
"intervenable = IntervenableModel(intervenable_config, tinystory)\n",
"\n",
"base = \"Once upon a time there was\"\n",
"\n",
"inputs = tokenizer(base, return_tensors=\"pt\")\n",
"base_outputs, counterfactual_outputs = intervenable.generate(\n",
" inputs,\n",
" unit_locations={\n",
" \"sources->base\": (\n",
" [[[0]]] * tinystory.config.num_layers, # a single token embeddings\n",
" [[[0]]] * tinystory.config.num_layers, # the last token of the prompt\n",
" )\n",
" },\n",
" activations_sources=activations_sad_sources,\n",
" intervene_on_prompt=False,\n",
" max_length=256,\n",
" num_beams=1,\n",
")\n",
Expand Down Expand Up @@ -370,18 +305,21 @@
}
],
"source": [
"intervenable_config = activation_addition_position_config(\n",
" \"mlp_output\", 0, 4, happy_embedding\n",
")\n",
"intervenable = IntervenableModel(intervenable_config, tinystory)\n",
"\n",
"base = \"Once upon a time there was\"\n",
"\n",
"inputs = tokenizer(base, return_tensors=\"pt\")\n",
"base_outputs, counterfactual_outputs = intervenable.generate(\n",
" inputs,\n",
" unit_locations={\n",
" \"sources->base\": (\n",
" [[[0]]] * tinystory.config.num_layers, # a single token embeddings\n",
" [[[0]]] * tinystory.config.num_layers, # the last token of the prompt\n",
" None, 0\n",
" )\n",
" },\n",
" activations_sources=activations_happy_sources,\n",
" }, # this is less broadcast\n",
" intervene_on_prompt=False,\n",
" max_length=256,\n",
" num_beams=1,\n",
Expand All @@ -391,14 +329,6 @@
")\n",
"print(counterfactual_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76422448",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit d0cddb8

Please sign in to comment.