Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Massively simplify join logic #729

Merged
merged 3 commits into from
Feb 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 91 additions & 100 deletions pydal/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,35 +626,51 @@ def _geoexpand(self, field, query_env):
field = field.st_astext()
return self.expand(field, query_env=query_env)

def _build_joins_for_select(self, tablenames, param):
if not isinstance(param, (tuple, list)):
param = [param]
def _build_joins_for_select(self, join_on_expr):
if not isinstance(join_on_expr, (tuple, list)):
join_on_expr = [join_on_expr]

implicit_joins = []
explicit_joins = []
tablemap = {}
for item in param:
if isinstance(item, Expression):
item = item.first

for t in join_on_expr:
if isinstance(t, Expression): # db.table.on(...)
explicit_joins.append(t)
item = t.first

# previously a user doing db.foo.with_alias("bar").on(db.baz.foo_id == db.foo.id)
# caused a CROSS JOIN to happen automatically, which is unintuitive.
# as a basic check, the (maybe) aliased table name just has to appear somewhere in the ON clause
if item._tablename not in str(t.second):
raise ValueError(
f"In join, table is aliased as: `{t.first}`\n"
f"but the same table is mentioned without alias in ON clause: `{t.second}`"
)

elif hasattr(t, "_tablename"):
implicit_joins.append(t)
item = t
else:
raise ValueError(f"Cannot join with {t}")

key = item._tablename
if tablemap.get(key, item) is not item:
raise ValueError("Name conflict in table list: %s" % key)
tablemap[key] = item
join_tables = [t._tablename for t in param if not isinstance(t, Expression)]
join_on = [t for t in param if isinstance(t, Expression)]
tables_to_merge = {}
for t in join_on:
tables_to_merge = merge_tablemaps(tables_to_merge, self.tables(t))
join_on_tables = [t.first._tablename for t in join_on]
for t in join_on_tables:
if t in tables_to_merge:
tables_to_merge.pop(t)
important_tablenames = join_tables + join_on_tables + list(tables_to_merge)
excluded = [t for t in tablenames if t not in important_tablenames]

un_joined = {}
for t in explicit_joins:
un_join_i = self.tables(t)
un_join_i.pop(t.first._tablename, None)
un_joined = merge_tablemaps(un_joined, un_join_i)

tablemap = merge_tablemaps(tablemap, un_joined)

return (
join_tables,
join_on,
tables_to_merge,
join_on_tables,
important_tablenames,
excluded,
implicit_joins,
explicit_joins,
list(un_joined),
tablemap,
)

Expand Down Expand Up @@ -705,93 +721,68 @@ def _select_wcols(
if self.can_select_for_update is False and for_update is True:
raise SyntaxError("invalid select attribute: for_update")
#: build joins (inner, left outer) and table names
if join:
(
# FIXME? ijoin_tables is never used
ijoin_tables,
ijoin_on,
itables_to_merge,
ijoin_on_tables,
iimportant_tablenames,
iexcluded,
itablemap,
) = self._build_joins_for_select(tablemap, join)
tablemap = merge_tablemaps(tablemap, itables_to_merge)
tablemap = merge_tablemaps(tablemap, itablemap)
if left:
jointypes = [
(join, self.dialect.join),
(left, self.dialect.left_join),
]
joins = []
base_table = None
cross_join = []
for joinexpr, joinfunc in jointypes:
if not joinexpr:
continue
(
join_tables,
join_on,
tables_to_merge,
join_on_tables,
important_tablenames,
excluded,
jtablemap,
) = self._build_joins_for_select(tablemap, left)
tablemap = merge_tablemaps(tablemap, tables_to_merge)
tablemap = merge_tablemaps(tablemap, jtablemap)
implicit_joins,
explicit_joins,
not_joined,
join_tablemap,
) = self._build_joins_for_select(joinexpr)
tablemap = merge_tablemaps(tablemap, join_tablemap)

if len(not_joined) > 0:
item = not_joined.pop(0)
if base_table is None:
base_table = item

cross_join.extend(not_joined)
joins.append(
(
joinfunc,
implicit_joins,
explicit_joins,
)
)

if base_table is None:
base_table = query_tables[0]

current_scope = outer_scoped + list(tablemap)
query_env = dict(current_scope=current_scope, parent_scope=outer_scoped)
#: prepare columns and expand fields
colnames = [self._colexpand(x, query_env) for x in fields]
sql_fields = ", ".join(self._geoexpand(x, query_env) for x in fields)
table_alias = lambda name: tablemap[name].query_name(outer_scoped)[0]
if join and not left:
cross_joins = iexcluded + list(itables_to_merge)
tokens = [table_alias(cross_joins[0])]
tokens.extend(
[
self.dialect.cross_join(table_alias(t), query_env)
for t in cross_joins[1:]
]
)
tokens.extend([self.dialect.join(t, query_env) for t in ijoin_on])
sql_t = " ".join(tokens)
elif not join and left:
cross_joins = excluded + list(tables_to_merge)
tokens = [table_alias(cross_joins[0])]
tokens.extend(
[
self.dialect.cross_join(table_alias(t), query_env)
for t in cross_joins[1:]
]
)
# FIXME: WTF? This is not correct syntax at least on PostgreSQL
if join_tables:
tokens.append(
self.dialect.left_join(
",".join([table_alias(t) for t in join_tables]), query_env
)
)
tokens.extend([self.dialect.left_join(t, query_env) for t in join_on])
sql_t = " ".join(tokens)
elif join and left:
all_tables_in_query = set(
important_tablenames + iimportant_tablenames + query_tables
)
tables_in_joinon = set(join_on_tables + ijoin_on_tables)
tables_not_in_joinon = list(
all_tables_in_query.difference(tables_in_joinon)
)
tokens = [table_alias(tables_not_in_joinon[0])]
tokens.extend(
[
self.dialect.cross_join(table_alias(t), query_env)
for t in tables_not_in_joinon[1:]
]
)
tokens.extend([self.dialect.join(t, query_env) for t in ijoin_on])
# FIXME: WTF? This is not correct syntax at least on PostgreSQL
if join_tables:
tokens.append(
self.dialect.left_join(
",".join([table_alias(t) for t in join_tables]), query_env
)
)
tokens.extend([self.dialect.left_join(t, query_env) for t in join_on])

if len(joins) > 0:
tokens = [table_alias(base_table)]
tokens += [
self.dialect.cross_join(table_alias(t), query_env) for t in cross_join
]

for joinfunc, implicit, explicit in joins:
# TODO: joins without ON condition (especially concatenated with a comma??)
# are rarely supported, and usually equivalent to cross join
# would it be better to also just make this explicitly cross joins?
# the current behaviour mirrors what was there before.
if len(implicit) > 0:
tokens += [
joinfunc(",".join(map(table_alias, implicit)), query_env)
]
tokens += [joinfunc(t, query_env) for t in explicit]
sql_t = " ".join(tokens)
else:
sql_t = ", ".join(table_alias(t) for t in query_tables)

#: expand query if needed
if query:
query = self.expand(query, query_env=query_env)
Expand Down