Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempt to use more specific transformer when executing remote entities #3111

Merged
merged 8 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions flytekit/core/type_match_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from flytekit.models.core.types import EnumType
from flytekit.models.types import LiteralType, UnionType


def literal_types_match(downstream: LiteralType, upstream: LiteralType) -> bool:
"""
Returns if two LiteralTypes are the same.
Takes into account arbitrary ordering of enums and unions, otherwise just an equivalence check.
"""

# If the types are exactly the same, return True
if downstream == upstream:
return True

if downstream.collection_type:
if not upstream.collection_type:
return False
return literal_types_match(downstream.collection_type, upstream.collection_type)

if downstream.map_value_type:
if not upstream.map_value_type:
return False
return literal_types_match(downstream.map_value_type, upstream.map_value_type)

# Handle enum types
if downstream.enum_type and upstream.enum_type:
return _enum_types_match(downstream.enum_type, upstream.enum_type)

# Handle union types
if downstream.union_type and upstream.union_type:
return _union_types_match(downstream.union_type, upstream.union_type)

# If none of the above conditions are met, the types are not castable
return False


def _enum_types_match(downstream: EnumType, upstream: EnumType) -> bool:
return set(upstream.values) == set(downstream.values)


def _union_types_match(downstream: UnionType, upstream: UnionType) -> bool:
if len(downstream.variants) != len(upstream.variants):
return False

down_sorted = sorted(downstream.variants, key=lambda x: str(x))
up_sorted = sorted(upstream.variants, key=lambda x: str(x))

for downstream_variant, upstream_variant in zip(down_sorted, up_sorted):
if not literal_types_match(downstream_variant, upstream_variant):
return False

return True
16 changes: 15 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.type_match_checking import literal_types_match
from flytekit.core.workflow import PythonFunctionWorkflow, ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import (
Expand Down Expand Up @@ -251,6 +252,19 @@ def _get_pickled_target_dict(
return md5_bytes, pickled_target_dict


def better_guess_type_hint(input_val: typing.Any, target_literal_type: type_models.LiteralType) -> typing.Type:
"""
Try to be smarter about guessing the type of the input (and hence the transformer).
If the literal type from the transformer for type(v), matches the literal type of the interface, then we
can use type(). Otherwise, fall back to guess python type from the literal type.
"""
transformer = TypeEngine.get_transformer(type(input_val))
inferred_literal_type = transformer.get_literal_type(input_val)
if literal_types_match(inferred_literal_type, target_literal_type):
return type(input_val)
return TypeEngine.guess_python_type(target_literal_type)
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved


class FlyteRemote(object):
"""Main entrypoint for programmatically accessing a Flyte remote backend.

Expand Down Expand Up @@ -1514,7 +1528,7 @@ def _execute(
else:
if k not in type_hints:
try:
type_hints[k] = TypeEngine.guess_python_type(input_flyte_type_map[k].type)
type_hints[k] = better_guess_type_hint(v, input_flyte_type_map[k].type)
except ValueError:
logger.debug(f"Could not guess type for {input_flyte_type_map[k].type}, skipping...")
variable = entity.interface.inputs.get(k)
Expand Down
83 changes: 83 additions & 0 deletions tests/flytekit/unit/core/test_type_match_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
from flytekit.models.core.types import BlobType, EnumType
from flytekit.models.types import LiteralType, StructuredDatasetType, UnionType, SimpleType
from flytekit.core.type_match_checking import literal_types_match


def test_exact_match():
lt = LiteralType(simple=SimpleType.STRING)
assert literal_types_match(lt, lt) is True

lt2 = LiteralType(simple=SimpleType.FLOAT)
assert literal_types_match(lt, lt2) is False


def test_collection_type_match():
lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING))
lt2 = LiteralType(collection_type=LiteralType(SimpleType.STRING))
assert literal_types_match(lt1, lt2) is True


def test_collection_type_mismatch():
lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING))
lt2 = LiteralType(collection_type=LiteralType(SimpleType.INTEGER))
assert literal_types_match(lt1, lt2) is False


def test_blob_type_match():
blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1))
blob2 = LiteralType(blob=BlobType(format="csv", dimensionality=1))
assert literal_types_match(blob1, blob2) is True


def test_blob_type_mismatch():
blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1))
blob2 = LiteralType(blob=BlobType(format="json", dimensionality=1))
assert literal_types_match(blob1, blob2) is False


def test_enum_type_match():
enum1 = LiteralType(enum_type=EnumType(values=["A", "B"]))
enum2 = LiteralType(enum_type=EnumType(values=["B", "A"]))
assert literal_types_match(enum1, enum2) is True


def test_enum_type_mismatch():
enum1 = LiteralType(enum_type=EnumType(values=["A", "B"]))
enum2 = LiteralType(enum_type=EnumType(values=["A", "C"]))
assert literal_types_match(enum1, enum2) is False


def test_structured_dataset_match():
col1 = StructuredDatasetType.DatasetColumn(name="col1", literal_type=LiteralType(simple=SimpleType.STRING))
col2 = StructuredDatasetType.DatasetColumn(name="col2", literal_type=LiteralType(simple=SimpleType.STRUCT))

dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
assert literal_types_match(dataset1, dataset2) is True

dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
assert literal_types_match(dataset1, dataset2) is False

dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2]))
assert literal_types_match(dataset1, dataset2) is True


def test_structured_dataset_mismatch():
dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="csv", columns=[]))
assert literal_types_match(dataset1, dataset2) is False


def test_union_type_match():
union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)]))
union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.INTEGER), LiteralType(SimpleType.STRING)]))
assert literal_types_match(union1, union2) is True


def test_union_type_mismatch():
union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)]))
union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.BOOLEAN)]))
assert literal_types_match(union1, union2) is False
Loading