Skip to content

Commit

Permalink
Commit poc notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
edan-bainglass committed Nov 18, 2024
1 parent 4380a7b commit 6b0db88
Showing 1 changed file with 203 additions and 0 deletions.
203 changes: 203 additions & 0 deletions test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from aiida import load_profile\n",
"\n",
"_ = load_profile()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"from time import sleep\n",
"\n",
"import ipywidgets as ipw\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import traitlets as tl\n",
"\n",
"from aiida import orm\n",
"from aiidalab_qe.common.widgets import LoadingWidget\n",
"\n",
"\n",
"class Model(tl.HasTraits):\n",
" node_id = tl.Unicode()\n",
" node = tl.Instance(orm.Node, allow_none=True)\n",
"\n",
" bands = tl.Instance(\n",
" np.ndarray,\n",
" default_value=np.array([]),\n",
" )\n",
" kpoints = tl.Instance(\n",
" np.ndarray,\n",
" default_value=np.array([]),\n",
" )\n",
"\n",
" title = \"\"\n",
"\n",
" def load_node(self):\n",
" try:\n",
" self.node = orm.load_node(self.node_id)\n",
" except Exception as e:\n",
" print(f\"Failed to load node: {e}\")\n",
"\n",
"\n",
"class BandsViewer(ipw.VBox):\n",
" def __init__(self, vid, model: Model, **kwargs):\n",
" super().__init__(\n",
" children=[LoadingWidget(\"Loading data fetcher\")],\n",
" **kwargs,\n",
" )\n",
"\n",
" self._id = vid\n",
" self._model = model\n",
" self._model.observe(\n",
" self._on_node_uuid_change,\n",
" \"node_id\",\n",
" )\n",
"\n",
" self.rendered = False\n",
"\n",
" def render(self):\n",
" if self.rendered:\n",
" return\n",
"\n",
" self.bands = ipw.Label(value=\"Fetching attribute 1...\")\n",
" ipw.dlink(\n",
" (self._model, \"bands\"),\n",
" (self.bands, \"value\"),\n",
" lambda bands: f\"Bands count: {len(bands)}\",\n",
" )\n",
"\n",
" self.kpoints = ipw.Label(value=\"Fetching attribute 2...\")\n",
" ipw.dlink(\n",
" (self._model, \"kpoints\"),\n",
" (self.kpoints, \"value\"),\n",
" lambda kpoints: f\"Kpoints count: {len(kpoints)}\",\n",
" )\n",
"\n",
" self.plot_button = ipw.Button(description=\"Plot bands\")\n",
" ipw.dlink(\n",
" (self._model, \"node\"), (self.plot_button, \"disabled\"), lambda node: not node\n",
" )\n",
" self.plot_button.on_click(lambda _: asyncio.create_task(self.fetch_and_plot()))\n",
"\n",
" self.plot_area = ipw.Output()\n",
"\n",
" self.children = [\n",
" ipw.HBox(\n",
" children=[\n",
" ipw.Label(str(self._id)),\n",
" ipw.VBox(\n",
" children=[\n",
" self.bands,\n",
" self.kpoints,\n",
" self.plot_button,\n",
" ]\n",
" ),\n",
" ]\n",
" ),\n",
" self.plot_area,\n",
" ]\n",
"\n",
" self.rendered = True\n",
"\n",
" def _on_node_uuid_change(self, change):\n",
" if change[\"new\"]:\n",
" self.node = self._model.load_node()\n",
"\n",
" async def fetch_bands(self):\n",
" await asyncio.sleep(1)\n",
" node = self._model.node\n",
" bands = node.outputs.bands.bands.band_structure.get_array(\"bands\")\n",
" self._model.bands = bands\n",
"\n",
" async def fetch_kpoints(self):\n",
" await asyncio.sleep(2)\n",
" node = self._model.node\n",
" kpoints = node.outputs.bands.bands.band_structure.get_array(\"kpoints\")\n",
" self._model.kpoints = kpoints\n",
"\n",
" async def plot_bands(self):\n",
" await asyncio.sleep(0.5)\n",
" self.plot_area.clear_output()\n",
" with self.plot_area:\n",
" print(len(self._model.bands))\n",
" plt.title(self._model.title)\n",
" _ = plt.plot(self._model.bands)\n",
" plt.show()\n",
"\n",
" async def fetch_and_plot(self):\n",
" await asyncio.gather(\n",
" self.fetch_bands(),\n",
" self.fetch_kpoints(),\n",
" )\n",
" await self.plot_bands()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"models = []\n",
"for i in range(1, 5):\n",
" model = Model()\n",
" models.append(model)\n",
" model.title = f\"Band structure {i}\"\n",
" loader = BandsViewer(vid=i, model=model)\n",
" display(loader)\n",
" sleep(0.5)\n",
" loader.render()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for model in models:\n",
" model.node_id = \"5216\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 6b0db88

Please sign in to comment.