From 43d743f3b356699910296513eee4bf4648a843cb Mon Sep 17 00:00:00 2001 From: Carlo Antonio Pignedoli Date: Sat, 22 Feb 2025 07:04:05 +0000 Subject: [PATCH 1/2] not using rdkit anymore --- nanoribbon/viewers/cdxml2gnr.py | 666 ++++++++++++++++++++------------ 1 file changed, 415 insertions(+), 251 deletions(-) diff --git a/nanoribbon/viewers/cdxml2gnr.py b/nanoribbon/viewers/cdxml2gnr.py index 4c2f869..945ed03 100644 --- a/nanoribbon/viewers/cdxml2gnr.py +++ b/nanoribbon/viewers/cdxml2gnr.py @@ -1,308 +1,472 @@ -"""Widget to convert CDXML to nanoribbons.""" +"""Widget to convert CDXML to planar structures""" +import xml.etree.ElementTree as ET +import nglview as nv import ase -import ase.neighborlist import ipywidgets as ipw -import nglview import numpy as np -import rdkit -import scipy -import sklearn.decomposition -import traitlets as tl -from IPython.display import clear_output +import traitlets as tr +from ase import Atoms +from ase.data import chemical_symbols, covalent_radii +from ase.neighborlist import NeighborList +from scipy.spatial.distance import pdist class CdxmlUpload2GnrWidget(ipw.VBox): - """Class that allows to upload structures from user's computer.""" + """Widget for uploading CDXML files and converting them to ASE Atoms structures.""" - structure = tl.Instance(ase.Atoms, allow_none=True) + structure = tr.Instance(ase.Atoms, allow_none=True) def __init__(self, title="CDXML to GNR", description="Upload Structure"): self.title = title - self.original_structure = None - self.selection = set() + + # File upload widget for .cdxml files self.file_upload = ipw.FileUpload( - description=description, multiple=False, layout={"width": "initial"} - ) - supported_formats = ipw.HTML( - """ - Supported structure formats: ".cdxml" - """ + description=description, + multiple=False, + layout={"width": "initial"}, + accept=".cdxml", ) - self.file_upload.observe(self._on_file_upload, names="value") - self.create_cell_btn = ipw.Button( - description="create GNR", button_style="info", disabled=True + # Additional widgets + self.nunits = ipw.Text(description="N units", value="Infinite", disabled=True) + self.create_button = ipw.Button( + description="Create model", + button_style="success", ) - self.create_cell_btn.on_click(self._on_cell_button_pressed) + self.create_button.on_click(self._on_button_click) - self.viewer = nglview.NGLWidget() - self.viewer.stage.set_parameters(mouse_preset="pymol") - self.viewer.observe(self._on_picked, names="picked") - self.allmols = ipw.Dropdown( - options=[None], description="Select mol", value=None, disabled=True + supported_formats = ipw.HTML( + """ + + Supported structure formats: ".cdxml" + + """ ) - self.allmols.observe(self._on_sketch_selected, names="value") - self.select_two = ipw.HTML("") - self.picked_out = ipw.Output() - self.cell_button_out = ipw.Output() + + # Output message widget + self.output_message = ipw.HTML(value="") + + # Initialize the widget layout super().__init__( children=[ self.file_upload, + self.nunits, supported_formats, - self.allmols, - self.select_two, - self.viewer, - self.picked_out, - self.create_cell_btn, - self.cell_button_out, + self.create_button, + self.output_message, ] ) - @staticmethod - def guess_scaling_factor(atoms): - """Scaling factor to correct the bond length.""" - - # Set bounding box as cell. - c_x = 1.5 * (np.amax(atoms.positions[:, 0]) - np.amin(atoms.positions[:, 0])) - c_y = 1.5 * (np.amax(atoms.positions[:, 1]) - np.amin(atoms.positions[:, 1])) - c_z = 15.0 - atoms.cell = (c_x, c_y, c_z) - atoms.pbc = (True, True, True) - - # Calculate all atom-atom distances. - c_atoms = [a for a in atoms if a.symbol[0] == "C"] - n_atoms = len(c_atoms) - dists = np.zeros([n_atoms, n_atoms]) - for i, atom_a in enumerate(c_atoms): - for j, atom_b in enumerate(c_atoms): - dists[i, j] = np.linalg.norm(atom_a.position - atom_b.position) - - # Find bond distances to closest neighbor. - dists += np.diag([np.inf] * n_atoms) # Don't consider diagonal. - bonds = np.amin(dists, axis=1) - - # Average bond distance. - avg_bond = float(scipy.stats.mode(bonds)[0]) - - # Scale box to match equilibrium carbon-carbon bond distance. - cc_eq = 1.4313333333 - return cc_eq / avg_bond + # Internal state + self.structure = None + self.crossing_points = None + self.cdxml_atoms = None + self.atoms = None @staticmethod - def scale(atoms, s): - """Scale atomic positions by the `factor`.""" - c_x, c_y, c_z = atoms.cell - atoms.set_cell((s * c_x, s * c_y, c_z), scale_atoms=True) - atoms.center() - return atoms - - @staticmethod - def rdkit2ase(mol): - """Converts rdkit molecule into ase Atoms""" - species = [ - ase.data.chemical_symbols[atm.GetAtomicNum()] for atm in mol.GetAtoms() - ] - pos = np.asarray(list(mol.GetConformer().GetPositions())) - pca = sklearn.decomposition.PCA(n_components=3) - posnew = pca.fit_transform(pos) - atoms = ase.Atoms(species, positions=posnew) - sys_size = np.ptp(atoms.positions, axis=0) - atoms.rotate(-90, "z") # cdxml are rotated - atoms.pbc = True - atoms.cell = sys_size + 10 - atoms.center() - - return atoms + def add_hydrogen_atoms(atoms: Atoms) -> tuple[str, Atoms]: + """Add missing hydrogen atoms to the Atoms object based on covalent radii.""" + message = "" - @staticmethod - def construct_cell(atoms, id1, id2): - """Construct periodic cell based on two selected equivalent atoms.""" + neighbor_list = NeighborList( + [covalent_radii[atom.number] for atom in atoms], + bothways=True, + self_interaction=False, + ) + neighbor_list.update(atoms) - pos = [ - [atoms[id2].x, atoms[id2].y], - [atoms[id1].x, atoms[id1].y], - [atoms[id2].x + int(1), atoms[id1].y], + need_hydrogen = [ + atom.index + for atom in atoms + if len(neighbor_list.get_neighbors(atom.index)[0]) < 3 + and atom.symbol in {"C", "N"} ] - vec = [np.array(pos[0]) - np.array(pos[1]), np.array(pos[2]) - np.array(pos[1])] - c_x = np.linalg.norm(vec[0]) + message = f"Added missing Hydrogen atoms: {need_hydrogen}." - angle = ( - np.math.atan2(np.linalg.det([vec[0], vec[1]]), np.dot(vec[0], vec[1])) - * 180.0 - / np.pi - ) - if np.abs(angle) > 0.01: - atoms.euler_rotate( - center=atoms[id1].position, phi=-angle, theta=0.0, psi=0.0 - ) + for index in need_hydrogen: + vec = np.zeros(3) + indices, offsets = neighbor_list.get_neighbors(atoms[index].index) + for i, offset in zip(indices, offsets): + vec += -atoms[index].position + ( + atoms.positions[i] + np.dot(offset, atoms.get_cell()) + ) + vec = -vec / np.linalg.norm(vec) * 1.1 + atoms[index].position + atoms.append(ase.Atom("H", vec)) - c_y = 15.0 + np.amax(atoms.positions[:, 1]) - np.amin(atoms.positions[:, 1]) - c_z = 15.0 + np.amax(atoms.positions[:, 2]) - np.amin(atoms.positions[:, 2]) + return message, atoms - atoms.cell = (c_x, c_y, c_z) - atoms.pbc = (True, True, True) - atoms.center() - atoms.wrap(eps=0.001) + def _on_file_upload(self, change=None): + """Handles the file upload event and converts CDXML to ASE Atoms.""" + self.nunits.value = "Infinite" + self.nunits.disabled = True + + uploaded_file = list(self.file_upload.value.values())[0] + cdxml_content = uploaded_file["content"].decode("utf-8") + try: + self.atoms = self.cdxml_to_ase_from_string(cdxml_content) + ( + self.crossing_points, + self.cdxml_atoms, + self.nunits.disabled, + ) = self.extract_crossing_and_atom_positions(cdxml_content) + except ValueError as exc: + self.output_message.value = f"Error: {exc}" + except Exception as exc: + self.output_message.value = f"Unexpected error: {exc}" + + # Clear the file upload widget + self.file_upload.value.clear() - # Remove redundant atoms. ORIGINAL - tobedel = [] + def _on_button_click(self, _=None): + """Handles the creation of the ASE model when 'Create model' button is clicked.""" + if not self.atoms: + self.output_message.value = "Error: No atoms available to process." + return - n_l = ase.neighborlist.NeighborList( - [ase.data.covalent_radii[a.number] for a in atoms], - bothways=False, - self_interaction=False, - ) - n_l.update(atoms) + atoms = self.atoms.copy() + - for atm in atoms: - indices, offsets = n_l.get_neighbors(atm.index) - for i, offset in zip(indices, offsets): - dist = np.linalg.norm( - atm.position - - (atoms.positions[i] + np.dot(offset, atoms.get_cell())) - ) - if dist < 0.4: - tobedel.append(atoms[i].index) + if self.crossing_points is not None: + crossing_points = self.transform_points( + self.cdxml_atoms, atoms.positions, self.crossing_points + ) + atoms = self.align_and_trim_atoms( + atoms, np.array(crossing_points), units=self.nunits.value + ) + else: + self.output_message.value = "Error: No 'crossing points' found." + return - del atoms[tobedel] + if self.nunits.disabled: + extra_cell = 15.0 + atoms.cell = (np.ptp(atoms.positions, axis=0)) + extra_cell + atoms.center() - # Find unit cell and apply it. + if self.nunits.value == "Infinite": + atoms.pbc = True - # Add Hydrogens. - n_l = ase.neighborlist.NeighborList( - [ase.data.covalent_radii[a.number] for a in atoms], - bothways=True, - self_interaction=False, - ) - n_l.update(atoms) + self.output_message.value, self.structure = self.add_hydrogen_atoms(atoms) + + @staticmethod + def cdxml_to_ase_from_string(cdxml_content: str, target_cc_distance: float = 1.43) -> ase.Atoms: + """ + Converts CDXML content provided as a string into an ASE Atoms object, + scaling coordinates so that the smallest C-C distance is target_cc_distance (default: 1.43 Å). + Atoms without an 'Element' attribute are considered Carbon ('C'). + Atoms with an 'Element' attribute use the periodic table symbol. + + Args: + cdxml_content (str): The content of the CDXML file as a string. + target_cc_distance (float): Desired minimum C-C distance (default: 1.43 Å). + + Returns: + Atoms: An ASE Atoms object with scaled coordinates. + """ + # Parse the CDXML content from the string + root = ET.fromstring(cdxml_content) + + # Extract atom data from 'n' elements + symbols = [] + positions = [] + + for atom in root.findall('.//n'): + # Determine the element symbol + if 'Element' in atom.attrib: + # Convert atomic number to element symbol using ASE's chemical_symbols + element_number = int(atom.get('Element')) + if element_number < len(chemical_symbols): + element = chemical_symbols[element_number] + else: + raise ValueError(f"Unknown atomic number {element_number} in CDXML content.") + else: + # Default to Carbon ('C') if no Element attribute is present + element = 'C' - need_hydrogen = [] - for atm in atoms: - if len(n_l.get_neighbors(atm.index)[0]) < 3: - if atm.symbol == "C" or atm.symbol == "N": - need_hydrogen.append(atm.index) + symbols.append(element) - print("Added missing Hydrogen atoms: ", need_hydrogen) + # Get 2D coordinates from 'p' attribute and assume z=0 + p = atom.get('p', '0 0').split() + x, y = float(p[0]), float(p[1]) + positions.append([x, y, 0.0]) - for atm in need_hydrogen: - vec = np.zeros(3) - indices, offsets = n_l.get_neighbors(atoms[atm].index) - for i, offset in zip(indices, offsets): - vec += -atoms[atm].position + ( - atoms.positions[i] + np.dot(offset, atoms.get_cell()) - ) - vec = -vec / np.linalg.norm(vec) * 1.1 + atoms[atm].position - atoms.append(ase.Atom("H", vec)) + if not symbols or not positions: + raise ValueError("No valid atoms found in the CDXML content.") - return atoms + # Convert positions to a numpy array + positions = np.array(positions) - def _on_picked(self, _=None): - """When an attom is picked.""" + # Find the smallest C-C distance + carbon_indices = [i for i, sym in enumerate(symbols) if sym == 'C'] + if len(carbon_indices) < 2: + raise ValueError("Not enough Carbon atoms to calculate C-C distance.") - if "atom1" not in self.viewer.picked.keys(): - return # did not click on atom - self.create_cell_btn.disabled = True + # Calculate pairwise distances between all Carbon atoms + carbon_positions = positions[carbon_indices] + cc_distances = pdist(carbon_positions) - with self.picked_out: - clear_output() + # Find the minimum C-C distance + min_cc_distance = np.min(cc_distances) - self.viewer.component_0.remove_ball_and_stick() - self.viewer.component_0.remove_ball_and_stick() - self.viewer.add_ball_and_stick() + # Scale coordinates to set the minimum C-C distance to target_cc_distance + scale_factor = target_cc_distance / min_cc_distance + positions *= scale_factor - idx = self.viewer.picked["atom1"]["index"] + # Create an ASE Atoms object with the scaled positions + ase_atoms = ase.Atoms(symbols=symbols, positions=positions) - # Toggle. - if idx in self.selection: - self.selection.remove(idx) - else: - self.selection.add(idx) + return ase_atoms + + @staticmethod + def transform_points(set1, set2, points): + """ + Transform a set of points based on the scaling and rotation that aligns set1 to set2. + + Args: + set1 (list of tuples): Reference set of points (e.g., [(x1, y1), ...]). + set2 (list of tuples): Transformed set of points (e.g., [(x2, y2), ...]). + points (list of tuples): Points to transform (e.g., [(px1, py1), ...]). + + Returns: + list of tuples: Transformed points. + """ + # set1 = np.array(set1) + # set2 = np.array(set2) + # points = np.array(points) + + # Compute centroids of set1 and set2 + centroid1 = np.mean(set1, axis=0) + centroid2 = np.mean(set2, axis=0) + + # Center the sets around their centroids + centered_set1 = set1 - centroid1 + centered_set2 = set2 - centroid2 + + # Compute the scaling factor + norm1 = np.linalg.norm(centered_set1, axis=1).mean() + norm2 = np.linalg.norm(centered_set2, axis=1).mean() + scale = norm2 / norm1 + + # Compute the rotation matrix using Singular Value Decomposition (SVD) + cross_covariance = np.dot(centered_set1.T, centered_set2) + u_matrix, _, vt_matrix = np.linalg.svd(cross_covariance) + rotation_m = np.dot(vt_matrix.T, u_matrix.T) # Rotation matrix + + # Apply the scaling and rotation to the points + transformed_points = ( + scale * np.dot(points - centroid1, rotation_m.T) + centroid2 + ) - if len(self.selection) == 2: - self.create_cell_btn.disabled = False + return transformed_points.tolist() + + @staticmethod + def max_extension_points(points): + """ + Given a list of points, checks whether the maximum extension is along the x-axis or y-axis, + and returns two points accordingly. + + Args: + points (list of tuple): List of points as (x, y, z) coordinates. + + Returns: + tuple: Two points as ((x1, y1, z1), (x2, y2, z2)) + """ + # Unpack x, y, and z coordinates + x_coords = [p[0] for p in points] + y_coords = [p[1] for p in points] + + # Calculate the range along x and y + x_range = max(x_coords) - min(x_coords) + y_range = max(y_coords) - min(y_coords) + + # Determine the points based on the largest range + if x_range >= y_range: + minx, maxx = min(x_coords), max(x_coords) + return np.array([[minx - 7.5, 0, 0], [maxx + 7.5, 0, 0]]) + else: + miny, maxy = min(y_coords), max(y_coords) + return np.array([[0, miny - 7.5, 0], [0, maxy + 7.5, 0]]) + + + + def extract_crossing_and_atom_positions(self, cdxml_content: str): + """ + Extract the first two crossing points such that the vector connecting them is aligned with the unit vector. + + Args: + cdxml_content (str): The content of the CDXML file as a string. + + Returns: + tuple: Three values: + - crossing_points_pair: A tuple of two crossing points that are aligned with the unit vector. + The unit_vector is a vector with positive x and y, parallel to the vector connecting the square brackets + - atom_positions: Atom positions as a numpy array of shape (M, 3). + - isnotperiodic (bool): Indicates whether the structure is non-periodic. + """ + root = ET.fromstring(cdxml_content) + + # Parse all atom positions + atom_positions = [] + atom_id_map = {} + for node in root.findall(".//n"): + atom_id = node.get("id") + if atom_id and "p" in node.attrib: + position = tuple(map(float, node.attrib["p"].split())) + atom_positions.append((position[0], position[1], 0.0)) # Add z=0 + atom_id_map[atom_id] = len(atom_positions) - 1 # Map atom ID to index + + atom_positions = np.array(atom_positions) + + # Parse crossing bonds and compute crossing points + crossing_points = [] + + for crossing in root.findall(".//crossingbond"): + bond_id = crossing.get("BondID") + + if bond_id: + bond = root.find(f".//b[@id='{bond_id}']") + if bond is not None: + start_id = bond.get("B") + end_id = bond.get("E") + if start_id in atom_id_map and end_id in atom_id_map: + start_pos = atom_positions[atom_id_map[start_id]] + end_pos = atom_positions[atom_id_map[end_id]] + + midpoint = ( + (start_pos[0] + end_pos[0]) / 2, + (start_pos[1] + end_pos[1]) / 2, + 0.0, # Add z=0 + ) + crossing_points.append(midpoint) + + crossing_points = np.array(crossing_points) + + # Parse square parentheses + brackets = [] + for graphic in root.findall(".//graphic[@BracketType='Square']"): + if "BoundingBox" in graphic.attrib: + bb = list(map(float, graphic.attrib["BoundingBox"].split())) + x_min, y_min, x_max, y_max = bb + + midpoint = ((x_min + x_max) / 2, (y_min + y_max) / 2, 0.0) + brackets.append(midpoint) + + isnotperiodic = True + if len(brackets) == 0: + twopoints = self.max_extension_points(atom_positions) + return twopoints, atom_positions, isnotperiodic + + if len(brackets) == 2: + brackets = np.array(brackets) + vector = brackets[1] - brackets[0] + unit_vector = vector[:2] / np.linalg.norm(vector[:2]) + + if unit_vector[0] < 0 or unit_vector[1] < 0: + unit_vector = -unit_vector + + for i in range(len(crossing_points)): + for j in range(i + 1, len(crossing_points)): + vector = crossing_points[j][:2] - crossing_points[i][:2] + unit_test_vector = vector / np.linalg.norm(vector) + if np.dot(unit_test_vector, unit_vector) > 0.99: + isnotperiodic = False + return crossing_points[[i, j]], atom_positions, isnotperiodic + + return None, atom_positions, True - # if(selection): - sel_str = ",".join([str(i) for i in sorted(self.selection)]) - print("Selected atoms: " + sel_str) - self.viewer.add_representation( - "ball+stick", selection="@" + sel_str, color="red", aspectRatio=3.0 - ) - self.viewer.picked = ( - {} - ) # reset, otherwise immidiately selecting same atom again won't create change event - def _on_file_upload(self, change=None): - """When file upload button is pressed.""" - self.create_cell_btn.disabled = True - - fname, item = next(iter(change["new"].items())) - frmt = fname.split(".")[-1] - if frmt == "cdxml": - options = [ - self.rdkit2ase(mol) - for mol in rdkit.Chem.MolsFromCDXML(item["content"].decode("ascii")) - ] - self.allmols.options = [ - (f"{i}: " + mol.get_chemical_formula(), mol) - for i, mol in enumerate(options) + @staticmethod + def align_and_trim_atoms(atoms, crossing_points, units=None): + """ + Aligns an ASE Atoms object with the x-axis based on two crossing points, + trims or replicates atoms based on specified x-bounds, sets a new unit cell, and centers the structure. + + Args: + atoms (ASE.Atoms): The ASE Atoms object to transform. + crossing_points (numpy.ndarray): A 2x3 NumPy array containing two crossing points. + n_units (int, optional): Number of units to replicate along the x-axis. If None, trims atoms. + + Returns: + ASE.Atoms: The transformed ASE Atoms object. + """ + # Ensure crossing_points is a 2x3 array + assert crossing_points.shape == ( + 2, + 3, + ), "crossing_points must be a 2x3 NumPy array." + + # Calculate the vector connecting the two crossing points + vector = crossing_points[1] - crossing_points[0] + norm_vector = np.linalg.norm(vector) + + # Calculate rotation angle to align the vector with the x-axis + angle = np.arctan2(vector[1], vector[0]) + + # Rotate the atoms and crossing points + atoms.rotate(-np.degrees(angle), "z", center=(0, 0, 0)) + rotation_matrix = np.array( + [ + [np.cos(-angle), -np.sin(-angle), 0], + [np.sin(-angle), np.cos(-angle), 0], + [0, 0, 1], ] - self.allmols.disabled = False - - def _on_sketch_selected(self, change=None): - self.structure = None # needed to empty view in second viewer - if self.allmols.value is None: - return - self.create_cell_btn.disabled = True - atoms = self.allmols.value - factor = self.guess_scaling_factor(atoms) - struct = self.scale(atoms, factor) - self.select_two.value = ( - "

