Skip to content

Commit

Permalink
Program to refactor topology.py Issue #788
Browse files Browse the repository at this point in the history
  • Loading branch information
gumyr committed Nov 16, 2024
1 parent e1da165 commit e06337a
Showing 1 changed file with 291 additions and 0 deletions.
291 changes: 291 additions & 0 deletions tools/refactor_topo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
from pathlib import Path
import libcst as cst
from typing import List, Set, Dict, Union
from pprint import pprint
from rope.base.project import Project
from rope.refactor.importutils import ImportOrganizer


class ImportCollector(cst.CSTVisitor):
def __init__(self):
self.imports: Set[str] = set()

def visit_Import(self, node: cst.Import) -> None:
for name in node.names:
# Create a proper statement line
stmt = cst.SimpleStatementLine([node])
self.imports.add(cst.Module([stmt]).code)

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
# Create a proper statement line
stmt = cst.SimpleStatementLine([node])
self.imports.add(cst.Module([stmt]).code)


class ClassExtractor(cst.CSTVisitor):
def __init__(self, class_names_to_extract: List[str]):
self.class_names = class_names_to_extract
self.extracted_classes: Dict[str, cst.ClassDef] = {}

def visit_ClassDef(self, node: cst.ClassDef) -> None:
if node.name.value in self.class_names:
self.extracted_classes[node.name.value] = node


class MixinClassExtractor(cst.CSTVisitor):
def __init__(self):
self.extracted_classes: Dict[str, cst.ClassDef] = {}

def visit_ClassDef(self, node: cst.ClassDef) -> None:
if "Mixin" in node.name.value:
self.extracted_classes[node.name.value] = node


class StandaloneFunctionAndVariableCollector(cst.CSTVisitor):
def __init__(self):
self.functions: List[cst.FunctionDef] = []
self.current_scope_level = 0 # Track nesting level

def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Entering a new class scope, increase nesting level
self.current_scope_level += 1

def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
if self.current_scope_level > 0:
self.current_scope_level -= 1

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
if self.current_scope_level == 0:
self.functions.append(node)


class GlobalVariableExtractor(cst.CSTVisitor):
def __init__(self):
# Store the global variable assignments
self.global_variables: List[cst.Assign] = []

def visit_Module(self, node: cst.Module) -> None:
# Visit all assignments at the module level
for statement in node.body:
if isinstance(statement, cst.SimpleStatementLine):
for assign in statement.body:
if isinstance(assign, cst.Assign):
self.global_variables.append(assign)


def write_topo_class_files(
extracted_classes: Dict[str, cst.ClassDef],
imports: Set[str],
output_dir: Path,
) -> None:
"""
Write files for each group of classes:
1. Separate modules for "Shape", "Compound", "Solid", "Face" + "Shell", "Edge" + "Wire", and "Vertex"
2. "ShapeList" is extracted into its own module and imported by all modules except "Shape"
"""
# Create output directory if it doesn't exist
output_dir.mkdir(parents=True, exist_ok=True)

# Sort imports for consistency
imports_code = "\n".join(imports)

# Define class groupings based on layers
class_groups = {
"shape": ["Shape"],
"vertex": ["Vertex"],
"edge_wire": ["Mixin1D", "Edge", "Wire"],
"face_shell": ["Face", "Shell"],
"solid": ["Mixin3D", "Solid"],
"compound": ["Compound"],
"shape_list": ["ShapeList"],
}

# Write ShapeList class separately
if "ShapeList" in extracted_classes:
class_file = output_dir / "shape_list.py"
shape_list_class = extracted_classes["ShapeList"]
shape_list_module = cst.Module(
body=[*cst.parse_module(imports_code).body, shape_list_class]
)
class_file.write_text(shape_list_module.code)
print(f"Created {class_file}")

for group_name, class_names in class_groups.items():
if group_name == "shape_list":
continue

group_classes = [
extracted_classes[name] for name in class_names if name in extracted_classes
]
if not group_classes:
continue

