Skip to content

Commit

Permalink
update all tutorials (no testing)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Feb 2, 2025
1 parent 8895ed5 commit b2aa098
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 33 deletions.
6 changes: 3 additions & 3 deletions tutorials/advanced_tutorials/Boundless_DAS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,8 @@
"warm_up_steps = 0.1 * t_total\n",
"optimizer_params = []\n",
"for k, v in intervenable.interventions.items():\n",
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n",
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
"scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=warm_up_steps, num_training_steps=t_total\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
" loss = loss_fct(shift_logits, shift_labels)\n",
"\n",
" for k, v in intervenable.interventions.items():\n",
" boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n",
" boundary_loss = 1.0 * v.intervention_boundaries.sum()\n",
" loss += boundary_loss\n",
"\n",
" return loss"
Expand Down
2 changes: 1 addition & 1 deletion tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@
"t_total = int(len(dataset) * epochs)\n",
"optimizer_params = []\n",
"for k, v in intervenable.interventions.items():\n",
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
" break\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n",
"\n",
Expand Down
8 changes: 4 additions & 4 deletions tutorials/advanced_tutorials/IOI_with_DAS.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7302,7 +7302,7 @@
"source": [
"intervention = boundless_das_intervenable.interventions[\n",
" \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
"][0]\n",
"]\n",
"boundary_mask = sigmoid_boundary(\n",
" intervention.intervention_population.repeat(1, 1),\n",
" 0.0,\n",
Expand Down Expand Up @@ -12476,7 +12476,7 @@
"source": [
"intervention = das_intervenable.interventions[\n",
" \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
"][0]\n",
"]\n",
"learned_weights = intervention.rotate_layer.weight\n",
"headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n",
"\n",
Expand Down Expand Up @@ -17401,7 +17401,7 @@
"source": [
"intervention = boundless_das_intervenable.interventions[\n",
" \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
"][0]\n",
"]\n",
"boundary_mask = sigmoid_boundary(\n",
" intervention.intervention_population.repeat(1, 1),\n",
" 0.0,\n",
Expand Down Expand Up @@ -23344,7 +23344,7 @@
"source": [
"intervention = das_intervenable.interventions[\n",
" \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n",
"][0]\n",
"]\n",
"learned_weights = intervention.rotate_layer.weight\n",
"headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
"def calculate_loss_with_mask(logits, labels, intervenable, coeff=1):\n",
" loss = calculate_loss(logits, labels)\n",
" for k, v in intervenable.interventions.items():\n",
" mask_loss = coeff * torch.norm(v[0].mask, 1)\n",
" mask_loss = coeff * torch.norm(v.mask, 1)\n",
" loss += mask_loss\n",
" return loss\n",
"\n",
Expand Down Expand Up @@ -363,8 +363,8 @@
" eval_preds += [counterfactual_outputs.logits]\n",
"eval_metrics = compute_metrics(eval_preds, eval_labels)\n",
"for k, v in pv_gpt2.interventions.items():\n",
" mask = v[0].mask\n",
" temperature = v[0].temperature\n",
" mask = v.mask\n",
" temperature = v.temperature\n",
" break\n",
"print(eval_metrics)\n",
"print(\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorials/advanced_tutorials/MQNLI.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,7 @@
"\n",
"optimizer_params = []\n",
"for k, v in intervenable.interventions.items():\n",
" optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n",
" optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
" break\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion tutorials/advanced_tutorials/Probing_Gender.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@
" optimizer_params = []\n",
" for k, v in intervenable.interventions.items():\n",
" try:\n",
" optimizer_params.append({\"params\": v[0].rotate_layer.parameters()})\n",
" optimizer_params.append({\"params\": v.rotate_layer.parameters()})\n",
" except:\n",
" pass\n",
" optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
Expand Down
22 changes: 11 additions & 11 deletions tutorials/advanced_tutorials/Voting_Mechanism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@
"optimizer_params = []\n",
"for k, v in pv_llama.interventions.items():\n",
" optimizer_params += [\n",
" {\"params\": v[0].rotate_layer.parameters()}]\n",
" {\"params\": v.rotate_layer.parameters()}]\n",
" optimizer_params += [\n",
" {\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n",
" {\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
"scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=warm_up_steps,\n",
Expand All @@ -351,7 +351,7 @@
" loss = loss_fct(shift_logits, shift_labels)\n",
"\n",
" for k, v in pv_llama.interventions.items():\n",
" boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n",
" boundary_loss = 1.0 * v.intervention_boundaries.sum()\n",
" loss += boundary_loss\n",
"\n",
" return loss"
Expand Down Expand Up @@ -481,7 +481,7 @@
"source": [
"torch.save(\n",
" pv_llama.interventions[\n",
" f\"layer.{layer}.comp.block_output.unit.pos.nunit.1#0\"][0].state_dict(), \n",
" f\"layer.{layer}.comp.block_output.unit.pos.nunit.1#0\"].state_dict(), \n",
" f\"./tutorial_data/layer.{layer}.pos.{token_position}.bin\"\n",
")"
]
Expand Down Expand Up @@ -522,9 +522,9 @@
"pv_llama = pv.IntervenableModel(pv_config, llama)\n",
"pv_llama.set_device(\"cuda\")\n",
"pv_llama.disable_model_gradients()\n",
"pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n",
"pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.75.bin'))\n",
"pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#1'][0].load_state_dict(\n",
"pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#1'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.80.bin'))"
]
},
Expand Down Expand Up @@ -665,11 +665,11 @@
"for loc in [78, 75, 80, [75, 80]]:\n",
" if loc == 78:\n",
" print(\"[control] intervening location: \", loc)\n",
" pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n",
" pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.78.bin'))\n",
" else:\n",
" print(\"intervening location: \", loc)\n",
" pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n",
" pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n",
" torch.load('./tutorial_data/layer.15.pos.75.bin'))\n",
" # evaluation on the test set\n",
" collected_probs = []\n",
Expand Down Expand Up @@ -1382,9 +1382,9 @@
"optimizer_params = []\n",
"for k, v in pv_llama.interventions.items():\n",
" optimizer_params += [\n",
" {\"params\": v[0].rotate_layer.parameters()}]\n",
" {\"params\": v.rotate_layer.parameters()}]\n",
" optimizer_params += [\n",
" {\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n",
" {\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n",
"optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
"scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=warm_up_steps,\n",
Expand All @@ -1403,7 +1403,7 @@
" loss = loss_fct(shift_logits, shift_labels)\n",
"\n",
" for k, v in pv_llama.interventions.items():\n",
" boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n",
" boundary_loss = 1.0 * v.intervention_boundaries.sum()\n",
" loss += boundary_loss\n",
"\n",
" return loss"
Expand Down
12 changes: 5 additions & 7 deletions tutorials/advanced_tutorials/tutorial_ioi_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def single_d_low_rank_das_position_config(
def calculate_boundless_das_loss(logits, labels, intervenable):
loss = calculate_loss(logits, labels)
for k, v in intervenable.interventions.items():
boundary_loss = 2.0 * v[0].intervention_boundaries.sum()
boundary_loss = 2.0 * v.intervention_boundaries.sum()
loss += boundary_loss
return loss

Expand Down Expand Up @@ -658,9 +658,9 @@ def find_variable_at(
if do_boundless_das:
optimizer_params = []
for k, v in intervenable.interventions.items():
optimizer_params += [{"params": v[0].rotate_layer.parameters()}]
optimizer_params += [{"params": v.rotate_layer.parameters()}]
optimizer_params += [
{"params": v[0].intervention_boundaries, "lr": 0.5}
{"params": v.intervention_boundaries, "lr": 0.5}
]
optimizer = torch.optim.Adam(optimizer_params, lr=initial_lr)
target_total_step = int(len(D_train) / batch_size) * n_epochs
Expand Down Expand Up @@ -759,9 +759,7 @@ def find_variable_at(
temperature_schedule[total_step]
)
for k, v in intervenable.interventions.items():
intervention_boundaries = v[
0
].intervention_boundaries.sum()
intervention_boundaries = v.intervention_boundaries.sum()
total_step += 1

# eval
Expand Down Expand Up @@ -828,7 +826,7 @@ def find_variable_at(

if do_boundless_das:
for k, v in intervenable.interventions.items():
intervention_boundaries = v[0].intervention_boundaries.sum()
intervention_boundaries = v.intervention_boundaries.sum()
data.append(
{
"pos": aligning_pos,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
" )\n",
" intervenable = IntervenableModel(config, gpt)\n",
" for k, v in intervenable.interventions.items():\n",
" v[0].set_interchange_dim(768)\n",
" v.set_interchange_dim(768)\n",
" for pos_i in range(len(base.input_ids[0])):\n",
" _, counterfactual_outputs = intervenable(\n",
" base,\n",
Expand All @@ -194,7 +194,7 @@
" )\n",
" intervenable = IntervenableModel(config, gpt)\n",
" for k, v in intervenable.interventions.items():\n",
" v[0].set_interchange_dim(768)\n",
" v.set_interchange_dim(768)\n",
" for pos_i in range(len(base.input_ids[0])):\n",
" _, counterfactual_outputs = intervenable(\n",
" base,\n",
Expand Down

0 comments on commit b2aa098

Please sign in to comment.