Skip to content

Commit

Permalink
Add support for arbitrary json in conn uri format (apache#15100)
Browse files Browse the repository at this point in the history
Currently in airflow web UI and the CLI you can store arbitrary (e.g. nested) json in the `extra` field.  But the URI format can only handle primitive key-value pairs.  This PR provides support for arbitrary json in the URI format.  

Co-authored-by: Daniel Standish <[email protected]>
Co-authored-by: Ash Berlin-Taylor <[email protected]>
  • Loading branch information
3 people authored Apr 14, 2021
1 parent 7490c6b commit a4c4e61
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 38 deletions.
19 changes: 16 additions & 3 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class Connection(Base, LoggingMixin): # pylint: disable=too-many-instance-attri
:type uri: str
"""

EXTRA_KEY = '__extra__'

__tablename__ = "connection"

id = Column(Integer(), primary_key=True)
Expand Down Expand Up @@ -161,7 +163,11 @@ def _parse_from_uri(self, uri: str):
self.password = unquote(uri_parts.password) if uri_parts.password else uri_parts.password
self.port = uri_parts.port
if uri_parts.query:
self.extra = json.dumps(dict(parse_qsl(uri_parts.query, keep_blank_values=True)))
query = dict(parse_qsl(uri_parts.query, keep_blank_values=True))
if self.EXTRA_KEY in query:
self.extra = query[self.EXTRA_KEY]
else:
self.extra = json.dumps(query)

def get_uri(self) -> str:
"""Return connection in URI format"""
Expand Down Expand Up @@ -194,8 +200,15 @@ def get_uri(self) -> str:

uri += host_block

if self.extra_dejson:
uri += f'?{urlencode(self.extra_dejson)}'
if self.extra:
try:
query = urlencode(self.extra_dejson)
except TypeError:
query = None
if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)):
uri += '?' + query
else:
uri += '?' + urlencode({self.EXTRA_KEY: self.extra})

return uri

Expand Down
68 changes: 52 additions & 16 deletions docs/apache-airflow/howto/connection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,6 @@ In general, Airflow's URI format is like so:
my-conn-type://my-login:my-password@my-host:5432/my-schema?param1=val1&param2=val2
.. note::

The params ``param1`` and ``param2`` are just examples; you may supply arbitrary urlencoded json-serializable data there.

The above URI would produce a ``Connection`` object equivalent to the following:

.. code-block:: python
Expand All @@ -232,17 +228,6 @@ The above URI would produce a ``Connection`` object equivalent to the following:
extra=json.dumps(dict(param1='val1', param2='val2'))
)
You can verify a URI is parsed correctly like so:

.. code-block:: pycon
>>> from airflow.models.connection import Connection
>>> c = Connection(uri='my-conn-type://my-login:my-password@my-host:5432/my-schema?param1=val1&param2=val2')
>>> print(c.login)
my-login
>>> print(c.password)
my-password
.. _generating_connection_uri:

Expand Down Expand Up @@ -289,12 +274,63 @@ Additionally, if you have created a connection, you can use ``airflow connection
.. _manage-connections-connection-types:

Encoding arbitrary JSON
^^^^^^^^^^^^^^^^^^^^^^^

Some JSON structures cannot be urlencoded without loss. For such JSON, ``get_uri``
will store the entire string under the url query param ``__extra__``.

For example:

.. code-block:: pycon
>>> extra_dict = {'my_val': ['list', 'of', 'values'], 'extra': {'nested': {'json': 'val'}}}
>>> c = Connection(
>>> conn_type='scheme',
>>> host='host/location',
>>> schema='schema',
>>> login='user',
>>> password='password',
>>> port=1234,
>>> extra=json.dumps(extra_dict),
>>> )
>>> uri = c.get_uri()
>>> uri
'scheme://user:password@host%2Flocation:1234/schema?__extra__=%7B%22my_val%22%3A+%5B%22list%22%2C+%22of%22%2C+%22values%22%5D%2C+%22extra%22%3A+%7B%22nested%22%3A+%7B%22json%22%3A+%22val%22%7D%7D%7D'
And we can verify that it returns the same dictionary:

.. code-block:: pycon
>>> new_c = Connection(uri=uri)
>>> new_c.extra_dejson == extra_dict
True
But for the most common case of storing only key-value pairs, plain url encoding is used.

You can verify a URI is parsed correctly like so:

.. code-block:: pycon
>>> from airflow.models.connection import Connection
>>> c = Connection(uri='my-conn-type://my-login:my-password@my-host:5432/my-schema?param1=val1&param2=val2')
>>> print(c.login)
my-login
>>> print(c.password)
my-password
Handling of special characters in connection params
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. note::

This process is automated as described in section :ref:`Generating a Connection URI <generating_connection_uri>`.
Use the convenience method ``Connection.get_uri`` when generating a connection
as described in section :ref:`Generating a Connection URI <generating_connection_uri>`.
This section for informational purposes only.

Special handling is required for certain characters when building a URI manually.

Expand Down
61 changes: 57 additions & 4 deletions tests/models/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,61 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
),
description='with extras',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?' '__extra__=single+value',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra='single value',
),
description='with extras single value',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
'__extra__=arbitrary+string+%2A%29%2A%24',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra='arbitrary string *)*$',
),
description='with extra non-json',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
'__extra__=%5B%22list%22%2C+%22of%22%2C+%22values%22%5D',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra_dejson=['list', 'of', 'values'],
),
description='with extras list',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
'__extra__=%7B%22my_val%22%3A+%5B%22list%22%2C+%22of%22%2C+%22values%22%5D%2C+%22extra%22%3A+%7B%22nested%22%3A+%7B%22json%22%3A+%22val%22%7D%7D%7D', # noqa: E501 # pylint: disable=C0301
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra_dejson={'my_val': ['list', 'of', 'values'], 'extra': {'nested': {'json': 'val'}}},
),
description='with nested json',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?extra1=a%20value&extra2=',
test_conn_attributes=dict(
Expand Down Expand Up @@ -351,11 +406,9 @@ def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig):
for conn_attr, expected_val in test_config.test_conn_attributes.items():
actual_val = getattr(new_conn, conn_attr)
if expected_val is None:
assert expected_val is None
if isinstance(expected_val, dict):
assert expected_val == actual_val
assert actual_val is None
else:
assert expected_val == actual_val
assert actual_val == expected_val

@parameterized.expand(
[
Expand Down
41 changes: 26 additions & 15 deletions tests/secrets/test_local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def test_missing_file(self, mock_exists):

@parameterized.expand(
(
("""CONN_A: 'mysql://host_a'""", {"CONN_A": "mysql://host_a"}),
(
"""CONN_A: 'mysql://host_a'""",
{"CONN_A": {'conn_type': 'mysql', 'host': 'host_a'}},
),
(
"""
conn_a: mysql://hosta
Expand All @@ -216,28 +219,36 @@ def test_missing_file(self, mock_exists):
password: None
port: 1234
extra_dejson:
extra__google_cloud_platform__keyfile_dict:
a: b
arbitrary_dict:
a: b
extra__google_cloud_platform__keyfile_dict: '{"a": "b"}'
extra__google_cloud_platform__keyfile_path: asaa""",
{
"conn_a": "mysql://hosta",
"conn_b": ''.join(
"""scheme://Login:None@host:1234/lschema?
extra__google_cloud_platform__keyfile_dict=%7B%27a%27%3A+%27b%27%7D
&extra__google_cloud_platform__keyfile_path=asaa""".split()
),
"conn_a": {'conn_type': 'mysql', 'host': 'hosta'},
"conn_b": {
'conn_type': 'scheme',
'host': 'host',
'schema': 'lschema',
'login': 'Login',
'password': 'None',
'port': 1234,
'extra_dejson': {
'arbitrary_dict': {"a": "b"},
'extra__google_cloud_platform__keyfile_dict': '{"a": "b"}',
'extra__google_cloud_platform__keyfile_path': 'asaa',
},
},
},
),
)
)
def test_yaml_file_should_load_connection(self, file_content, expected_connection_uris):
def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dict):
with mock_local_file(file_content):
connections_by_conn_id = local_filesystem.load_connections_dict("a.yaml")
connection_uris_by_conn_id = {
conn_id: connection.get_uri() for conn_id, connection in connections_by_conn_id.items()
}

assert expected_connection_uris == connection_uris_by_conn_id
for conn_id, connection in connections_by_conn_id.items():
expected_attrs = expected_attrs_dict[conn_id]
actual_attrs = {k: getattr(connection, k) for k in expected_attrs.keys()}
assert actual_attrs == expected_attrs

@parameterized.expand(
(
Expand Down

0 comments on commit a4c4e61

Please sign in to comment.