-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Attempt to use more specific transformer when executing remote entiti…
…es (#3111) Signed-off-by: Yee Hing Tong <[email protected]>
- Loading branch information
1 parent
d732d6f
commit 7f58efe
Showing
5 changed files
with
213 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
from flytekit.models.core.types import EnumType | ||
from flytekit.models.types import LiteralType, UnionType | ||
|
||
|
||
def literal_types_match(downstream: LiteralType, upstream: LiteralType) -> bool: | ||
""" | ||
Returns if two LiteralTypes are the same. | ||
Takes into account arbitrary ordering of enums and unions, otherwise just an equivalence check. | ||
""" | ||
|
||
# If the types are exactly the same, return True | ||
if downstream == upstream: | ||
return True | ||
|
||
if downstream.collection_type: | ||
if not upstream.collection_type: | ||
return False | ||
return literal_types_match(downstream.collection_type, upstream.collection_type) | ||
|
||
if downstream.map_value_type: | ||
if not upstream.map_value_type: | ||
return False | ||
return literal_types_match(downstream.map_value_type, upstream.map_value_type) | ||
|
||
# Handle enum types | ||
if downstream.enum_type and upstream.enum_type: | ||
return _enum_types_match(downstream.enum_type, upstream.enum_type) | ||
|
||
# Handle union types | ||
if downstream.union_type and upstream.union_type: | ||
return _union_types_match(downstream.union_type, upstream.union_type) | ||
|
||
# If none of the above conditions are met, the types are not castable | ||
return False | ||
|
||
|
||
def _enum_types_match(downstream: EnumType, upstream: EnumType) -> bool: | ||
return set(upstream.values) == set(downstream.values) | ||
|
||
|
||
def _union_types_match(downstream: UnionType, upstream: UnionType) -> bool: | ||
if len(downstream.variants) != len(upstream.variants): | ||
return False | ||
|
||
down_sorted = sorted(downstream.variants, key=lambda x: str(x)) | ||
up_sorted = sorted(upstream.variants, key=lambda x: str(x)) | ||
|
||
for downstream_variant, upstream_variant in zip(down_sorted, up_sorted): | ||
if not literal_types_match(downstream_variant, upstream_variant): | ||
return False | ||
|
||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from flytekit.models.core.types import BlobType, EnumType | ||
from flytekit.models.types import LiteralType, StructuredDatasetType, UnionType, SimpleType | ||
from flytekit.core.type_match_checking import literal_types_match | ||
|
||
|
||
def test_exact_match(): | ||
lt = LiteralType(simple=SimpleType.STRING) | ||
assert literal_types_match(lt, lt) is True | ||
|
||
lt2 = LiteralType(simple=SimpleType.FLOAT) | ||
assert literal_types_match(lt, lt2) is False | ||
|
||
|
||
def test_collection_type_match(): | ||
lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING)) | ||
lt2 = LiteralType(collection_type=LiteralType(SimpleType.STRING)) | ||
assert literal_types_match(lt1, lt2) is True | ||
|
||
|
||
def test_collection_type_mismatch(): | ||
lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING)) | ||
lt2 = LiteralType(collection_type=LiteralType(SimpleType.INTEGER)) | ||
assert literal_types_match(lt1, lt2) is False | ||
|
||
|
||
def test_blob_type_match(): | ||
blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1)) | ||
blob2 = LiteralType(blob=BlobType(format="csv", dimensionality=1)) | ||
assert literal_types_match(blob1, blob2) is True | ||
|
||
from flytekit.types.pickle.pickle import FlytePickleTransformer | ||
blob1 = LiteralType(blob=BlobType(format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=1)) | ||
blob2 = LiteralType(blob=BlobType(format="", dimensionality=1)) | ||
assert literal_types_match(blob1, blob2) is False | ||
|
||
|
||
def test_blob_type_mismatch(): | ||
blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1)) | ||
blob2 = LiteralType(blob=BlobType(format="json", dimensionality=1)) | ||
assert literal_types_match(blob1, blob2) is False | ||
|
||
|
||
def test_enum_type_match(): | ||
enum1 = LiteralType(enum_type=EnumType(values=["A", "B"])) | ||
enum2 = LiteralType(enum_type=EnumType(values=["B", "A"])) | ||
assert literal_types_match(enum1, enum2) is True | ||
|
||
|
||
def test_enum_type_mismatch(): | ||
enum1 = LiteralType(enum_type=EnumType(values=["A", "B"])) | ||
enum2 = LiteralType(enum_type=EnumType(values=["A", "C"])) | ||
assert literal_types_match(enum1, enum2) is False | ||
|
||
|
||
def test_structured_dataset_match(): | ||
col1 = StructuredDatasetType.DatasetColumn(name="col1", literal_type=LiteralType(simple=SimpleType.STRING)) | ||
col2 = StructuredDatasetType.DatasetColumn(name="col2", literal_type=LiteralType(simple=SimpleType.STRUCT)) | ||
|
||
dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) | ||
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) | ||
assert literal_types_match(dataset1, dataset2) is True | ||
|
||
dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2])) | ||
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) | ||
assert literal_types_match(dataset1, dataset2) is False | ||
|
||
dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2])) | ||
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2])) | ||
assert literal_types_match(dataset1, dataset2) is True | ||
|
||
|
||
def test_structured_dataset_mismatch(): | ||
dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[])) | ||
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="csv", columns=[])) | ||
assert literal_types_match(dataset1, dataset2) is False | ||
|
||
|
||
def test_union_type_match(): | ||
union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)])) | ||
union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.INTEGER), LiteralType(SimpleType.STRING)])) | ||
assert literal_types_match(union1, union2) is True | ||
|
||
|
||
def test_union_type_mismatch(): | ||
union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)])) | ||
union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.BOOLEAN)])) | ||
assert literal_types_match(union1, union2) is False |