Select two equivalent atoms that define the basis vector

" ) - self.original_structure = struct.copy() - - if hasattr(self.viewer, "component_0"): - self.viewer.component_0.remove_ball_and_stick() - cid = self.viewer.component_0.id - self.viewer.remove_component(cid) - - # Empty selection. - self.selection = set() + rotated_crossing_points = np.dot(crossing_points, rotation_matrix.T) + + # Define x-bounds based on crossing points + x_min = min(rotated_crossing_points[:, 0]) + x_max = max(rotated_crossing_points[:, 0]) + 0.1 + + # Extract positions and define the mask for atoms within bounds + positions = atoms.get_positions() + mask = (positions[:, 0] > x_min) & (positions[:, 0] <= x_max) + bounded_atoms = atoms[mask].copy() + mask = positions[:, 0] <= x_min + tail_atoms = atoms[mask].copy() + mask = positions[:, 0] > x_max + head_atoms = atoms[mask].copy() + try: + n_units = int(units) + except ValueError: + n_units = None + + if n_units is None or n_units < 1: + # Trim atoms based on x-bounds + atoms = bounded_atoms + else: + # Replicate atoms for n_units + replicated_atoms = bounded_atoms.copy() + for ni in range(1, n_units): + shifted_positions = bounded_atoms.get_positions() + np.array( + [ni * norm_vector, 0, 0] + ) + replicated_atoms += ase.Atoms( + bounded_atoms.get_chemical_symbols(), positions=shifted_positions + ) - # Add new component. - self.viewer.add_component(nglview.ASEStructure(struct)) # adds ball+stick - self.viewer.center() - self.viewer.handle_resize() - self.file_upload.value.clear() + # Add atoms shifted beyond xmax + shifted_positions = head_atoms.get_positions() + np.array( + [(n_units - 1) * norm_vector, 0, 0] + ) + replicated_atoms += ase.Atoms( + head_atoms.get_chemical_symbols(), positions=shifted_positions + ) + replicated_atoms += tail_atoms + atoms = replicated_atoms + + # Set the new unit cell + if n_units is None or n_units < 1: + l1 = norm_vector + atoms.set_periodic=True + else: + l1 = ( + np.ptp(atoms.get_positions()[:, 0]) + 15.0 + ) # Size in x-direction + 10 Å + l2 = 15.0 + np.ptp(atoms.get_positions()[:, 1]) # Size in y-direction + 15 Å + l3 = 15.0 # Fixed value + atoms.set_cell([l1, l2, l3]) + atoms.center() - def _on_cell_button_pressed(self, _=None): - """Generate GNR button pressed.""" - self.create_cell_btn.disabled = True - with self.cell_button_out: - clear_output() - if len(self.selection) != 2: - print("You must select exactly two atoms") - return - - id1 = sorted(self.selection)[0] - id2 = sorted(self.selection)[1] - incoming_struct = self.original_structure.copy() - self.structure = self.construct_cell(self.original_structure, id1, id2) - self.original_structure = incoming_struct.copy() - - if hasattr(self.viewer, "component_0"): - self.viewer.component_0.remove_ball_and_stick() - cid = self.viewer.component_0.id - self.viewer.remove_component(cid) - # Empty selection. - self.selection = set() - - # Add new component. - self.viewer.add_component( - nglview.ASEStructure(self.original_structure) - ) # adds ball+stick - self.viewer.center() - self.viewer.handle_resize() + return atoms From bc9b9530b85f29e95845505da8a344170f99e7f2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 22 Feb 2025 07:06:00 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nanoribbon/viewers/cdxml2gnr.py | 37 +++++++++++++++++---------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/nanoribbon/viewers/cdxml2gnr.py b/nanoribbon/viewers/cdxml2gnr.py index 945ed03..eab5b60 100644 --- a/nanoribbon/viewers/cdxml2gnr.py +++ b/nanoribbon/viewers/cdxml2gnr.py @@ -1,9 +1,10 @@ """Widget to convert CDXML to planar structures""" import xml.etree.ElementTree as ET -import nglview as nv + import ase import ipywidgets as ipw +import nglview as nv import numpy as np import traitlets as tr from ase import Atoms @@ -127,7 +128,6 @@ def _on_button_click(self, _=None): return atoms = self.atoms.copy() - if self.crossing_points is not None: crossing_points = self.transform_points( @@ -149,9 +149,11 @@ def _on_button_click(self, _=None): atoms.pbc = True self.output_message.value, self.structure = self.add_hydrogen_atoms(atoms) - + @staticmethod - def cdxml_to_ase_from_string(cdxml_content: str, target_cc_distance: float = 1.43) -> ase.Atoms: + def cdxml_to_ase_from_string( + cdxml_content: str, target_cc_distance: float = 1.43 + ) -> ase.Atoms: """ Converts CDXML content provided as a string into an ASE Atoms object, scaling coordinates so that the smallest C-C distance is target_cc_distance (default: 1.43 Å). @@ -172,23 +174,25 @@ def cdxml_to_ase_from_string(cdxml_content: str, target_cc_distance: float = 1.4 symbols = [] positions = [] - for atom in root.findall('.//n'): + for atom in root.findall(".//n"): # Determine the element symbol - if 'Element' in atom.attrib: + if "Element" in atom.attrib: # Convert atomic number to element symbol using ASE's chemical_symbols - element_number = int(atom.get('Element')) + element_number = int(atom.get("Element")) if element_number < len(chemical_symbols): element = chemical_symbols[element_number] else: - raise ValueError(f"Unknown atomic number {element_number} in CDXML content.") + raise ValueError( + f"Unknown atomic number {element_number} in CDXML content." + ) else: # Default to Carbon ('C') if no Element attribute is present - element = 'C' + element = "C" symbols.append(element) # Get 2D coordinates from 'p' attribute and assume z=0 - p = atom.get('p', '0 0').split() + p = atom.get("p", "0 0").split() x, y = float(p[0]), float(p[1]) positions.append([x, y, 0.0]) @@ -199,7 +203,7 @@ def cdxml_to_ase_from_string(cdxml_content: str, target_cc_distance: float = 1.4 positions = np.array(positions) # Find the smallest C-C distance - carbon_indices = [i for i, sym in enumerate(symbols) if sym == 'C'] + carbon_indices = [i for i, sym in enumerate(symbols) if sym == "C"] if len(carbon_indices) < 2: raise ValueError("Not enough Carbon atoms to calculate C-C distance.") @@ -217,8 +221,8 @@ def cdxml_to_ase_from_string(cdxml_content: str, target_cc_distance: float = 1.4 # Create an ASE Atoms object with the scaled positions ase_atoms = ase.Atoms(symbols=symbols, positions=positions) - return ase_atoms - + return ase_atoms + @staticmethod def transform_points(set1, set2, points): """ @@ -260,7 +264,7 @@ def transform_points(set1, set2, points): ) return transformed_points.tolist() - + @staticmethod def max_extension_points(points): """ @@ -289,8 +293,6 @@ def max_extension_points(points): miny, maxy = min(y_coords), max(y_coords) return np.array([[0, miny - 7.5, 0], [0, maxy + 7.5, 0]]) - - def extract_crossing_and_atom_positions(self, cdxml_content: str): """ Extract the first two crossing points such that the vector connecting them is aligned with the unit vector. @@ -376,7 +378,6 @@ def extract_crossing_and_atom_positions(self, cdxml_content: str): return None, atom_positions, True - @staticmethod def align_and_trim_atoms(atoms, crossing_points, units=None): """ @@ -459,7 +460,7 @@ def align_and_trim_atoms(atoms, crossing_points, units=None): # Set the new unit cell if n_units is None or n_units < 1: l1 = norm_vector - atoms.set_periodic=True + atoms.set_periodic = True else: l1 = ( np.ptp(atoms.get_positions()[:, 0]) + 15.0