From 4d4aed6b019c3833af9e88430a906a9b6bfb0f55 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Tue, 28 Nov 2023 08:59:04 -0800 Subject: [PATCH] Add convergence check --- notebooks/sorter_challenge.ipynb | 64 +++++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/notebooks/sorter_challenge.ipynb b/notebooks/sorter_challenge.ipynb index 2a441c8..648785f 100644 --- a/notebooks/sorter_challenge.ipynb +++ b/notebooks/sorter_challenge.ipynb @@ -25,7 +25,7 @@ "import invrs_opt\n", "\n", "from invrs_gym import challenges\n", - "from invrs_gym.utils import initializers" + "from invrs_gym import utils" ] }, { @@ -52,7 +52,7 @@ " )\n", "\n", "def density_initializer(key, seed_density):\n", - " density = initializers.noisy_density_initializer(\n", + " density = utils.initializers.noisy_density_initializer(\n", " key=key,\n", " seed_density=seed_density,\n", " relative_mean=0.5,\n", @@ -118,11 +118,11 @@ "plt.subplot(131)\n", "plt.plot(loss_values)\n", "ax = plt.subplot(132)\n", - "plt.imshow(common._density_array(initial_params[\"density_metasurface\"]))\n", + "plt.imshow(utils.transforms.rescaled_density_array(initial_params[\"density_metasurface\"], 0, 1))\n", "ax.axis(False)\n", "plt.colorbar()\n", "ax = plt.subplot(133)\n", - "plt.imshow(common._density_array(params[\"density_metasurface\"]))\n", + "plt.imshow(utils.transforms.rescaled_density_array(params[\"density_metasurface\"], 0, 1))\n", "ax.axis(False)\n", "plt.colorbar()\n", "\n", @@ -141,7 +141,7 @@ "source": [ "# Plot the transmission into each of the four quadrants\n", "\n", - "plt.figure(figsize=(10, 5))\n", + "plt.figure(figsize=(8, 3))\n", "plt.subplot(121)\n", "plt.imshow(response.transmission)\n", "plt.clim([0, 0.5])\n", @@ -170,7 +170,28 @@ "metadata": {}, "outputs": [], "source": [ - "print(metrics)" + "# Check for convergence by re-simulating the optimized structure for various\n", + "# expansions, i.e. with fewer and with more terms included.\n", + "\n", + "from fmmax import basis\n", + "\n", + "approximate_num_terms = [400, 800, 1200, 1600, 2000]\n", + "responses = []\n", + "for num in approximate_num_terms:\n", + " expansion = basis.generate_expansion(\n", + " primitive_lattice_vectors=basis.LatticeVectors(\n", + " u=basis.X * challenge.component.spec.pitch,\n", + " v=basis.Y * challenge.component.spec.pitch,\n", + " ),\n", + " approximate_num_terms=num,\n", + " truncation=basis.Truncation.CIRCULAR,\n", + " )\n", + " responses.append(\n", + " challenge.component.response(\n", + " params=params,\n", + " expansion=expansion,\n", + " )\n", + " )" ] }, { @@ -179,6 +200,37 @@ "id": "3719cf29-5f05-442c-bb31-43f23920fe53", "metadata": {}, "outputs": [], + "source": [ + "for num, (response, aux) in zip(approximate_num_terms, responses):\n", + " plt.figure(figsize=(8, 3))\n", + " plt.subplot(121)\n", + " plt.imshow(response.transmission)\n", + " plt.clim([0, 0.5])\n", + " plt.colorbar()\n", + " plt.title(f\"approximate_num_terms={num}\", fontsize=10)\n", + " \n", + " sz = aux[\"poynting_flux_z\"]\n", + " ax = plt.subplot(243)\n", + " ax.imshow(sz[..., 0])\n", + " ax.axis(False)\n", + " ax = plt.subplot(244)\n", + " ax.imshow(sz[..., 1])\n", + " ax.axis(False)\n", + " ax = plt.subplot(247)\n", + " ax.imshow(sz[..., 2])\n", + " ax.axis(False)\n", + " ax = plt.subplot(248)\n", + " ax.imshow(sz[..., 3])\n", + " ax.axis(False)\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e92f9c86-ad60-4215-b3a1-71c7203e0058", + "metadata": {}, + "outputs": [], "source": [] } ],