-
-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Program to refactor topology.py Issue #788
- Loading branch information
Showing
1 changed file
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
from pathlib import Path | ||
import libcst as cst | ||
from typing import List, Set, Dict, Union | ||
from pprint import pprint | ||
from rope.base.project import Project | ||
from rope.refactor.importutils import ImportOrganizer | ||
|
||
|
||
class ImportCollector(cst.CSTVisitor): | ||
def __init__(self): | ||
self.imports: Set[str] = set() | ||
|
||
def visit_Import(self, node: cst.Import) -> None: | ||
for name in node.names: | ||
# Create a proper statement line | ||
stmt = cst.SimpleStatementLine([node]) | ||
self.imports.add(cst.Module([stmt]).code) | ||
|
||
def visit_ImportFrom(self, node: cst.ImportFrom) -> None: | ||
# Create a proper statement line | ||
stmt = cst.SimpleStatementLine([node]) | ||
self.imports.add(cst.Module([stmt]).code) | ||
|
||
|
||
class ClassExtractor(cst.CSTVisitor): | ||
def __init__(self, class_names_to_extract: List[str]): | ||
self.class_names = class_names_to_extract | ||
self.extracted_classes: Dict[str, cst.ClassDef] = {} | ||
|
||
def visit_ClassDef(self, node: cst.ClassDef) -> None: | ||
if node.name.value in self.class_names: | ||
self.extracted_classes[node.name.value] = node | ||
|
||
|
||
class MixinClassExtractor(cst.CSTVisitor): | ||
def __init__(self): | ||
self.extracted_classes: Dict[str, cst.ClassDef] = {} | ||
|
||
def visit_ClassDef(self, node: cst.ClassDef) -> None: | ||
if "Mixin" in node.name.value: | ||
self.extracted_classes[node.name.value] = node | ||
|
||
|
||
class StandaloneFunctionAndVariableCollector(cst.CSTVisitor): | ||
def __init__(self): | ||
self.functions: List[cst.FunctionDef] = [] | ||
self.current_scope_level = 0 # Track nesting level | ||
|
||
def visit_ClassDef(self, node: cst.ClassDef) -> None: | ||
# Entering a new class scope, increase nesting level | ||
self.current_scope_level += 1 | ||
|
||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: | ||
if self.current_scope_level > 0: | ||
self.current_scope_level -= 1 | ||
|
||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: | ||
if self.current_scope_level == 0: | ||
self.functions.append(node) | ||
|
||
|
||
class GlobalVariableExtractor(cst.CSTVisitor): | ||
def __init__(self): | ||
# Store the global variable assignments | ||
self.global_variables: List[cst.Assign] = [] | ||
|
||
def visit_Module(self, node: cst.Module) -> None: | ||
# Visit all assignments at the module level | ||
for statement in node.body: | ||
if isinstance(statement, cst.SimpleStatementLine): | ||
for assign in statement.body: | ||
if isinstance(assign, cst.Assign): | ||
self.global_variables.append(assign) | ||
|
||
|
||
def write_topo_class_files( | ||
extracted_classes: Dict[str, cst.ClassDef], | ||
imports: Set[str], | ||
output_dir: Path, | ||
) -> None: | ||
""" | ||
Write files for each group of classes: | ||
1. Separate modules for "Shape", "Compound", "Solid", "Face" + "Shell", "Edge" + "Wire", and "Vertex" | ||
2. "ShapeList" is extracted into its own module and imported by all modules except "Shape" | ||
""" | ||
# Create output directory if it doesn't exist | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# Sort imports for consistency | ||
imports_code = "\n".join(imports) | ||
|
||
# Define class groupings based on layers | ||
class_groups = { | ||
"shape": ["Shape"], | ||
"vertex": ["Vertex"], | ||
"edge_wire": ["Mixin1D", "Edge", "Wire"], | ||
"face_shell": ["Face", "Shell"], | ||
"solid": ["Mixin3D", "Solid"], | ||
"compound": ["Compound"], | ||
"shape_list": ["ShapeList"], | ||
} | ||
|
||
# Write ShapeList class separately | ||
if "ShapeList" in extracted_classes: | ||
class_file = output_dir / "shape_list.py" | ||
shape_list_class = extracted_classes["ShapeList"] | ||
shape_list_module = cst.Module( | ||
body=[*cst.parse_module(imports_code).body, shape_list_class] | ||
) | ||
class_file.write_text(shape_list_module.code) | ||
print(f"Created {class_file}") | ||
|
||
for group_name, class_names in class_groups.items(): | ||
if group_name == "shape_list": | ||
continue | ||
|
||
group_classes = [ | ||
extracted_classes[name] for name in class_names if name in extracted_classes | ||
] | ||
if not group_classes: | ||
continue | ||
|
||
# Add imports for base classes based on layer dependencies | ||
additional_imports = ["from .utils import *"] | ||
if group_name != "shape": | ||
additional_imports.append("from .shape import Shape") | ||
additional_imports.append("from .shape_list import ShapeList") | ||
if group_name in ["edge_wire", "face_shell", "solid", "compound"]: | ||
additional_imports.append("from .vertex import Vertex") | ||
if group_name in ["face_shell", "solid", "compound"]: | ||
additional_imports.append("from .edge_wire import Edge, Wire") | ||
if group_name in ["solid", "compound"]: | ||
additional_imports.append("from .face_shell import Face, Shell") | ||
if group_name == "compound": | ||
additional_imports.append("from .solid import Solid") | ||
|
||
# Create class file (e.g., face_shell.py) | ||
class_file = output_dir / f"{group_name}.py" | ||
all_imports_code = "\n".join([imports_code, *additional_imports]) | ||
class_module = cst.Module( | ||
body=[*cst.parse_module(all_imports_code).body, *group_classes] | ||
) | ||
class_file.write_text(class_module.code) | ||
print(f"Created {class_file}") | ||
|
||
# Create __init__.py to make it a proper package | ||
init_file = output_dir / "__init__.py" | ||
init_content = [] | ||
for group_name in class_groups.keys(): | ||
if group_name != "shape_list": | ||
init_content.append(f"from .{group_name} import *") | ||
|
||
init_file.write_text("\n".join(init_content)) | ||
print(f"Created {init_file}") | ||
|
||
|
||
def write_utils_file( | ||
source_tree: cst.Module, imports: Set[str], output_dir: Path | ||
) -> None: | ||
""" | ||
Extract and write standalone functions and global variables to a utils.py file. | ||
Args: | ||
source_tree: The parsed source tree | ||
imports: Set of import statements | ||
output_dir: Directory to write the utils file | ||
""" | ||
# Collect standalone functions and global variables | ||
function_collector = StandaloneFunctionAndVariableCollector() | ||
source_tree.visit(function_collector) | ||
|
||
variable_collector = GlobalVariableExtractor() | ||
source_tree.visit(variable_collector) | ||
|
||
# Create utils file | ||
utils_file = output_dir / "utils.py" | ||
|
||
# Prepare the module body | ||
module_body = [] | ||
|
||
# Add imports | ||
imports_tree = cst.parse_module("\n".join(sorted(imports))) | ||
module_body.extend(imports_tree.body) | ||
|
||
# Add global variables with newlines | ||
for var in variable_collector.global_variables: | ||
module_body.append(var) | ||
module_body.append(cst.EmptyLine(indent=False)) | ||
|
||
# Add a newline between variables and functions | ||
if variable_collector.global_variables and function_collector.functions: | ||
module_body.append(cst.EmptyLine(indent=False)) | ||
|
||
# Add functions | ||
module_body.extend(function_collector.functions) | ||
|
||
# Create the module | ||
utils_module = cst.Module(body=module_body) | ||
|
||
# Write the file | ||
utils_file.write_text(utils_module.code) | ||
print(f"Created {utils_file}") | ||
|
||
|
||
def remove_unused_imports(file_path: Path, project: Project) -> None: | ||
"""Remove unused imports from a Python file using rope. | ||
Args: | ||
file_path: Path to the Python file to clean imports | ||
project: Rope project instance to refresh and use for cleaning | ||
""" | ||
# Get the relative file path from the project root | ||
relative_path = file_path.relative_to(project.address) | ||
|
||
# Refresh the project to recognize new files | ||
project.validate() | ||
|
||
# Get the resource (file) to work on | ||
resource = project.get_resource(str(relative_path)) | ||
|
||
# Create import organizer | ||
import_organizer = ImportOrganizer(project) | ||
|
||
# Get and apply the changes | ||
changes = import_organizer.organize_imports(resource) | ||
if changes: | ||
changes.do() | ||
print(f"Cleaned imports in {file_path}") | ||
else: | ||
print(f"No unused imports found in {file_path}") | ||
|
||
|
||
def main(): | ||
# Define paths | ||
script_dir = Path(__file__).parent | ||
topo_file = script_dir / "topology.py" | ||
output_dir = script_dir / "topology" | ||
|
||
# Define classes to extract | ||
class_names = [ | ||
"Shape", | ||
"Compound", | ||
"Solid", | ||
"Shell", | ||
"Face", | ||
"Wire", | ||
"Edge", | ||
"Vertex", | ||
"Mixin0D", | ||
"Mixin1D", | ||
"Mixin2D", | ||
"Mixin3D", | ||
"MixinCompound", | ||
"ShapeList", | ||
] | ||
|
||
# Parse source file and collect imports | ||
source_tree = cst.parse_module(topo_file.read_text()) | ||
collector = ImportCollector() | ||
source_tree.visit(collector) | ||
|
||
# Extract classes | ||
extractor = ClassExtractor(class_names) | ||
source_tree.visit(extractor) | ||
|
||
# Extract mixin classes | ||
mixin_extractor = MixinClassExtractor() | ||
source_tree.visit(mixin_extractor) | ||
|
||
# Write the class files | ||
write_topo_class_files( | ||
extracted_classes=extractor.extracted_classes, | ||
imports=collector.imports, | ||
output_dir=output_dir, | ||
) | ||
|
||
# Write the utils file | ||
write_utils_file( | ||
source_tree=source_tree, imports=collector.imports, output_dir=output_dir | ||
) | ||
|
||
# Create a Rope project instance | ||
project = Project(str(script_dir)) | ||
|
||
# Clean up imports | ||
for file in output_dir.glob("*.py"): | ||
remove_unused_imports(file, project) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |