From b2aa098acb4738f09215b02409c46e8732582a17 Mon Sep 17 00:00:00 2001 From: frankaging Date: Sun, 2 Feb 2025 13:38:26 -0800 Subject: [PATCH] update all tutorials (no testing) --- .../advanced_tutorials/Boundless_DAS.ipynb | 6 ++--- .../DAS_Main_Introduction.ipynb | 2 +- .../advanced_tutorials/IOI_with_DAS.ipynb | 8 +++---- .../IOI_with_Mask_Intervention.ipynb | 6 ++--- tutorials/advanced_tutorials/MQNLI.ipynb | 2 +- .../advanced_tutorials/Probing_Gender.ipynb | 2 +- .../advanced_tutorials/Voting_Mechanism.ipynb | 22 +++++++++---------- .../advanced_tutorials/tutorial_ioi_utils.py | 12 +++++----- ...Subspace_Partition_with_Intervention.ipynb | 4 ++-- 9 files changed, 31 insertions(+), 33 deletions(-) diff --git a/tutorials/advanced_tutorials/Boundless_DAS.ipynb b/tutorials/advanced_tutorials/Boundless_DAS.ipynb index 7b40c536..7f6a84eb 100644 --- a/tutorials/advanced_tutorials/Boundless_DAS.ipynb +++ b/tutorials/advanced_tutorials/Boundless_DAS.ipynb @@ -422,8 +422,8 @@ "warm_up_steps = 0.1 * t_total\n", "optimizer_params = []\n", "for k, v in intervenable.interventions.items():\n", - " optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n", - " optimizer_params += [{\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n", + " optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n", + " optimizer_params += [{\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n", "optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", "scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=warm_up_steps, num_training_steps=t_total\n", @@ -470,7 +470,7 @@ " loss = loss_fct(shift_logits, shift_labels)\n", "\n", " for k, v in intervenable.interventions.items():\n", - " boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n", + " boundary_loss = 1.0 * v.intervention_boundaries.sum()\n", " loss += boundary_loss\n", "\n", " return loss" diff --git a/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb b/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb index 5115e93b..adb11ff5 100644 --- a/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb +++ b/tutorials/advanced_tutorials/DAS_Main_Introduction.ipynb @@ -1438,7 +1438,7 @@ "t_total = int(len(dataset) * epochs)\n", "optimizer_params = []\n", "for k, v in intervenable.interventions.items():\n", - " optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n", + " optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n", " break\n", "optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n", "\n", diff --git a/tutorials/advanced_tutorials/IOI_with_DAS.ipynb b/tutorials/advanced_tutorials/IOI_with_DAS.ipynb index d5b39266..c216e354 100644 --- a/tutorials/advanced_tutorials/IOI_with_DAS.ipynb +++ b/tutorials/advanced_tutorials/IOI_with_DAS.ipynb @@ -7302,7 +7302,7 @@ "source": [ "intervention = boundless_das_intervenable.interventions[\n", " \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + "]\n", "boundary_mask = sigmoid_boundary(\n", " intervention.intervention_population.repeat(1, 1),\n", " 0.0,\n", @@ -12476,7 +12476,7 @@ "source": [ "intervention = das_intervenable.interventions[\n", " \"layer.8.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + "]\n", "learned_weights = intervention.rotate_layer.weight\n", "headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n", "\n", @@ -17401,7 +17401,7 @@ "source": [ "intervention = boundless_das_intervenable.interventions[\n", " \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + "]\n", "boundary_mask = sigmoid_boundary(\n", " intervention.intervention_population.repeat(1, 1),\n", " 0.0,\n", @@ -23344,7 +23344,7 @@ "source": [ "intervention = das_intervenable.interventions[\n", " \"layer.9.repr.attention_value_output.unit.pos.nunit.1#0\"\n", - "][0]\n", + "]\n", "learned_weights = intervention.rotate_layer.weight\n", "headwise_learned_weights = torch.chunk(learned_weights, chunks=12, dim=0)\n", "\n", diff --git a/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb b/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb index 53d727ca..aa7736c5 100644 --- a/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb +++ b/tutorials/advanced_tutorials/IOI_with_Mask_Intervention.ipynb @@ -256,7 +256,7 @@ "def calculate_loss_with_mask(logits, labels, intervenable, coeff=1):\n", " loss = calculate_loss(logits, labels)\n", " for k, v in intervenable.interventions.items():\n", - " mask_loss = coeff * torch.norm(v[0].mask, 1)\n", + " mask_loss = coeff * torch.norm(v.mask, 1)\n", " loss += mask_loss\n", " return loss\n", "\n", @@ -363,8 +363,8 @@ " eval_preds += [counterfactual_outputs.logits]\n", "eval_metrics = compute_metrics(eval_preds, eval_labels)\n", "for k, v in pv_gpt2.interventions.items():\n", - " mask = v[0].mask\n", - " temperature = v[0].temperature\n", + " mask = v.mask\n", + " temperature = v.temperature\n", " break\n", "print(eval_metrics)\n", "print(\n", diff --git a/tutorials/advanced_tutorials/MQNLI.ipynb b/tutorials/advanced_tutorials/MQNLI.ipynb index fd83a5aa..ba6980ab 100644 --- a/tutorials/advanced_tutorials/MQNLI.ipynb +++ b/tutorials/advanced_tutorials/MQNLI.ipynb @@ -1402,7 +1402,7 @@ "\n", "optimizer_params = []\n", "for k, v in intervenable.interventions.items():\n", - " optimizer_params += [{\"params\": v[0].rotate_layer.parameters()}]\n", + " optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n", " break\n", "optimizer = torch.optim.Adam(optimizer_params, lr=0.001)\n", "\n", diff --git a/tutorials/advanced_tutorials/Probing_Gender.ipynb b/tutorials/advanced_tutorials/Probing_Gender.ipynb index 0bd32733..1a741ded 100644 --- a/tutorials/advanced_tutorials/Probing_Gender.ipynb +++ b/tutorials/advanced_tutorials/Probing_Gender.ipynb @@ -905,7 +905,7 @@ " optimizer_params = []\n", " for k, v in intervenable.interventions.items():\n", " try:\n", - " optimizer_params.append({\"params\": v[0].rotate_layer.parameters()})\n", + " optimizer_params.append({\"params\": v.rotate_layer.parameters()})\n", " except:\n", " pass\n", " optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", diff --git a/tutorials/advanced_tutorials/Voting_Mechanism.ipynb b/tutorials/advanced_tutorials/Voting_Mechanism.ipynb index 5553f34c..24204d49 100644 --- a/tutorials/advanced_tutorials/Voting_Mechanism.ipynb +++ b/tutorials/advanced_tutorials/Voting_Mechanism.ipynb @@ -330,9 +330,9 @@ "optimizer_params = []\n", "for k, v in pv_llama.interventions.items():\n", " optimizer_params += [\n", - " {\"params\": v[0].rotate_layer.parameters()}]\n", + " {\"params\": v.rotate_layer.parameters()}]\n", " optimizer_params += [\n", - " {\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n", + " {\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n", "optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", "scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=warm_up_steps,\n", @@ -351,7 +351,7 @@ " loss = loss_fct(shift_logits, shift_labels)\n", "\n", " for k, v in pv_llama.interventions.items():\n", - " boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n", + " boundary_loss = 1.0 * v.intervention_boundaries.sum()\n", " loss += boundary_loss\n", "\n", " return loss" @@ -481,7 +481,7 @@ "source": [ "torch.save(\n", " pv_llama.interventions[\n", - " f\"layer.{layer}.comp.block_output.unit.pos.nunit.1#0\"][0].state_dict(), \n", + " f\"layer.{layer}.comp.block_output.unit.pos.nunit.1#0\"].state_dict(), \n", " f\"./tutorial_data/layer.{layer}.pos.{token_position}.bin\"\n", ")" ] @@ -522,9 +522,9 @@ "pv_llama = pv.IntervenableModel(pv_config, llama)\n", "pv_llama.set_device(\"cuda\")\n", "pv_llama.disable_model_gradients()\n", - "pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n", + "pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.75.bin'))\n", - "pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#1'][0].load_state_dict(\n", + "pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#1'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.80.bin'))" ] }, @@ -665,11 +665,11 @@ "for loc in [78, 75, 80, [75, 80]]:\n", " if loc == 78:\n", " print(\"[control] intervening location: \", loc)\n", - " pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n", + " pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.78.bin'))\n", " else:\n", " print(\"intervening location: \", loc)\n", - " pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'][0].load_state_dict(\n", + " pv_llama.interventions[f'layer.{layer}.comp.block_output.unit.pos.nunit.1#0'].load_state_dict(\n", " torch.load('./tutorial_data/layer.15.pos.75.bin'))\n", " # evaluation on the test set\n", " collected_probs = []\n", @@ -1382,9 +1382,9 @@ "optimizer_params = []\n", "for k, v in pv_llama.interventions.items():\n", " optimizer_params += [\n", - " {\"params\": v[0].rotate_layer.parameters()}]\n", + " {\"params\": v.rotate_layer.parameters()}]\n", " optimizer_params += [\n", - " {\"params\": v[0].intervention_boundaries, \"lr\": 1e-2}]\n", + " {\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n", "optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n", "scheduler = get_linear_schedule_with_warmup(\n", " optimizer, num_warmup_steps=warm_up_steps,\n", @@ -1403,7 +1403,7 @@ " loss = loss_fct(shift_logits, shift_labels)\n", "\n", " for k, v in pv_llama.interventions.items():\n", - " boundary_loss = 1.0 * v[0].intervention_boundaries.sum()\n", + " boundary_loss = 1.0 * v.intervention_boundaries.sum()\n", " loss += boundary_loss\n", "\n", " return loss" diff --git a/tutorials/advanced_tutorials/tutorial_ioi_utils.py b/tutorials/advanced_tutorials/tutorial_ioi_utils.py index cb81320b..2da1aed5 100644 --- a/tutorials/advanced_tutorials/tutorial_ioi_utils.py +++ b/tutorials/advanced_tutorials/tutorial_ioi_utils.py @@ -519,7 +519,7 @@ def single_d_low_rank_das_position_config( def calculate_boundless_das_loss(logits, labels, intervenable): loss = calculate_loss(logits, labels) for k, v in intervenable.interventions.items(): - boundary_loss = 2.0 * v[0].intervention_boundaries.sum() + boundary_loss = 2.0 * v.intervention_boundaries.sum() loss += boundary_loss return loss @@ -658,9 +658,9 @@ def find_variable_at( if do_boundless_das: optimizer_params = [] for k, v in intervenable.interventions.items(): - optimizer_params += [{"params": v[0].rotate_layer.parameters()}] + optimizer_params += [{"params": v.rotate_layer.parameters()}] optimizer_params += [ - {"params": v[0].intervention_boundaries, "lr": 0.5} + {"params": v.intervention_boundaries, "lr": 0.5} ] optimizer = torch.optim.Adam(optimizer_params, lr=initial_lr) target_total_step = int(len(D_train) / batch_size) * n_epochs @@ -759,9 +759,7 @@ def find_variable_at( temperature_schedule[total_step] ) for k, v in intervenable.interventions.items(): - intervention_boundaries = v[ - 0 - ].intervention_boundaries.sum() + intervention_boundaries = v.intervention_boundaries.sum() total_step += 1 # eval @@ -828,7 +826,7 @@ def find_variable_at( if do_boundless_das: for k, v in intervenable.interventions.items(): - intervention_boundaries = v[0].intervention_boundaries.sum() + intervention_boundaries = v.intervention_boundaries.sum() data.append( { "pos": aligning_pos, diff --git a/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb b/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb index 8f950656..2d1c4820 100644 --- a/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb +++ b/tutorials/basic_tutorials/Subspace_Partition_with_Intervention.ipynb @@ -167,7 +167,7 @@ " )\n", " intervenable = IntervenableModel(config, gpt)\n", " for k, v in intervenable.interventions.items():\n", - " v[0].set_interchange_dim(768)\n", + " v.set_interchange_dim(768)\n", " for pos_i in range(len(base.input_ids[0])):\n", " _, counterfactual_outputs = intervenable(\n", " base,\n", @@ -194,7 +194,7 @@ " )\n", " intervenable = IntervenableModel(config, gpt)\n", " for k, v in intervenable.interventions.items():\n", - " v[0].set_interchange_dim(768)\n", + " v.set_interchange_dim(768)\n", " for pos_i in range(len(base.input_ids[0])):\n", " _, counterfactual_outputs = intervenable(\n", " base,\n",