Skip to content

Commit

Permalink
runs ruff, pre commit on files
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Dec 12, 2024
1 parent 873e513 commit e2fbd7e
Show file tree
Hide file tree
Showing 2 changed files with 476 additions and 235 deletions.
101 changes: 66 additions & 35 deletions examples/pooled_texture_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,30 @@
}
],
"source": [
"import plenoptic as po\n",
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import einops\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"\n",
"import plenoptic as po\n",
"\n",
"# so that relative sizes of axes created by po.imshow and others look right\n",
"plt.rcParams['figure.dpi'] = 72\n",
"plt.rcParams[\"figure.dpi\"] = 72\n",
"\n",
"# set seed for reproducibility\n",
"po.tools.set_seed(1)\n",
"\n",
"# use GPU if available\n",
"DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"%matplotlib inline\n",
"\n",
"\n",
"# Animation-related settings\n",
"plt.rcParams['animation.html'] = 'html5'\n",
"plt.rcParams[\"animation.html\"] = \"html5\"\n",
"# use single-threaded ffmpeg for animation writer\n",
"plt.rcParams['animation.writer'] = 'ffmpeg'\n",
"plt.rcParams['animation.ffmpeg_args'] = ['-threads', '1']"
"plt.rcParams[\"animation.writer\"] = \"ffmpeg\"\n",
"plt.rcParams[\"animation.ffmpeg_args\"] = [\"-threads\", \"1\"]"
]
},
{
Expand Down Expand Up @@ -111,9 +113,9 @@
" mask_sz = int(256 // len(mask))\n",
" for j, m in enumerate(mask):\n",
" if i == 0:\n",
" m[..., j*mask_sz:(j+1)*mask_sz, :] = 1\n",
" m[..., j * mask_sz : (j + 1) * mask_sz, :] = 1\n",
" else:\n",
" m[..., j*mask_sz:(j+1)*mask_sz] = 1\n",
" m[..., j * mask_sz : (j + 1) * mask_sz] = 1\n",
" masks.append(torch.cat(mask, 0))\n",
"print([m.shape for m in masks])"
]
Expand Down Expand Up @@ -144,8 +146,8 @@
}
],
"source": [
"masked = einops.einsum(*masks, img, 'm0 h w, m1 h w, b c h w -> b c m0 m1 h w')\n",
"po.imshow(einops.rearrange(masked, 'b c m0 m1 h w -> b (c m0 m1) h w'));"
"masked = einops.einsum(*masks, img, \"m0 h w, m1 h w, b c h w -> b c m0 m1 h w\")\n",
"po.imshow(einops.rearrange(masked, \"b c m0 m1 h w -> b (c m0 m1) h w\"));"
]
},
{
Expand Down Expand Up @@ -179,7 +181,7 @@
"for i, m in enumerate(mask):\n",
" x = int(i // np.sqrt(len(mask)))\n",
" y = int(i % np.sqrt(len(mask)))\n",
" m[..., x*mask_sz:(x+1)*mask_sz, y*mask_sz:(y+1)*mask_sz] = 1\n",
" m[..., x * mask_sz : (x + 1) * mask_sz, y * mask_sz : (y + 1) * mask_sz] = 1\n",
"single_masks = torch.cat(mask, 0)\n",
"single_masks.shape"
]
Expand All @@ -202,8 +204,8 @@
}
],
"source": [
"masked = einops.einsum(single_masks, img, 'm0 h w, b c h w -> b c m0 h w')\n",
"po.imshow(einops.rearrange(masked, 'b c m0 h w -> b (c m0) h w'));"
"masked = einops.einsum(single_masks, img, \"m0 h w, b c h w -> b c m0 h w\")\n",
"po.imshow(einops.rearrange(masked, \"b c m0 h w -> b (c m0) h w\"));"
]
},
{
Expand Down Expand Up @@ -237,7 +239,7 @@
}
],
"source": [
"einops.einsum(*masks, 'm0 h w, m1 h w -> m0 m1')"
"einops.einsum(*masks, \"m0 h w, m1 h w -> m0 m1\")"
]
},
{
Expand Down Expand Up @@ -269,8 +271,8 @@
}
],
"source": [
"masks = [m/64 for m in masks]\n",
"einops.einsum(*masks, 'm0 h w, m1 h w -> m0 m1')"
"masks = [m / 64 for m in masks]\n",
"einops.einsum(*masks, \"m0 h w, m1 h w -> m0 m1\")"
]
},
{
Expand Down Expand Up @@ -444,7 +446,11 @@
}
],
"source": [
"met = po.synth.MetamerCTF(img, ps_mask, loss_function=po.tools.optim.l2_norm,)"
"met = po.synth.MetamerCTF(\n",
" img,\n",
" ps_mask,\n",
" loss_function=po.tools.optim.l2_norm,\n",
")"
]
},
{
Expand All @@ -470,7 +476,9 @@
}
],
"source": [
"met.synthesize(max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7)"
"met.synthesize(\n",
" max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7\n",
")"
]
},
{
Expand Down Expand Up @@ -507,7 +515,9 @@
}
],
"source": [
"po.synth.metamer.plot_synthesis_status(met, width_ratios={'plot_representation_error': 2});"
"po.synth.metamer.plot_synthesis_status(\n",
" met, width_ratios={\"plot_representation_error\": 2}\n",
");"
]
},
{
Expand Down Expand Up @@ -7428,7 +7438,7 @@
}
],
"source": [
"po.synth.metamer.animate(met, width_ratios={'plot_representation_error': 2})"
"po.synth.metamer.animate(met, width_ratios={\"plot_representation_error\": 2})"
]
},
{
Expand Down Expand Up @@ -7458,7 +7468,11 @@
],
"source": [
"img = po.data.einstein().to(DEVICE)\n",
"met = po.synth.MetamerCTF(img, ps_mask, loss_function=po.tools.optim.l2_norm,)"
"met = po.synth.MetamerCTF(\n",
" img,\n",
" ps_mask,\n",
" loss_function=po.tools.optim.l2_norm,\n",
")"
]
},
{
Expand All @@ -7476,7 +7490,9 @@
}
],
"source": [
"met.synthesize(max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7)"
"met.synthesize(\n",
" max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7\n",
")"
]
},
{
Expand Down Expand Up @@ -7505,7 +7521,9 @@
}
],
"source": [
"po.synth.metamer.plot_synthesis_status(met, width_ratios={'plot_representation_error': 2});"
"po.synth.metamer.plot_synthesis_status(\n",
" met, width_ratios={\"plot_representation_error\": 2}\n",
");"
]
},
{
Expand All @@ -7528,7 +7546,8 @@
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../../pooling-windows/')\n",
"\n",
"sys.path.append(\"../../pooling-windows/\")\n",
"import pooling"
]
},
Expand Down Expand Up @@ -7556,7 +7575,7 @@
}
],
"source": [
"pw = pooling.PoolingWindows(.5, img.shape[-2:], max_eccentricity=5)\n",
"pw = pooling.PoolingWindows(0.5, img.shape[-2:], max_eccentricity=5)\n",
"pooling_windows = [pw.ecc_windows[0], pw.angle_windows[0]]"
]
},
Expand Down Expand Up @@ -7617,7 +7636,9 @@
}
],
"source": [
"plt.stem(po.to_numpy(einops.einsum(*pooling_windows, 'm1 h w, m2 h w -> m1 m2').flatten()));"
"plt.stem(\n",
" po.to_numpy(einops.einsum(*pooling_windows, \"m1 h w, m2 h w -> m1 m2\").flatten())\n",
");"
]
},
{
Expand Down Expand Up @@ -7695,7 +7716,9 @@
],
"source": [
"pooling_windows = [pw.ecc_windows[0][:-1], pw.angle_windows[0]]\n",
"plt.stem(po.to_numpy(einops.einsum(*pooling_windows, 'm1 h w, m2 h w -> m1 m2').flatten()));"
"plt.stem(\n",
" po.to_numpy(einops.einsum(*pooling_windows, \"m1 h w, m2 h w -> m1 m2\").flatten())\n",
");"
]
},
{
Expand Down Expand Up @@ -7723,7 +7746,11 @@
"metadata": {},
"outputs": [],
"source": [
"met = po.synth.MetamerCTF(img, ps_mask, loss_function=po.tools.optim.l2_norm,)"
"met = po.synth.MetamerCTF(\n",
" img,\n",
" ps_mask,\n",
" loss_function=po.tools.optim.l2_norm,\n",
")"
]
},
{
Expand All @@ -7741,7 +7768,9 @@
}
],
"source": [
"met.synthesize(max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7)"
"met.synthesize(\n",
" max_iter=500, store_progress=10, change_scale_criterion=None, ctf_iters_to_check=7\n",
")"
]
},
{
Expand All @@ -7762,7 +7791,9 @@
}
],
"source": [
"po.synth.metamer.plot_synthesis_status(met, width_ratios={'plot_representation_error': 2});"
"po.synth.metamer.plot_synthesis_status(\n",
" met, width_ratios={\"plot_representation_error\": 2}\n",
");"
]
},
{
Expand Down Expand Up @@ -14163,7 +14194,7 @@
}
],
"source": [
"po.synth.metamer.animate(met, width_ratios={'plot_representation_error': 2})"
"po.synth.metamer.animate(met, width_ratios={\"plot_representation_error\": 2})"
]
},
{
Expand Down Expand Up @@ -14192,7 +14223,7 @@
}
],
"source": [
"fig = po.imshow([img, met.metamer], title=['Target image', 'Metamer']);\n",
"fig = po.imshow([img, met.metamer], title=[\"Target image\", \"Metamer\"])\n",
"for ax in fig.axes:\n",
" pw.plot_windows(ax, subset=False)"
]
Expand Down
Loading

0 comments on commit e2fbd7e

Please sign in to comment.