Skip to content

Commit

Permalink
Attempt to use more specific transformer when executing remote entiti…
Browse files Browse the repository at this point in the history
…es (#3111)

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Feb 11, 2025
1 parent d732d6f commit 7f58efe
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 6 deletions.
22 changes: 22 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from flytekit.core.context_manager import FlyteContext
from flytekit.core.hash import HashMethod
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.core.type_match_checking import literal_types_match
from flytekit.core.utils import load_proto_from_file, str2bool, timeit
from flytekit.exceptions import user as user_exceptions
from flytekit.interaction.string_literals import literal_map_string_repr
Expand Down Expand Up @@ -2415,6 +2416,27 @@ def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.
return cls(**constructor_inputs)


def strict_type_hint_matching(input_val: typing.Any, target_literal_type: LiteralType) -> typing.Type:
"""
Try to be smarter about guessing the type of the input (and hence the transformer).
If the literal type from the transformer for type(v), matches the literal type of the interface, then we
can use type(). Otherwise, fall back to guess python type from the literal type.
Raises ValueError, like in case of [1,2,3] type() will just give `list`, which won't work.
Raises ValueError also if the transformer found for the raw type doesn't have a literal type match.
"""
native_type = type(input_val)
transformer: TypeTransformer = TypeEngine.get_transformer(native_type)
inferred_literal_type = transformer.get_literal_type(native_type)
# note: if no good match, transformer will be the pickle transformer, but type will not match unless it's the
# pickle type so will fall back to normal guessing
if literal_types_match(inferred_literal_type, target_literal_type):
return type(input_val)

raise ValueError(
f"Transformer for {native_type} returned literal type {inferred_literal_type} which doesn't match {target_literal_type}"
)


def _check_and_covert_float(lv: Literal) -> float:
if lv.scalar.primitive.float_value is not None:
return lv.scalar.primitive.float_value
Expand Down
54 changes: 54 additions & 0 deletions flytekit/core/type_match_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from flytekit.models.core.types import EnumType
from flytekit.models.types import LiteralType, UnionType


def literal_types_match(downstream: LiteralType, upstream: LiteralType) -> bool:
"""
Returns if two LiteralTypes are the same.
Takes into account arbitrary ordering of enums and unions, otherwise just an equivalence check.
"""

# If the types are exactly the same, return True
if downstream == upstream:
return True

if downstream.collection_type:
if not upstream.collection_type:
return False
return literal_types_match(downstream.collection_type, upstream.collection_type)

if downstream.map_value_type:
if not upstream.map_value_type:
return False
return literal_types_match(downstream.map_value_type, upstream.map_value_type)

# Handle enum types
if downstream.enum_type and upstream.enum_type:
return _enum_types_match(downstream.enum_type, upstream.enum_type)

# Handle union types
if downstream.union_type and upstream.union_type:
return _union_types_match(downstream.union_type, upstream.union_type)

# If none of the above conditions are met, the types are not castable
return False


def _enum_types_match(downstream: EnumType, upstream: EnumType) -> bool:
return set(upstream.values) == set(downstream.values)


def _union_types_match(downstream: UnionType, upstream: UnionType) -> bool:
if len(downstream.variants) != len(upstream.variants):
return False

down_sorted = sorted(downstream.variants, key=lambda x: str(x))
up_sorted = sorted(upstream.variants, key=lambda x: str(x))

for downstream_variant, upstream_variant in zip(down_sorted, up_sorted):
if not literal_types_match(downstream_variant, upstream_variant):
return False

return True
9 changes: 6 additions & 3 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from flytekit.core.reference_entity import ReferenceEntity, ReferenceSpec
from flytekit.core.task import ReferenceTask
from flytekit.core.tracker import extract_task_module
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.type_engine import LiteralsResolver, TypeEngine, strict_type_hint_matching
from flytekit.core.workflow import PythonFunctionWorkflow, ReferenceWorkflow, WorkflowBase, WorkflowFailurePolicy
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import (
Expand Down Expand Up @@ -1514,9 +1514,12 @@ def _execute(
else:
if k not in type_hints:
try:
type_hints[k] = TypeEngine.guess_python_type(input_flyte_type_map[k].type)
type_hints[k] = strict_type_hint_matching(v, input_flyte_type_map[k].type)
except ValueError:
logger.debug(f"Could not guess type for {input_flyte_type_map[k].type}, skipping...")
developer_logger.debug(
f"Could not guess type for {input_flyte_type_map[k].type}, skipping..."
)
type_hints[k] = TypeEngine.guess_python_type(input_flyte_type_map[k].type)
variable = entity.interface.inputs.get(k)
hint = type_hints[k]
self.file_access._get_upload_signed_url_fn = functools.partial(
Expand Down
47 changes: 44 additions & 3 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum, auto
from typing import List, Optional, Type
from typing import List, Optional, Type, Dict

import mock
import msgpack
Expand All @@ -34,6 +34,10 @@
from flytekit.core.data_persistence import flyte_tmp_dir
from flytekit.core.hash import HashMethod
from flytekit.core.type_engine import (
IntTransformer,
FloatTransformer,
BoolTransformer,
StrTransformer,
DataclassTransformer,
DictTransformer,
EnumTransformer,
Expand All @@ -48,9 +52,9 @@
convert_mashumaro_json_schema_to_python_class,
dataclass_from_dict,
get_underlying_type,
is_annotated, IntTransformer,
is_annotated,
strict_type_hint_matching,
)
from flytekit.core.type_engine import *
from flytekit.exceptions import user as user_exceptions
from flytekit.models import types as model_types
from flytekit.models.annotation import TypeAnnotation
Expand Down Expand Up @@ -3777,3 +3781,40 @@ class RegularDC:
assert TypeEngine.get_transformer(RegularDC) == TypeEngine._DATACLASS_TRANSFORMER

del TypeEngine._REGISTRY[ParentDC]


def test_strict_type_matching():
# should correctly return the more specific transformer
class MyInt:
def __init__(self, x: int):
self.val = x

def __eq__(self, other):
if not isinstance(other, MyInt):
return False
return other.val == self.val

lt = LiteralType(simple=SimpleType.INTEGER)
TypeEngine.register(
SimpleTransformer(
"MyInt",
MyInt,
lt,
lambda x: Literal(scalar=Scalar(primitive=Primitive(integer=x.val))),
lambda x: MyInt(x.scalar.primitive.integer),
)
)

pt_guess = IntTransformer.guess_python_type(lt)
assert pt_guess is int
pt_better_guess = strict_type_hint_matching(MyInt(3), lt)
assert pt_better_guess is MyInt

del TypeEngine._REGISTRY[MyInt]


def test_strict_type_matching_error():
xs: typing.List[float] = [0.1, 0.2, 0.3, 0.4, -99999.7]
lt = TypeEngine.to_literal_type(typing.List[float])
with pytest.raises(ValueError):
strict_type_hint_matching(xs, lt)
87 changes: 87 additions & 0 deletions tests/flytekit/unit/core/test_type_match_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from flytekit.models.core.types import BlobType, EnumType
from flytekit.models.types import LiteralType, StructuredDatasetType, UnionType, SimpleType
from flytekit.core.type_match_checking import literal_types_match


def test_exact_match():
lt = LiteralType(simple=SimpleType.STRING)
assert literal_types_match(lt, lt) is True

lt2 = LiteralType(simple=SimpleType.FLOAT)
assert literal_types_match(lt, lt2) is False


def test_collection_type_match():
lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING))
lt2 = LiteralType(collection_type=LiteralType(SimpleType.STRING))
assert literal_types_match(lt1, lt2) is True


def test_collection_type_mismatch():
lt1 = LiteralType(collection_type=LiteralType(SimpleType.STRING))
lt2 = LiteralType(collection_type=LiteralType(SimpleType.INTEGER))
assert literal_types_match(lt1, lt2) is False


def test_blob_type_match():
blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1))
blob2 = LiteralType(blob=BlobType(format="csv", dimensionality=1))
assert literal_types_match(blob1, blob2) is True

