From 1451ba82bac222c213a43b3c1dc6f042617c426f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?L=C3=A9on=20van=20Velzen?= <leonvanvelzen@protonmail.com>
Date: Mon, 30 Dec 2024 20:19:23 +0100
Subject: [PATCH] Throw better errors when Excitation is applied to wrong mesh
 types (lines/triangles)

---
 traceon/excitation.py | 53 +++++++++++++++++++++++++++++++++++++------
 1 file changed, 46 insertions(+), 7 deletions(-)

diff --git a/traceon/excitation.py b/traceon/excitation.py
index ea912ee3..c67a2578 100644
--- a/traceon/excitation.py
+++ b/traceon/excitation.py
@@ -19,6 +19,8 @@
 import numpy as np
 
 from .backend import N_QUAD_2D
+from .logging import log_error
+from . import excitation as E
 
 class Symmetry(IntEnum):
     """Symmetry to be used for solver. Used when deciding which formulas to use in the Boundary Element Method. The currently
@@ -93,6 +95,14 @@ def __str__(self):
         return f'<Traceon Excitation,\n\t' \
             + '\n\t'.join([f'{n}={v} ({t})' for n, (t, v) in self.excitation_types.items()]) \
             + '>'
+
+    def _ensure_electrode_is_lines(self, excitation_type, name):
+        assert name in self.electrodes, f"Electrode '{name}' is not present in the mesh"
+        assert name in self.mesh.physical_to_lines, f"Adding {excitation_type} excitation in {self.symmetry} symmetry is only supported if electrode '{name}' consists of lines"
+    
+    def _ensure_electrode_is_triangles(self, excitation_type, name):
+        assert name in self.electrodes, f"Electrode '{name}' is not present in the mesh"
+        assert name in self.mesh.physical_to_triangles, f"Adding {excitation_type} excitation in {self.symmetry} symmetry is only supported if electrode '{name}' consists of triangles"
      
     def add_voltage(self, **kwargs):
         """
@@ -108,7 +118,12 @@ def add_voltage(self, **kwargs):
         
         """
         for name, voltage in kwargs.items():
-            assert name in self.electrodes, f'Cannot add {name} to excitation, since it\'s not present in the mesh'
+             
+            if self.symmetry == E.Symmetry.RADIAL:
+                self._ensure_electrode_is_lines('voltage', name)
+            elif self.symmetry == E.Symmetry.THREE_D:
+                self._ensure_electrode_is_triangles('voltage', name)
+            
             if isinstance(voltage, int) or isinstance(voltage, float):
                 self.excitation_types[name] = (ExcitationType.VOLTAGE_FIXED, voltage)
             elif callable(voltage):
@@ -130,11 +145,11 @@ def add_current(self, **kwargs):
         """
         if self.symmetry == Symmetry.RADIAL:
             for name, current in kwargs.items():
-                assert name in self.mesh.physical_to_triangles.keys(), "Current should be applied to triangles in radial symmetry"
+                self._ensure_electrode_is_triangles("current", name)
                 self.excitation_types[name] = (ExcitationType.CURRENT, current)
         elif self.symmetry == Symmetry.THREE_D:
             for name, current in kwargs.items():
-                assert name in self.mesh.physical_to_lines.keys(), "Current should be applied to lines in 3D symmetry"
+                self._ensure_electrode_is_lines("current", name)
                 self.excitation_types[name] = (ExcitationType.CURRENT, current)
         else:
             raise ValueError('Symmetry should be one of RADIAL or THREE_D')
@@ -162,7 +177,11 @@ def add_magnetostatic_potential(self, **kwargs):
             calling the function as `add_magnetostatic_potential(lens=50)` assigns a 50A value to the geometry elements part of the 'lens' physical group.
         """
         for name, pot in kwargs.items():
-            assert name in self.electrodes, f'Cannot add {name} to excitation, since it\'s not present in the mesh'
+            if self.symmetry == E.Symmetry.RADIAL:
+                self._ensure_electrode_is_lines('magnetostatic potential', name)
+            elif self.symmetry == E.Symmetry.THREE_D:
+                self._ensure_electrode_is_triangles('magnetostatic potential', name)
+             
             self.excitation_types[name] = (ExcitationType.MAGNETOSTATIC_POT, pot)
 
     def add_magnetizable(self, **kwargs):
@@ -178,7 +197,11 @@ def add_magnetizable(self, **kwargs):
         """
 
         for name, permeability in kwargs.items():
-            assert name in self.electrodes, f'Cannot add {name} to excitation, since it\'s not present in the mesh'
+            if self.symmetry == E.Symmetry.RADIAL:
+                self._ensure_electrode_is_lines('magnetizable', name)
+            elif self.symmetry == E.Symmetry.THREE_D:
+                self._ensure_electrode_is_triangles('magnetizable', name)
+
             self.excitation_types[name] = (ExcitationType.MAGNETIZABLE, permeability)
      
     def add_dielectric(self, **kwargs):
@@ -193,7 +216,11 @@ def add_dielectric(self, **kwargs):
          
         """
         for name, permittivity in kwargs.items():
-            assert name in self.electrodes, f'Cannot add {name} to excitation, since it\'s not present in the mesh'
+            if self.symmetry == E.Symmetry.RADIAL:
+                self._ensure_electrode_is_lines('dielectric', name)
+            elif self.symmetry == E.Symmetry.THREE_D:
+                self._ensure_electrode_is_triangles('dielectric', name)
+
             self.excitation_types[name] = (ExcitationType.DIELECTRIC, permittivity)
 
     def add_electrostatic_boundary(self, *args, ensure_inward_normals=True):
@@ -212,6 +239,12 @@ def add_electrostatic_boundary(self, *args, ensure_inward_normals=True):
             for electrode in args:
                 self.mesh.ensure_inward_normals(electrode)
         
+        for name in args:
+            if self.symmetry == E.Symmetry.RADIAL:
+                self._ensure_electrode_is_lines('electrostatic boundary', name)
+            elif self.symmetry == E.Symmetry.THREE_D:
+                self._ensure_electrode_is_triangles('electrostatic boundary', name)
+
         self.add_dielectric(**{a:0 for a in args})
     
     def add_magnetostatic_boundary(self, *args, ensure_inward_normals=True):
@@ -230,7 +263,13 @@ def add_magnetostatic_boundary(self, *args, ensure_inward_normals=True):
             for electrode in args:
                 print('flipping normals', electrode)
                 self.mesh.ensure_inward_normals(electrode)
-        
+         
+        for name in args:
+            if self.symmetry == E.Symmetry.RADIAL:
+                self._ensure_electrode_is_lines('magnetostatic boundary', name)
+            elif self.symmetry == E.Symmetry.THREE_D:
+                self._ensure_electrode_is_triangles('magnetostatic boundary', name)
+         
         self.add_magnetizable(**{a:0 for a in args})
     
     def _split_for_superposition(self):