Skip to content

Commit

Permalink
Fix missing on_load trigger for folder-based plugins (apache#15208)
Browse files Browse the repository at this point in the history
  • Loading branch information
jedcunningham authored Apr 6, 2021
1 parent 042be2e commit 97b7780
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
23 changes: 14 additions & 9 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,23 @@ def is_valid_plugin(plugin_obj):
return False


def register_plugin(plugin_instance):
"""
Start plugin load and register it after success initialization
:param plugin_instance: subclass of AirflowPlugin
"""
global plugins # pylint: disable=global-statement
plugin_instance.on_load()
plugins.append(plugin_instance)


def load_entrypoint_plugins():
"""
Load and register plugins AirflowPlugin subclasses from the entrypoints.
The entry_point group should be 'airflow.plugins'.
"""
global import_errors # pylint: disable=global-statement
global plugins # pylint: disable=global-statement

log.debug("Loading plugins from entrypoints")

Expand All @@ -202,10 +212,8 @@ def load_entrypoint_plugins():
continue

plugin_instance = plugin_class()
if callable(getattr(plugin_instance, 'on_load', None)):
plugin_instance.on_load()
plugin_instance.source = EntryPointSource(entry_point, dist)
plugins.append(plugin_instance)
plugin_instance.source = EntryPointSource(entry_point, dist)
register_plugin(plugin_instance)
except Exception as e: # pylint: disable=broad-except
log.exception("Failed to import plugin %s", entry_point.name)
import_errors[entry_point.module] = str(e)
Expand All @@ -214,11 +222,9 @@ def load_entrypoint_plugins():
def load_plugins_from_plugin_directory():
"""Load and register Airflow Plugins from plugins directory"""
global import_errors # pylint: disable=global-statement
global plugins # pylint: disable=global-statement
log.debug("Loading plugins from directory: %s", settings.PLUGINS_FOLDER)

for file_path in find_path_from_directory(settings.PLUGINS_FOLDER, ".airflowignore"):

if not os.path.isfile(file_path):
continue
mod_name, file_ext = os.path.splitext(os.path.split(file_path)[-1])
Expand All @@ -236,8 +242,7 @@ def load_plugins_from_plugin_directory():
for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)):
plugin_instance = mod_attr_value()
plugin_instance.source = PluginsDirectorySource(file_path)
plugins.append(plugin_instance)

register_plugin(plugin_instance)
except Exception as e: # pylint: disable=broad-except
log.exception(e)
log.error('Failed to import plugin %s', file_path)
Expand Down
7 changes: 7 additions & 0 deletions tests/plugins/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,10 @@ class MockPluginB(AirflowPlugin):

class MockPluginC(AirflowPlugin):
name = 'plugin-c'


class AirflowTestOnLoadPlugin(AirflowPlugin):
name = 'preload'

def on_load(self, *args, **kwargs):
self.name = 'postload'
47 changes: 47 additions & 0 deletions tests/plugins/test_plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,32 @@
# under the License.
import importlib
import logging
import os
import sys
import tempfile
from unittest import mock

import pytest

from airflow.hooks.base import BaseHook
from airflow.plugins_manager import AirflowPlugin
from airflow.www import app as application
from tests.test_utils.config import conf_vars
from tests.test_utils.mock_plugins import mock_plugin_manager

py39 = sys.version_info >= (3, 9)
importlib_metadata = 'importlib.metadata' if py39 else 'importlib_metadata'

ON_LOAD_EXCEPTION_PLUGIN = """
from airflow.plugins_manager import AirflowPlugin
class AirflowTestOnLoadExceptionPlugin(AirflowPlugin):
name = 'preload'
def on_load(self, *args, **kwargs):
raise Exception("oops")
"""


class TestPluginsRBAC:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -146,6 +159,40 @@ class TestPropertyHook(BaseHook):
assert caplog.records[-1].levelname == 'DEBUG'
assert caplog.records[-1].msg == 'Loading %d plugin(s) took %.2f seconds'

def test_loads_filesystem_plugins(self, caplog):
from airflow import plugins_manager

with mock.patch('airflow.plugins_manager.plugins', []):
plugins_manager.load_plugins_from_plugin_directory()

assert 5 == len(plugins_manager.plugins)
for plugin in plugins_manager.plugins:
if 'AirflowTestOnLoadPlugin' not in str(plugin):
continue
assert 'postload' == plugin.name
break
else:
pytest.fail("Wasn't able to find a registered `AirflowTestOnLoadPlugin`")

assert caplog.record_tuples == []

def test_loads_filesystem_plugins_exception(self, caplog):
from airflow import plugins_manager

with mock.patch('airflow.plugins_manager.plugins', []):
with tempfile.TemporaryDirectory() as tmpdir:
with open(os.path.join(tmpdir, 'testplugin.py'), "w") as f:
f.write(ON_LOAD_EXCEPTION_PLUGIN)

with conf_vars({('core', 'plugins_folder'): tmpdir}):
plugins_manager.load_plugins_from_plugin_directory()

assert plugins_manager.plugins == []

received_logs = caplog.text
assert 'Failed to import plugin' in received_logs
assert 'testplugin.py' in received_logs

def test_should_warning_about_incompatible_plugins(self, caplog):
class AirflowAdminViewsPlugin(AirflowPlugin):
name = "test_admin_views_plugin"
Expand Down

0 comments on commit 97b7780

Please sign in to comment.