# Add imports for base classes based on layer dependencies
additional_imports = ["from .utils import *"]
if group_name != "shape":
additional_imports.append("from .shape import Shape")
additional_imports.append("from .shape_list import ShapeList")
if group_name in ["edge_wire", "face_shell", "solid", "compound"]:
additional_imports.append("from .vertex import Vertex")
if group_name in ["face_shell", "solid", "compound"]:
additional_imports.append("from .edge_wire import Edge, Wire")
if group_name in ["solid", "compound"]:
additional_imports.append("from .face_shell import Face, Shell")
if group_name == "compound":
additional_imports.append("from .solid import Solid")

# Create class file (e.g., face_shell.py)
class_file = output_dir / f"{group_name}.py"
all_imports_code = "\n".join([imports_code, *additional_imports])
class_module = cst.Module(
body=[*cst.parse_module(all_imports_code).body, *group_classes]
)
class_file.write_text(class_module.code)
print(f"Created {class_file}")

# Create __init__.py to make it a proper package
init_file = output_dir / "__init__.py"
init_content = []
for group_name in class_groups.keys():
if group_name != "shape_list":
init_content.append(f"from .{group_name} import *")

init_file.write_text("\n".join(init_content))
print(f"Created {init_file}")


def write_utils_file(
source_tree: cst.Module, imports: Set[str], output_dir: Path
) -> None:
"""
Extract and write standalone functions and global variables to a utils.py file.
Args:
source_tree: The parsed source tree
imports: Set of import statements
output_dir: Directory to write the utils file
"""
# Collect standalone functions and global variables
function_collector = StandaloneFunctionAndVariableCollector()
source_tree.visit(function_collector)

variable_collector = GlobalVariableExtractor()
source_tree.visit(variable_collector)

# Create utils file
utils_file = output_dir / "utils.py"

# Prepare the module body
module_body = []

# Add imports
imports_tree = cst.parse_module("\n".join(sorted(imports)))
module_body.extend(imports_tree.body)

# Add global variables with newlines
for var in variable_collector.global_variables:
module_body.append(var)
module_body.append(cst.EmptyLine(indent=False))

# Add a newline between variables and functions
if variable_collector.global_variables and function_collector.functions:
module_body.append(cst.EmptyLine(indent=False))

# Add functions
module_body.extend(function_collector.functions)

# Create the module
utils_module = cst.Module(body=module_body)

# Write the file
utils_file.write_text(utils_module.code)
print(f"Created {utils_file}")


def remove_unused_imports(file_path: Path, project: Project) -> None:
"""Remove unused imports from a Python file using rope.
Args:
file_path: Path to the Python file to clean imports
project: Rope project instance to refresh and use for cleaning
"""
# Get the relative file path from the project root
relative_path = file_path.relative_to(project.address)

# Refresh the project to recognize new files
project.validate()

# Get the resource (file) to work on
resource = project.get_resource(str(relative_path))

# Create import organizer
import_organizer = ImportOrganizer(project)

# Get and apply the changes
changes = import_organizer.organize_imports(resource)
if changes:
changes.do()
print(f"Cleaned imports in {file_path}")
else:
print(f"No unused imports found in {file_path}")


def main():
# Define paths
script_dir = Path(__file__).parent
topo_file = script_dir / "topology.py"
output_dir = script_dir / "topology"

# Define classes to extract
class_names = [
"Shape",
"Compound",
"Solid",
"Shell",
"Face",
"Wire",
"Edge",
"Vertex",
"Mixin0D",
"Mixin1D",
"Mixin2D",
"Mixin3D",
"MixinCompound",
"ShapeList",
]

# Parse source file and collect imports
source_tree = cst.parse_module(topo_file.read_text())
collector = ImportCollector()
source_tree.visit(collector)

# Extract classes
extractor = ClassExtractor(class_names)
source_tree.visit(extractor)

# Extract mixin classes
mixin_extractor = MixinClassExtractor()
source_tree.visit(mixin_extractor)

# Write the class files
write_topo_class_files(
extracted_classes=extractor.extracted_classes,
imports=collector.imports,
output_dir=output_dir,
)

# Write the utils file
write_utils_file(
source_tree=source_tree, imports=collector.imports, output_dir=output_dir
)

# Create a Rope project instance
project = Project(str(script_dir))

# Clean up imports
for file in output_dir.glob("*.py"):
remove_unused_imports(file, project)


if __name__ == "__main__":
main()

0 comments on commit e06337a

Please sign in to comment.