-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Oozie + Datastage + misc improvements (#27)
- [datastage] fix/feat/doc: rm defusedXML dependency, add tests+docs, add more JQ - [oozie] feat: update demo translation - [oozie] fix: add airflow as o2a dep :( - [datastage] fix: return sql as list, if multiple - [datastage] fix: improve sql extraction - [various] chore: standardize 'demo' msg, add/improve unit tests
- Loading branch information
1 parent
2e37762
commit 479be42
Showing
14 changed files
with
1,013 additions
and
281 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,7 @@ | ||
r"""Demonstration of translating AutoSys JIL files into Airflow DAGs. | ||
Contact Astronomer @ https://astronomer.io/contact for access to our full translation. | ||
```pycon | ||
>>> translation_ruleset.test(''' | ||
... insert_job: foo.job | ||
|
@@ -12,7 +15,7 @@ | |
from airflow import DAG | ||
from airflow.providers.ssh.operators.ssh import SSHOperator | ||
from pendulum import DateTime, Timezone | ||
with DAG(dag_id='foo_job', schedule=None, start_date=DateTime(1970, 1, 1, 0, 0, 0), catchup=False, default_args={'owner': '[email protected]'}): | ||
with DAG(dag_id='foo_job', schedule=None, start_date=DateTime(1970, 1, 1, 0, 0, 0), catchup=False, default_args={'owner': '[email protected]'}, doc_md=...): | ||
foo_job_task = SSHOperator(task_id='foo_job', ssh_conn_id='bar', command='"C:\\ldhe\\cxl\\TidalDB\\startApp.cmd" "arg1" "arg2" "arg3"', doc='Foo Job') | ||
``` | ||
|
@@ -121,7 +124,7 @@ def basic_dag_rule(val: dict) -> OrbiterDAG | None: | |
... # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE | ||
from airflow import DAG | ||
from pendulum import DateTime, Timezone | ||
with DAG(dag_id='foo_job', schedule=None, start_date=DateTime(1970, 1, 1, 0, 0, 0), catchup=False, default_args={'owner': '[email protected]'}): | ||
with DAG(dag_id='foo_job', schedule=None, start_date=DateTime(1970, 1, 1, 0, 0, 0), catchup=False, default_args={'owner': '[email protected]'}, doc_md=...): | ||
``` | ||
""" | ||
|
@@ -137,6 +140,9 @@ def basic_dag_rule(val: dict) -> OrbiterDAG | None: | |
return OrbiterDAG( | ||
dag_id=dag_id, | ||
file_path=dag_id + ".py", | ||
doc_md="**Created via [Orbiter](https://astronomer.github.io/orbiter) w/ Demo Translation Ruleset**.\n" | ||
"Contact Astronomer @ [[email protected]](mailto:[email protected]) " | ||
"or at [astronomer.io/contact](https://www.astronomer.io/contact/) for more!", | ||
**default_args, | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,108 @@ | ||
"""Demo translation ruleset for DataStage XML files to Airflow DAGs | ||
Contact Astronomer @ https://astronomer.io/contact for access to our full translation | ||
```pycon | ||
>>> translation_ruleset.test(input_value='''<?xml version="1.0" encoding="UTF-8"?> | ||
... <DSExport> | ||
... <Job Identifier="DataStage_Job" DateModified="2020-11-27" TimeModified="05.07.33"> | ||
... <Record Identifier="V253S0" Type="CustomStage" Readonly="0"> | ||
... <Property Name="Name">SELECT_TABLE</Property> | ||
... <Collection Name="Properties" Type="CustomProperty"> | ||
... <SubRecord> | ||
... <Property Name="Name">XMLProperties</Property> | ||
... <Property Name="Value" PreFormatted="1"><?xml version='1.0' | ||
... encoding='UTF-16'?><Properties | ||
... version='1.1'><Common><Context | ||
... type='int'>1</Context><Variant | ||
... type='string'>1.0</Variant><DescriptorVersion | ||
... type='string'>1.0</DescriptorVersion><PartitionType | ||
... type='int'>-1</PartitionType><RCP | ||
... type='int'>0</RCP></Common><Connection><URL | ||
... modified='1' | ||
... type='string'><![CDATA[jdbc:snowflake://xyz.us-east-1.snowflakecomputing.com/?&warehouse=#XYZ_DB.$snowflake_wh#&db=#DB.$schema#]]></URL><Username | ||
... modified='1' | ||
... type='string'><![CDATA[#DB.$snowflake_userid#]]></Username><Password | ||
... modified='1' | ||
... type='string'><![CDATA[#DB.$snowflake_passwd#]]></Password><Attributes | ||
... modified='1' | ||
... type='string'><![CDATA[]]></Attributes></Connection><Usage><ReadMode | ||
... type='int'><![CDATA[0]]></ReadMode><GenerateSQL | ||
... modified='1' | ||
... type='bool'><![CDATA[0]]></GenerateSQL><EnableQuotedIDs | ||
... type='bool'><![CDATA[0]]></EnableQuotedIDs><SQL><SelectStatement | ||
... collapsed='1' modified='1' | ||
... type='string'><![CDATA[ | ||
... Select 1 as Dummy from db.schema.table Limit 1;]]><ReadFromFileSelect | ||
... type='bool'><![CDATA[0]]></ReadFromFileSelect></SelectStatement><EnablePartitionedReads | ||
... type='bool'><![CDATA[0]]></EnablePartitionedReads></SQL><Transaction><RecordCount | ||
... type='int'><![CDATA[2000]]></RecordCount><IsolationLevel | ||
... type='int'><![CDATA[0]]></IsolationLevel><AutocommitMode | ||
... modified='1' | ||
... type='int'><![CDATA[1]]></AutocommitMode><EndOfWave | ||
... type='int'><![CDATA[0]]></EndOfWave><BeginEnd | ||
... collapsed='1' | ||
... type='bool'><![CDATA[0]]></BeginEnd></Transaction><Session><ArraySize | ||
... type='int'><![CDATA[1]]></ArraySize><FetchSize | ||
... type='int'><![CDATA[0]]></FetchSize><ReportSchemaMismatch | ||
... type='bool'><![CDATA[0]]></ReportSchemaMismatch><DefaultLengthForColumns | ||
... type='int'><![CDATA[200]]></DefaultLengthForColumns><DefaultLengthForLongColumns | ||
... type='int'><![CDATA[20000]]></DefaultLengthForLongColumns><CharacterSetForNonUnicodeColumns | ||
... collapsed='1' | ||
... type='int'><![CDATA[0]]></CharacterSetForNonUnicodeColumns><KeepConductorConnectionAlive | ||
... type='bool'><![CDATA[1]]></KeepConductorConnectionAlive></Session><BeforeAfter | ||
... modified='1' type='bool'><![CDATA[1]]><BeforeSQL | ||
... collapsed='1' modified='1' | ||
... type='string'><![CDATA[]]><ReadFromFileBeforeSQL | ||
... type='bool'><![CDATA[0]]></ReadFromFileBeforeSQL><FailOnError | ||
... type='bool'><![CDATA[1]]></FailOnError></BeforeSQL><AfterSQL | ||
... collapsed='1' modified='1' | ||
... type='string'><![CDATA[ | ||
... SELECT * FROM db.schema.table; | ||
... ]]><ReadFromFileAfterSQL | ||
... type='bool'><![CDATA[0]]></ReadFromFileAfterSQL><FailOnError | ||
... type='bool'><![CDATA[1]]></FailOnError></AfterSQL><BeforeSQLNode | ||
... type='string'><![CDATA[]]><ReadFromFileBeforeSQLNode | ||
... type='bool'><![CDATA[0]]></ReadFromFileBeforeSQLNode><FailOnError | ||
... type='bool'><![CDATA[1]]></FailOnError></BeforeSQLNode><AfterSQLNode | ||
... type='string'><![CDATA[]]><ReadFromFileAfterSQLNode | ||
... type='bool'><![CDATA[0]]></ReadFromFileAfterSQLNode><FailOnError | ||
... type='bool'><![CDATA[1]]></FailOnError></AfterSQLNode></BeforeAfter><Java><ConnectorClasspath | ||
... type='string'><![CDATA[$(DSHOME)/../DSComponents/bin/ccjdbc.jar;$(DSHOME)]]></ConnectorClasspath><HeapSize | ||
... modified='1' | ||
... type='int'><![CDATA[1024]]></HeapSize><ConnectorOtherOptions | ||
... type='string'><![CDATA[-Dcom.ibm.is.cc.options=noisfjars]]></ConnectorOtherOptions></Java><LimitRows | ||
... collapsed='1' | ||
... type='bool'><![CDATA[0]]></LimitRows></Usage></Properties | ||
... ></Property> | ||
... </SubRecord> | ||
... </Collection> | ||
... </Record> | ||
... </Job> | ||
... </DSExport>''').dags['data_stage_job'] # doctest: +ELLIPSIS | ||
from airflow import DAG | ||
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator | ||
from pendulum import DateTime, Timezone | ||
with DAG(dag_id='data_stage_job', schedule=None, start_date=DateTime(1970, 1, 1, 0, 0, 0), catchup=False, doc_md=...): | ||
select_table_task = SQLExecuteQueryOperator(task_id='select_table', conn_id='DB', sql=['Select 1 as Dummy from db.schema.table Limit 1;', 'SELECT * FROM db.schema.table;']) | ||
""" | ||
from __future__ import annotations | ||
from itertools import pairwise | ||
from defusedxml import ElementTree | ||
import inflection | ||
|
||
import json | ||
from itertools import pairwise | ||
|
||
from loguru import logger | ||
|
||
import jq | ||
from orbiter.file_types import FileTypeXML | ||
from orbiter.objects import conn_id | ||
from orbiter.objects.dag import OrbiterDAG | ||
from orbiter.objects.operators.empty import OrbiterEmptyOperator | ||
from orbiter.objects.operators.sql import OrbiterSQLExecuteQueryOperator | ||
from orbiter.objects.task import OrbiterOperator | ||
from orbiter.objects.task_group import OrbiterTaskGroup | ||
from orbiter.objects.task import OrbiterTaskDependency | ||
from orbiter.objects.task_group import OrbiterTaskGroup | ||
from orbiter.rules import ( | ||
dag_filter_rule, | ||
dag_rule, | ||
|
@@ -32,25 +123,51 @@ | |
|
||
|
||
@dag_filter_rule | ||
def basic_dag_filter(val: dict) -> list | None: | ||
"""Filter input down to a list of dictionaries that can be processed by the `@dag_rules`""" | ||
return val["DSExport"][0]["Job"] | ||
def basic_dag_filter(val: dict) -> list[dict] | None: | ||
"""Get `Job` objects from within a parent `DSExport` object | ||
```pycon | ||
>>> basic_dag_filter({"DSExport": [{"Job": [{'@Identifier': 'foo'}, {'@Identifier': 'bar'}]}]}) | ||
[{'@Identifier': 'foo'}, {'@Identifier': 'bar'}] | ||
``` | ||
""" | ||
if ds_export := val.get("DSExport"): | ||
return [job for export in ds_export for job in export.get("Job") if export.get("Job")] | ||
|
||
|
||
@dag_rule | ||
def basic_dag_rule(val: dict) -> OrbiterDAG | None: | ||
"""Translate input into an `OrbiterDAG`""" | ||
try: | ||
dag_id = val["@Identifier"] | ||
dag_id = inflection.underscore(dag_id) | ||
return OrbiterDAG(dag_id=dag_id, file_path=f"{dag_id}.py") | ||
except Exception: | ||
return None | ||
"""Translate input into an `OrbiterDAG`, using the `Identifier` as the DAG ID | ||
```pycon | ||
>>> basic_dag_rule({"@Identifier": "demo.extract_sample_currency_data"}) # doctest: +ELLIPSIS | ||
from airflow import DAG | ||
from pendulum import DateTime, Timezone | ||
with DAG(dag_id='demo.extract_sample_currency_data', schedule=None, start_date=DateTime(1970, 1, 1, 0, 0, 0), catchup=False, doc_md=...): | ||
``` | ||
""" | ||
if dag_id := val.get("@Identifier"): | ||
return OrbiterDAG( | ||
dag_id=dag_id, | ||
file_path=f"{dag_id}.py", | ||
doc_md="**Created via [Orbiter](https://astronomer.github.io/orbiter) w/ Demo Translation Ruleset**.\n" | ||
"Contact Astronomer @ [[email protected]](mailto:[email protected]) " | ||
"or at [astronomer.io/contact](https://www.astronomer.io/contact/) for more!", | ||
) | ||
|
||
|
||
@task_filter_rule | ||
def basic_task_filter(val: dict) -> list | None: | ||
"""Filter input down to a list of dictionaries that can be processed by the `@task_rules`""" | ||
"""Filter input down to a list of dictionaries with `@Type=CustomStage` | ||
```pycon | ||
>>> basic_task_filter({"Record": [{"@Type": "CustomStage"}, {"@Type": "SomethingElse"}]}) | ||
[{'@Type': 'CustomStage'}] | ||
``` | ||
""" | ||
if isinstance(val, dict): | ||
val = json.loads(json.dumps(val, default=str)) # pre-serialize values, for JQ | ||
try: | ||
|
@@ -64,67 +181,91 @@ def basic_task_filter(val: dict) -> list | None: | |
return None | ||
|
||
|
||
@task_rule(priority=2) | ||
def basic_task_rule(val: dict) -> OrbiterOperator | OrbiterTaskGroup | None: | ||
"""Translate input into an Operator (e.g. `OrbiterBashOperator`). will be applied first, with a higher priority""" | ||
if "task_id" in val: | ||
return OrbiterEmptyOperator(task_id=val["task_id"]) | ||
else: | ||
return None | ||
|
||
|
||
def task_common_args(val: dict) -> dict: | ||
""" | ||
Common mappings for all tasks | ||
""" | ||
task_id: str = ( | ||
jq.compile(""".Property[] | select(.["@Name"] == "Name") | .["#text"]""") | ||
.input_value(val) | ||
.first() | ||
) | ||
task_id = inflection.underscore(task_id) | ||
try: | ||
task_id: str = ( | ||
jq.compile(""".Property[] | select(.["@Name"] == "Name") | .["#text"]""") | ||
.input_value(val) | ||
.first() | ||
) | ||
except ValueError: | ||
task_id = "UNKNOWN" | ||
params = {"task_id": task_id} | ||
return params | ||
|
||
|
||
def extract_sql_statements(root): | ||
sql_statements = {} | ||
sql_tags = ["SelectStatement", "BeforeSQL", "AfterSQL"] | ||
@task_rule(priority=2) | ||
def _cannot_map_rule(val: dict) -> OrbiterOperator | OrbiterTaskGroup | None: | ||
"""Translate input into an Operator (e.g. `OrbiterBashOperator`). will be applied first, with a higher priority""" | ||
return OrbiterEmptyOperator(**task_common_args(val)) | ||
|
||
|
||
def extract_sql_statements(root: dict) -> list[str]: | ||
"""Find SQL Statements deeply nested | ||
for tag in sql_tags: | ||
elements = root.findall(f".//{tag}") | ||
for elem in elements: | ||
if elem.text: | ||
sql_text = elem.text.strip() | ||
sql_statements[tag] = sql_text | ||
return sql_statements | ||
Looks for text of `SelectStatement`, `BeforeSQL`, and `AfterSQL` tags | ||
```pycon | ||
>>> extract_sql_statements({ | ||
... "BeforeSQL": [{"#text": ""}], | ||
... "SelectStatement": [{"#text": "SELECT 1 as Dummy from db.schema.table Limit 1;"}], | ||
... "AfterSQL": [{"#text": "SELECT * FROM db.schema.table;"}] | ||
... }) | ||
['SELECT 1 as Dummy from db.schema.table Limit 1;', 'SELECT * FROM db.schema.table;'] | ||
>>> extract_sql_statements({"a": {"b": {"c": {"d": {"SelectStatement": [{"#text": "SELECT 1;"}]}}}}}) | ||
['SELECT 1;'] | ||
``` | ||
""" | ||
if root: | ||
return [ | ||
sql.strip() | ||
for tag in [ | ||
"BeforeSQL", | ||
"SelectStatement", | ||
"AfterSQL" | ||
] | ||
for sql in jq.all(f"""recurse | select(.{tag}?) | .{tag}[]["#text"]""", root) | ||
if sql and sql.strip() | ||
] or None | ||
raise ValueError("No SQL Statements found") | ||
|
||
@task_rule(priority=2) | ||
def sql_command_rule(val) -> OrbiterSQLExecuteQueryOperator | None: | ||
""" | ||
For SQLQueryOperator. | ||
Create a SQL Operator with one or more SQL Statements | ||
```pycon | ||
>>> sql_command_rule({ | ||
... 'Property': [{'@Name': 'Name', '#text': 'SELECT_TABLE'}], | ||
... '@Identifier': 'V253S0', | ||
... "Collection": [{"SubRecord": [{"Property": [{ | ||
... "@PreFormatted": "1", | ||
... "#text": {"Properties": [{ | ||
... "Usage": [{"SQL": [{"SelectStatement": [{"#text": "SELECT 1;"}]}]}], | ||
... }]} | ||
... }]}]}] | ||
... }) # doctest: +ELLIPSIS | ||
select_table_task = SQLExecuteQueryOperator(task_id='select_table', conn_id='DB', sql='SELECT 1;') | ||
``` | ||
""" # noqa: E501 | ||
try: | ||
sql: str = ( | ||
jq.compile( | ||
""".Collection[] | .SubRecord[] | .Property[] | select(.["@PreFormatted"] == "1") | .["#text"] """ | ||
) | ||
.input_value(val) | ||
.first() | ||
) | ||
root = ElementTree.fromstring(sql.encode("utf-16")) | ||
sql_statements = extract_sql_statements(root) | ||
sql = " ".join(sql_statements.values()) | ||
if sql: | ||
if sql := jq.first( | ||
""".Collection[] | .SubRecord[] | .Property[] | select(.["@PreFormatted"] == "1") | .["#text"]""", | ||
val | ||
): | ||
return OrbiterSQLExecuteQueryOperator( | ||
sql=sql, | ||
**conn_id(conn_id="snowflake_default", conn_type="snowflake"), | ||
sql=stmt[0] if len(stmt := extract_sql_statements(sql)) == 1 else stmt, | ||
**conn_id(conn_id="DB"), | ||
**task_common_args(val), | ||
) | ||
except StopIteration: | ||
pass | ||
except (StopIteration, ValueError) as e: | ||
logger.debug(f"[WARNING] No SQL found in {val}, {e}") | ||
return None | ||
|
||
|
||
|
@@ -149,9 +290,7 @@ def basic_task_dependency_rule(val: OrbiterDAG) -> list | None: | |
dag_filter_ruleset=DAGFilterRuleset(ruleset=[basic_dag_filter]), | ||
dag_ruleset=DAGRuleset(ruleset=[basic_dag_rule]), | ||
task_filter_ruleset=TaskFilterRuleset(ruleset=[basic_task_filter]), | ||
task_ruleset=TaskRuleset( | ||
ruleset=[sql_command_rule, basic_task_rule, cannot_map_rule] | ||
), | ||
task_ruleset=TaskRuleset(ruleset=[sql_command_rule, _cannot_map_rule, cannot_map_rule]), | ||
task_dependency_ruleset=TaskDependencyRuleset(ruleset=[basic_task_dependency_rule]), | ||
post_processing_ruleset=PostProcessingRuleset(ruleset=[]), | ||
) |
Oops, something went wrong.