From 972163f8ea4b36e172f74a4c04dbb197723d50f1 Mon Sep 17 00:00:00 2001 From: Arun Sharma Date: Sat, 18 Jan 2025 10:36:39 -0800 Subject: [PATCH] sqlmodel: several bugfixes These were found by testing a more comprehensive use case: https://github.com/adsharma/fastapi-shopping/blob/main/models.py --- fquery/sqlmodel.py | 50 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/fquery/sqlmodel.py b/fquery/sqlmodel.py index 88daf43..2e24d13 100644 --- a/fquery/sqlmodel.py +++ b/fquery/sqlmodel.py @@ -64,7 +64,7 @@ def many_to_one(key_column=None, back_populates=None): if key_column is not None: ret.metadata["SQL"]["key_column"] = key_column if back_populates is not None: - ret.metadata["SQL"][back_populates] = back_populates + ret.metadata["SQL"]["back_populates"] = back_populates return ret @@ -134,8 +134,13 @@ def get_field_type(field, cls): type_class = field.type other_class = type_class.__args__[0] if has_many_to_one_relationship: - type_class = get_type_hints(cls)[field.name] - return Optional[other_class.__sqlmodel__] + try: + type_class = get_type_hints(cls)[field.name] + except NameError: + # TODO: log exception? + pass + else: + return Optional[other_class.__sqlmodel__] return field.type def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls): @@ -145,7 +150,16 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls): if has_relationship: if has_many_to_one_relationship: type_class = field.type - other_class = type_class.__args__[0].__sqlmodel__ + try: + type_class = get_type_hints(cls)[field.name] + except NameError: + # TODO: log exception? + pass + inner = type_class.__args__[0] + if isinstance(inner, ForwardRef): + # can't patch right now. Try at a later time via back_populates + return + other_class = inner.__sqlmodel__ old = other_class.__annotations__[back_populates] # Should be sqlalchemy.orm.base.Mapped[typing.List[ForwardRef('T')]] # replace it with Mapped[List[sqlmodel_cls]] @@ -156,17 +170,23 @@ def patch_back_populates_types(field, back_populates, cls, sqlmodel_cls): List[sqlmodel_cls] ] other_class.sqlmodel_rebuild() - else: - # Replace Optional['T'] with Optional[TSQLModel] - old = field.type - origin = get_origin(old) - inner = get_args(old) - if ( - origin == Union - and len(inner) - and inner[0] == ForwardRef(cls.__name__) - ): - sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls] + + # Replace Optional['T'] with Optional[TSQLModel] + old = field.type + origin = get_origin(old) + inner = get_args(old) + needs_rebuild = False + if origin == Union and len(inner) and inner[0] == ForwardRef(cls.__name__): + sqlmodel_cls.__annotations__[field.name] = Optional[sqlmodel_cls] + needs_rebuild = True + + # Replace Optional[T] with Optional[TSQLModel] if T is a dataclass + if origin == Union and len(inner) and is_dataclass(inner[0]): + sqlmodel_cls.__annotations__[field.name] = Optional[inner[0].__sqlmodel__] + needs_rebuild = True + + if needs_rebuild: + sqlmodel_cls.sqlmodel_rebuild() def default_table_name(clsname: str) -> str: return inflection.underscore(inflection.pluralize(clsname))