Skip to content

Commit

Permalink
hotfix for multi-lines imports %load_node (#4068)
Browse files Browse the repository at this point in the history
* add new test case - failed as expected

Signed-off-by: Nok <[email protected]>

* add notes

Signed-off-by: Nok <[email protected]>

* Add a fix for multi-lines imports

Signed-off-by: Nok Lam Chan <[email protected]>

* fix

Signed-off-by: Nok Lam Chan <[email protected]>

* fix linter

Signed-off-by: Nok Lam Chan <[email protected]>

* fix lint

Signed-off-by: Nok <[email protected]>

---------

Signed-off-by: Nok <[email protected]>
Signed-off-by: Nok Lam Chan <[email protected]>
  • Loading branch information
noklam authored Aug 14, 2024
1 parent 9cdddef commit 2a97dd4
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 6 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

## Bug fixes and other changes
* Moved `_find_run_command()` and `_find_run_command_in_plugins()` from `__main__.py` in the project template to the framework itself.
* Fixed a bug where `%load_node` breaks with multi-lines import statements.

## Breaking changes to the API

Expand Down
26 changes: 23 additions & 3 deletions kedro/ipython/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,30 @@ def _prepare_imports(node_func: Callable) -> str:
if python_file:
import_statement = []
with open(python_file) as file:
# Handle multiline imports, i.e.
# from lib import (
# a,
# b,
# c
# )
# This will not work with all edge cases but good enough with common cases that
# are formatted automatically by black, ruff etc.
inside_bracket = False
# Parse any line start with from or import statement
for line in file.readlines():
if line.startswith("from") or line.startswith("import"):
import_statement.append(line.strip())

for _ in file.readlines():
line = _.strip()
if not inside_bracket:
# The common case
if line.startswith("from") or line.startswith("import"):
import_statement.append(line)
if line.endswith("("):
inside_bracket = True
# Inside multi-lines import, append everything.
else:
import_statement.append(line)
if line.endswith(")"):
inside_bracket = False

clean_imports = "\n".join(import_statement).strip()
return clean_imports
Expand Down
3 changes: 2 additions & 1 deletion tests/ipython/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from kedro.pipeline import node
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline

from . import dummy_function_fixtures # It is needed for the inspect module
from . import dummy_function_fixtures
from .dummy_function_fixtures import (
dummy_function,
dummy_function_with_loop,
dummy_function_with_variable_length,
dummy_nested_function,
)
from .dummy_multiline_fixtures import dummy_multiline_import_function # noqa: F401

# Constants
PACKAGE_NAME = "fake_package_name"
Expand Down
18 changes: 18 additions & 0 deletions tests/ipython/dummy_multiline_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# ruff: noqa
# multi-lines import
from logging import (
INFO,
DEBUG,
WARN,
ERROR,
)


def dummy_multiline_import_function(dummy_input, my_input):
"""
Returns True if input is not
"""
# this is an in-line comment in the body of the function
random_assignment = "Added for a longer function"
random_assignment += "make sure to modify variable"
return not dummy_input
20 changes: 18 additions & 2 deletions tests/ipython/test_ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
)
from kedro.pipeline.modular_pipeline import pipeline as modular_pipeline

from .conftest import dummy_function, dummy_function_with_loop, dummy_nested_function
from .conftest import (
dummy_function,
dummy_function_with_loop,
dummy_multiline_import_function,
dummy_nested_function,
)


class TestLoadKedroObjects:
Expand Down Expand Up @@ -338,7 +343,7 @@ def test_node_not_found(self, dummy_pipelines):
in str(excinfo.value)
)

def test_prepare_imports(self, mocker, dummy_module_literal):
def test_prepare_imports(self, mocker):
func_imports = """import logging # noqa
from logging import config # noqa
import logging as dummy_logging # noqa
Expand All @@ -347,6 +352,17 @@ def test_prepare_imports(self, mocker, dummy_module_literal):
result = _prepare_imports(dummy_function)
assert result == func_imports

def test_prepare_imports_multiline(self, mocker):
func_imports = """from logging import (
INFO,
DEBUG,
WARN,
ERROR,
)"""

result = _prepare_imports(dummy_multiline_import_function)
assert result == func_imports

def test_prepare_imports_func_not_found(self, mocker):
mocker.patch("inspect.getsourcefile", return_value=None)

Expand Down

0 comments on commit 2a97dd4

Please sign in to comment.