Skip to content

Commit

Permalink
Add convergence check
Browse files Browse the repository at this point in the history
  • Loading branch information
mfschubert committed Nov 28, 2023
1 parent 1e22989 commit 4d4aed6
Showing 1 changed file with 58 additions and 6 deletions.
64 changes: 58 additions & 6 deletions notebooks/sorter_challenge.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"import invrs_opt\n",
"\n",
"from invrs_gym import challenges\n",
"from invrs_gym.utils import initializers"
"from invrs_gym import utils"
]
},
{
Expand All @@ -52,7 +52,7 @@
" )\n",
"\n",
"def density_initializer(key, seed_density):\n",
" density = initializers.noisy_density_initializer(\n",
" density = utils.initializers.noisy_density_initializer(\n",
" key=key,\n",
" seed_density=seed_density,\n",
" relative_mean=0.5,\n",
Expand Down Expand Up @@ -118,11 +118,11 @@
"plt.subplot(131)\n",
"plt.plot(loss_values)\n",
"ax = plt.subplot(132)\n",
"plt.imshow(common._density_array(initial_params[\"density_metasurface\"]))\n",
"plt.imshow(utils.transforms.rescaled_density_array(initial_params[\"density_metasurface\"], 0, 1))\n",
"ax.axis(False)\n",
"plt.colorbar()\n",
"ax = plt.subplot(133)\n",
"plt.imshow(common._density_array(params[\"density_metasurface\"]))\n",
"plt.imshow(utils.transforms.rescaled_density_array(params[\"density_metasurface\"], 0, 1))\n",
"ax.axis(False)\n",
"plt.colorbar()\n",
"\n",
Expand All @@ -141,7 +141,7 @@
"source": [
"# Plot the transmission into each of the four quadrants\n",
"\n",
"plt.figure(figsize=(10, 5))\n",
"plt.figure(figsize=(8, 3))\n",
"plt.subplot(121)\n",
"plt.imshow(response.transmission)\n",
"plt.clim([0, 0.5])\n",
Expand Down Expand Up @@ -170,7 +170,28 @@
"metadata": {},
"outputs": [],
"source": [
"print(metrics)"
"# Check for convergence by re-simulating the optimized structure for various\n",
"# expansions, i.e. with fewer and with more terms included.\n",
"\n",
"from fmmax import basis\n",
"\n",
"approximate_num_terms = [400, 800, 1200, 1600, 2000]\n",
"responses = []\n",
"for num in approximate_num_terms:\n",
" expansion = basis.generate_expansion(\n",
" primitive_lattice_vectors=basis.LatticeVectors(\n",
" u=basis.X * challenge.component.spec.pitch,\n",
" v=basis.Y * challenge.component.spec.pitch,\n",
" ),\n",
" approximate_num_terms=num,\n",
" truncation=basis.Truncation.CIRCULAR,\n",
" )\n",
" responses.append(\n",
" challenge.component.response(\n",
" params=params,\n",
" expansion=expansion,\n",
" )\n",
" )"
]
},
{
Expand All @@ -179,6 +200,37 @@
"id": "3719cf29-5f05-442c-bb31-43f23920fe53",
"metadata": {},
"outputs": [],
"source": [
"for num, (response, aux) in zip(approximate_num_terms, responses):\n",
" plt.figure(figsize=(8, 3))\n",
" plt.subplot(121)\n",
" plt.imshow(response.transmission)\n",
" plt.clim([0, 0.5])\n",
" plt.colorbar()\n",
" plt.title(f\"approximate_num_terms={num}\", fontsize=10)\n",
" \n",
" sz = aux[\"poynting_flux_z\"]\n",
" ax = plt.subplot(243)\n",
" ax.imshow(sz[..., 0])\n",
" ax.axis(False)\n",
" ax = plt.subplot(244)\n",
" ax.imshow(sz[..., 1])\n",
" ax.axis(False)\n",
" ax = plt.subplot(247)\n",
" ax.imshow(sz[..., 2])\n",
" ax.axis(False)\n",
" ax = plt.subplot(248)\n",
" ax.imshow(sz[..., 3])\n",
" ax.axis(False)\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e92f9c86-ad60-4215-b3a1-71c7203e0058",
"metadata": {},
"outputs": [],
"source": []
}
],
Expand Down

0 comments on commit 4d4aed6

Please sign in to comment.