diff --git a/test.ipynb b/test.ipynb new file mode 100644 index 000000000..3dd574b9b --- /dev/null +++ b/test.ipynb @@ -0,0 +1,203 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from aiida import load_profile\n", + "\n", + "_ = load_profile()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "from time import sleep\n", + "\n", + "import ipywidgets as ipw\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import traitlets as tl\n", + "\n", + "from aiida import orm\n", + "from aiidalab_qe.common.widgets import LoadingWidget\n", + "\n", + "\n", + "class Model(tl.HasTraits):\n", + " node_id = tl.Unicode()\n", + " node = tl.Instance(orm.Node, allow_none=True)\n", + "\n", + " bands = tl.Instance(\n", + " np.ndarray,\n", + " default_value=np.array([]),\n", + " )\n", + " kpoints = tl.Instance(\n", + " np.ndarray,\n", + " default_value=np.array([]),\n", + " )\n", + "\n", + " title = \"\"\n", + "\n", + " def load_node(self):\n", + " try:\n", + " self.node = orm.load_node(self.node_id)\n", + " except Exception as e:\n", + " print(f\"Failed to load node: {e}\")\n", + "\n", + "\n", + "class BandsViewer(ipw.VBox):\n", + " def __init__(self, vid, model: Model, **kwargs):\n", + " super().__init__(\n", + " children=[LoadingWidget(\"Loading data fetcher\")],\n", + " **kwargs,\n", + " )\n", + "\n", + " self._id = vid\n", + " self._model = model\n", + " self._model.observe(\n", + " self._on_node_uuid_change,\n", + " \"node_id\",\n", + " )\n", + "\n", + " self.rendered = False\n", + "\n", + " def render(self):\n", + " if self.rendered:\n", + " return\n", + "\n", + " self.bands = ipw.Label(value=\"Fetching attribute 1...\")\n", + " ipw.dlink(\n", + " (self._model, \"bands\"),\n", + " (self.bands, \"value\"),\n", + " lambda bands: f\"Bands count: {len(bands)}\",\n", + " )\n", + "\n", + " self.kpoints = ipw.Label(value=\"Fetching attribute 2...\")\n", + " ipw.dlink(\n", + " (self._model, \"kpoints\"),\n", + " (self.kpoints, \"value\"),\n", + " lambda kpoints: f\"Kpoints count: {len(kpoints)}\",\n", + " )\n", + "\n", + " self.plot_button = ipw.Button(description=\"Plot bands\")\n", + " ipw.dlink(\n", + " (self._model, \"node\"), (self.plot_button, \"disabled\"), lambda node: not node\n", + " )\n", + " self.plot_button.on_click(lambda _: asyncio.create_task(self.fetch_and_plot()))\n", + "\n", + " self.plot_area = ipw.Output()\n", + "\n", + " self.children = [\n", + " ipw.HBox(\n", + " children=[\n", + " ipw.Label(str(self._id)),\n", + " ipw.VBox(\n", + " children=[\n", + " self.bands,\n", + " self.kpoints,\n", + " self.plot_button,\n", + " ]\n", + " ),\n", + " ]\n", + " ),\n", + " self.plot_area,\n", + " ]\n", + "\n", + " self.rendered = True\n", + "\n", + " def _on_node_uuid_change(self, change):\n", + " if change[\"new\"]:\n", + " self.node = self._model.load_node()\n", + "\n", + " async def fetch_bands(self):\n", + " await asyncio.sleep(1)\n", + " node = self._model.node\n", + " bands = node.outputs.bands.bands.band_structure.get_array(\"bands\")\n", + " self._model.bands = bands\n", + "\n", + " async def fetch_kpoints(self):\n", + " await asyncio.sleep(2)\n", + " node = self._model.node\n", + " kpoints = node.outputs.bands.bands.band_structure.get_array(\"kpoints\")\n", + " self._model.kpoints = kpoints\n", + "\n", + " async def plot_bands(self):\n", + " await asyncio.sleep(0.5)\n", + " self.plot_area.clear_output()\n", + " with self.plot_area:\n", + " print(len(self._model.bands))\n", + " plt.title(self._model.title)\n", + " _ = plt.plot(self._model.bands)\n", + " plt.show()\n", + "\n", + " async def fetch_and_plot(self):\n", + " await asyncio.gather(\n", + " self.fetch_bands(),\n", + " self.fetch_kpoints(),\n", + " )\n", + " await self.plot_bands()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "models = []\n", + "for i in range(1, 5):\n", + " model = Model()\n", + " models.append(model)\n", + " model.title = f\"Band structure {i}\"\n", + " loader = BandsViewer(vid=i, model=model)\n", + " display(loader)\n", + " sleep(0.5)\n", + " loader.render()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for model in models:\n", + " model.node_id = \"5216\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}