from flytekit.types.pickle.pickle import FlytePickleTransformer
blob1 = LiteralType(blob=BlobType(format=FlytePickleTransformer.PYTHON_PICKLE_FORMAT, dimensionality=1))
blob2 = LiteralType(blob=BlobType(format="", dimensionality=1))
assert literal_types_match(blob1, blob2) is False


def test_blob_type_mismatch():
blob1 = LiteralType(blob=BlobType(format="csv", dimensionality=1))
blob2 = LiteralType(blob=BlobType(format="json", dimensionality=1))
assert literal_types_match(blob1, blob2) is False


def test_enum_type_match():
enum1 = LiteralType(enum_type=EnumType(values=["A", "B"]))
enum2 = LiteralType(enum_type=EnumType(values=["B", "A"]))
assert literal_types_match(enum1, enum2) is True


def test_enum_type_mismatch():
enum1 = LiteralType(enum_type=EnumType(values=["A", "B"]))
enum2 = LiteralType(enum_type=EnumType(values=["A", "C"]))
assert literal_types_match(enum1, enum2) is False


def test_structured_dataset_match():
col1 = StructuredDatasetType.DatasetColumn(name="col1", literal_type=LiteralType(simple=SimpleType.STRING))
col2 = StructuredDatasetType.DatasetColumn(name="col2", literal_type=LiteralType(simple=SimpleType.STRUCT))

dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
assert literal_types_match(dataset1, dataset2) is True

dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
assert literal_types_match(dataset1, dataset2) is False

dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[col1, col2]))
assert literal_types_match(dataset1, dataset2) is True


def test_structured_dataset_mismatch():
dataset1 = LiteralType(structured_dataset_type=StructuredDatasetType(format="parquet", columns=[]))
dataset2 = LiteralType(structured_dataset_type=StructuredDatasetType(format="csv", columns=[]))
assert literal_types_match(dataset1, dataset2) is False


def test_union_type_match():
union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)]))
union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.INTEGER), LiteralType(SimpleType.STRING)]))
assert literal_types_match(union1, union2) is True


def test_union_type_mismatch():
union1 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.INTEGER)]))
union2 = LiteralType(union_type=UnionType(variants=[LiteralType(SimpleType.STRING), LiteralType(SimpleType.BOOLEAN)]))
assert literal_types_match(union1, union2) is False

0 comments on commit 7f58efe

Please sign in to comment.