Skip to content

Commit

Permalink
Fix pyvene 101 notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
PinetreePantry committed Jan 30, 2025
1 parent 3c1aabe commit 8695a09
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pyvene_101.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1877,7 +1877,7 @@
" tokenizer(\"The capital of Italy is\", return_tensors=\"pt\"),\n",
"]\n",
"base_outputs, counterfactual_outputs = pv_gpt2(\n",
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}\n",
" base, sources, {\"sources->base\": ([[[3]]], [[[3]]])}, output_original_output=True\n",
")\n",
"print(counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state)\n",
"# call backward will put gradients on model's weights\n",
Expand Down Expand Up @@ -2785,7 +2785,7 @@
" model=resnet\n",
")\n",
"intervened_outputs = pv_resnet(\n",
" base_inputs, [source_inputs], return_dict=True\n",
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
")\n",
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
]
Expand Down Expand Up @@ -2842,7 +2842,7 @@
")\n",
"\n",
"intervened_outputs = pv_resnet(\n",
" base_inputs, [source_inputs], return_dict=True\n",
" base_inputs, [source_inputs], return_dict=True, output_original_output=True\n",
")\n",
"(intervened_outputs.intervened_outputs.logits - intervened_outputs.original_outputs.logits).sum()"
]
Expand Down

0 comments on commit 8695a09

Please sign in to comment.