Skip to content

Commit

Permalink
fix: Fixed bugs for analyzing packages and api data creation. (#27)
Browse files Browse the repository at this point in the history
### Summary of Changes
* Fixed a bug in `_get_api.py` which prevented analyzing packages with
different path lengths.
* Fixed a bug in `_ast_visitor.py` for the api data creation, in which
the id of modules did not correctly represent their path.
* Adjusted test snapshots

---------

Co-authored-by: megalinter-bot <[email protected]>
  • Loading branch information
Masara and megalinter-bot authored Nov 10, 2023
1 parent 0651f5b commit 80215a3
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 15 deletions.
66 changes: 57 additions & 9 deletions src/safeds_stubgen/api_analyzer/_ast_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ def enter_moduledef(self, node: MypyFile) -> None:
elif isinstance(definition, ExpressionStmt) and isinstance(definition.expr, StrExpr):
docstring = definition.expr.value

# Create module id to get the full path
id_ = self._create_module_id(node.fullname)

# If we are checking a package node.name will be the package name, but since we get import information from
# the __init__.py file we set the name to __init__
if is_package:
name = "__init__"
id_ += f"/{name}"
else:
name = node.name
id_ = self.__get_id(name)

# Remember module, so we can later add classes and global functions
module = Module(
Expand All @@ -125,7 +128,7 @@ def leave_moduledef(self, _: MypyFile) -> None:
self.api.add_module(module)

def enter_classdef(self, node: ClassDef) -> None:
id_ = self.__get_id(node.name)
id_ = self._create_id_from_stack(node.name)
name = node.name

# Get docstring
Expand Down Expand Up @@ -172,7 +175,7 @@ def leave_classdef(self, _: ClassDef) -> None:

def enter_funcdef(self, node: FuncDef) -> None:
name = node.name
function_id = self.__get_id(name)
function_id = self._create_id_from_stack(name)

is_public = self.is_public(name, node.fullname)
is_static = node.is_static
Expand Down Expand Up @@ -229,7 +232,7 @@ def leave_funcdef(self, _: FuncDef) -> None:
parent.add_method(function)

def enter_enumdef(self, node: ClassDef) -> None:
id_ = self.__get_id(node.name)
id_ = self._create_id_from_stack(node.name)
self.__declaration_stack.append(
Enum(
id=id_,
Expand Down Expand Up @@ -489,7 +492,7 @@ def create_attribute(
docstring = self.docstring_parser.get_attribute_documentation(parent, name)

# Remove __init__ for attribute ids
id_ = self.__get_id(name).replace("__init__/", "")
id_ = self._create_id_from_stack(name).replace("__init__/", "")

return Attribute(
id=id_,
Expand Down Expand Up @@ -605,6 +608,37 @@ def add_reexports(self, module: Module) -> None:

# #### Misc. utilities

def _create_module_id(self, qname: str) -> str:
"""Create an ID for the module object.
Creates the module ID while discarding possible unnecessary information from the module qname.
Paramters
---------
qname : str
The qualified name of the module
Returns
-------
str
ID of the module
"""
package_name = self.api.package

if package_name not in qname:
raise ValueError("Package name could not be found in the qualified name of the module.")

# We have to split the qname of the module at the first occurence of the package name and reconnect it while
# discarding everything in front of it. This is necessary since the qname could contain unwanted information.
module_id = qname.split(f"{package_name}", 1)[-1]

if module_id.startswith("."):
module_id = module_id[1:]

# Replaces dots with slashes and add the package name at the start of the id, since we removed it
module_id = f"/{module_id.replace('.', '/')}" if module_id else ""
return f"{package_name}{module_id}"

def is_public(self, name: str, qualified_name: str) -> bool:
if name.startswith("_") and not name.endswith("__"):
return False
Expand All @@ -625,10 +659,24 @@ def is_public(self, name: str, qualified_name: str) -> bool:
# The slicing is necessary so __init__ functions are not excluded (already handled in the first condition).
return all(not it.startswith("_") for it in qualified_name.split(".")[:-1])

def __get_id(self, name: str) -> str:
segments = [self.api.package]
segments += [
it.name
def _create_id_from_stack(self, name: str) -> str:
"""Create an ID for a new object using previous objects of the stack.
Creates an ID by connecting the previous objects of the __declaration_stack stack and the new objects name,
which is on the highest level.
Paramters
---------
name : str
The name of the new object which lies on the highest level.
Returns
-------
str
ID of the object
"""
segments = [
it.id if isinstance(it, Module) else it.name # Special case, to get the module path info the id
for it in self.__declaration_stack
if not isinstance(it, list) # Check for the linter, on runtime can never be list type
]
Expand Down
2 changes: 1 addition & 1 deletion src/safeds_stubgen/api_analyzer/_get_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _get_mypy_ast(files: list[str], package_paths: list[Path], root: Path) -> li
# Check mypy data key root start
parts = root.parts
graph_keys = list(result.graph.keys())
root_start_after = 0
root_start_after = -1
for i in range(len(parts)):
if ".".join(parts[i:]) in graph_keys:
root_start_after = i
Expand Down
6 changes: 3 additions & 3 deletions tests/safeds_stubgen/__snapshots__/test_main.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@
'description': '',
'full_docstring': '',
}),
'id': 'test_package/another_module/AnotherClass',
'id': 'test_package/another_path/another_module/AnotherClass',
'is_public': True,
'methods': list([
]),
Expand Down Expand Up @@ -1977,7 +1977,7 @@
}),
dict({
'classes': list([
'test_package/another_module/AnotherClass',
'test_package/another_path/another_module/AnotherClass',
]),
'docstring': '''
Another Module Docstring.
Expand All @@ -1989,7 +1989,7 @@
]),
'functions': list([
]),
'id': 'test_package/another_module',
'id': 'test_package/another_path/another_module',
'name': 'another_module',
'qualified_imports': list([
]),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2756,7 +2756,7 @@
# name: test_modules_another_module
dict({
'classes': list([
'test_package/another_module/AnotherClass',
'test_package/another_path/another_module/AnotherClass',
]),
'docstring': '''
Another Module Docstring.
Expand All @@ -2768,7 +2768,7 @@
]),
'functions': list([
]),
'id': 'test_package/another_module',
'id': 'test_package/another_path/another_module',
'name': 'another_module',
'qualified_imports': list([
]),
Expand Down
75 changes: 75 additions & 0 deletions tests/safeds_stubgen/api_analyzer/test_api_visitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import pytest

# noinspection PyProtectedMember
from safeds_stubgen.api_analyzer._api import API

# noinspection PyProtectedMember
from safeds_stubgen.api_analyzer._ast_visitor import MyPyAstVisitor
from safeds_stubgen.docstring_parsing import PlaintextDocstringParser


@pytest.mark.parametrize(
("qname", "expected_id", "package_name"),
[
(
"some.path.package_name.src.data",
"package_name/src/data",
"package_name",
),
(
"some.path.package_name",
"package_name",
"package_name",
),
(
"some.path.no_package",
"",
"package_name",
),
(
"",
"",
"package_name",
),
(
"some.package_name.package_name.src.data",
"package_name/package_name/src/data",
"package_name",
),
(
"some.path.package_name.src.package_name",
"package_name/src/package_name",
"package_name",
),
(
"some.package_name.package_name.src.package_name",
"package_name/package_name/src/package_name",
"package_name",
),
],
ids=[
"With unneeded data",
"Without unneeded data",
"No package name in qname",
"No qname",
"Package name twice in qname 1",
"Package name twice in qname 2",
"Package name twice in qname 3",
],
)
def test__create_module_id(qname: str, expected_id: str, package_name: str) -> None:
api = API(
distribution="dist_name",
package=package_name,
version="1.3",
)

visitor = MyPyAstVisitor(PlaintextDocstringParser(), api)
if not expected_id:
with pytest.raises(ValueError, match="Package name could not be found in the qualified name of the module."):
visitor._create_module_id(qname)
else:
module_id = visitor._create_module_id(qname)
assert module_id == expected_id

0 comments on commit 80215a3

Please sign in to comment.