diff --git a/README.md b/README.md index 44222ce..00b6374 100644 --- a/README.md +++ b/README.md @@ -158,7 +158,7 @@ pip install st-supabase-connection 1. Import ```python - from st_supabase_connection import SupabaseConnection + from st_supabase_connection import SupabaseConnection, execute_query ``` 2. Initialize ```python @@ -170,7 +170,7 @@ pip install st-supabase-connection key="YOUR_SUPABASE_KEY", # not needed if provided as a streamlit secret ) ``` -3. Use in your app to query tables and files. Happy Streamlit-ing! :balloon: +3. Use in your app to query tables and files, and add authentication. Happy Streamlit-ing! :balloon: ## :ok_hand: Supported methods
@@ -197,7 +197,7 @@ pip install st-supabase-connection
Database
@@ -300,7 +300,7 @@ SyncBucket(id='new_bucket', name='new_bucket', owner='', public=True, created_at ### :file_cabinet: Database operations #### Simple query ```python ->>> st_supabase_client.query("*", table="countries", ttl=0).execute() +>>> execute_query(st_supabase_client.table("countries").select("*"), ttl=0) APIResponse( data=[ {"id": 1, "name": "Afghanistan"}, @@ -312,20 +312,25 @@ APIResponse( ``` #### Query with join ```python ->>> st_supabase_client.query("name, teams(name)", table="users", count="exact", ttl="1h").execute() +>>> execute_query( + st_supabase_client.table("users").select("name, teams(name)", count="exact"), + ttl="1h", + ) + APIResponse( data=[ {"name": "Kiran", "teams": [{"name": "Green"}, {"name": "Blue"}]}, {"name": "Evan", "teams": [{"name": "Blue"}]}, ], - count=None, + count=2, ) ``` #### Filter through foreign tables ```python ->>> st_supabase_client.query("name, countries(*)", count="exact", table="cities", ttl=None).eq( - "countries.name", "Curaçao" - ).execute() +>>> execute_query( + st_supabase_client.table("cities").select("name, countries(*)", count="exact").eq("countries.name", "Curaçao"), + ttl=None, + ) APIResponse( data=[ @@ -348,9 +353,13 @@ APIResponse( #### Insert rows ```python ->>> st_supabase_client.table("countries").insert( - [{"name": "Wakanda", "iso2": "WK"}, {"name": "Wadiya", "iso2": "WD"}], count="None" - ).execute() +>>> execute_query( + st_supabase_client.table("countries").insert( + [{"name": "Wakanda", "iso2": "WK"}, {"name": "Wadiya", "iso2": "WD"}], count="None" + ), + ttl=0, + ) + APIResponse( data=[ { diff --git a/demo/app.py b/demo/app.py index 25d2265..de5c448 100644 --- a/demo/app.py +++ b/demo/app.py @@ -4,7 +4,7 @@ import streamlit as st from st_social_media_links import SocialMediaIcons -from st_supabase_connection import SupabaseConnection, __version__ +from st_supabase_connection import SupabaseConnection, __version__, execute_query VERSION = __version__ @@ -53,7 +53,7 @@ ) if st.button( - "Clear the cache to fetch latest data🧹", + "Clear cache to fetch latest data🧹", use_container_width=True, type="primary", ): @@ -298,11 +298,11 @@ ) ttl = None if ttl == "" else ttl constructed_storage_query = f"""st_supabase.{operation}("{bucket_id}", {ttl=})""" - st.session_state["storage_disabled"] = False if bucket_id else True + st.session_state["storage_disabled"] = bool(not bucket_id) elif operation in ["delete_bucket", "empty_bucket"]: constructed_storage_query = f"""st_supabase.{operation}("{bucket_id}")""" - st.session_state["storage_disabled"] = False if bucket_id else True + st.session_state["storage_disabled"] = bool(not bucket_id) elif operation == "create_bucket": col1, col2, col3, col4 = st.columns(4) @@ -335,7 +335,7 @@ ) constructed_storage_query = f"""st_supabase.create_bucket('{bucket_id}',{name=},{file_size_limit=},allowed_mime_types={allowed_mime_types},{public=})""" - st.session_state["storage_disabled"] = False if bucket_id else True + st.session_state["storage_disabled"] = bool(not bucket_id) elif operation == "update_bucket": if bucket_id: @@ -412,7 +412,7 @@ constructed_storage_query = f""" st_supabase.{operation}("{bucket_id}", {source=}, {file=}, destination_path="{destination_path}", {overwrite=}) """ - st.session_state["storage_disabled"] = False if all([bucket_id, file]) else True + st.session_state["storage_disabled"] = bool(not all([bucket_id, file])) elif operation == "list_buckets": ttl = st.text_input( "Results cache duration", @@ -441,7 +441,7 @@ constructed_storage_query = ( f"""st_supabase.{operation}("{bucket_id}", {source_path=}, {ttl=})""" ) - st.session_state["storage_disabled"] = False if all([bucket_id, source_path]) else True + st.session_state["storage_disabled"] = bool(not all([bucket_id, source_path])) elif operation == "move": from_path = st.text_input( @@ -457,9 +457,7 @@ f"""st_supabase.{operation}("{bucket_id}", {from_path=}, {to_path=})""" ) - st.session_state["storage_disabled"] = ( - False if all([bucket_id, from_path, to_path]) else True - ) + st.session_state["storage_disabled"] = bool(not all([bucket_id, from_path, to_path])) elif operation == "remove": paths = st.text_input( "Enter the paths of the objects in the bucket to remove", @@ -468,7 +466,7 @@ ) constructed_storage_query = f"""st_supabase.{operation}("{bucket_id}", paths={paths})""" - st.session_state["storage_disabled"] = False if all([bucket_id, paths]) else True + st.session_state["storage_disabled"] = bool(not all([bucket_id, paths])) elif operation == "list_objects": lcol, rcol = st.columns([3, 1]) path = lcol.text_input( @@ -797,7 +795,7 @@ placeholder="countries", ) - lcol, mcol, rcol = st.columns(3) + lcol, mcol, rcol = st.columns([2, 2, 3]) request_builder = lcol.selectbox( "Select the query type", options=["select", "insert", "upsert", "update", "delete"], @@ -820,33 +818,59 @@ placeholder = ( value ) = """[{"name":"Wakanda","iso2":"WK"},{"name":"Wadiya","iso2":"WD"}]""" - upsert = rcol_placeholder.checkbox( + rcol1, rcol2 = rcol_placeholder.columns(2) + ttl = rcol1.text_input( + "Cache duration", + value=0, + placeholder=0, + help="Set as `0` to always fetch the latest results (recommended for DML), or leave blank to cache indefinitely.", + ) + upsert = rcol2.checkbox( label="Upsert", help="Whether the query should be an upsert", ) elif request_builder == "select": request_builder_query_label = "Enter the columns to fetch as comma-separated strings" - placeholder = value = "*" ttl = rcol_placeholder.text_input( "Result cache duration", - value=0, + value=None, placeholder=None, - help="Set as `0` to always fetch the latest results, and leave blank to cache indefinitely.", + help="Set as `0` to always fetch the latest results, or leave blank to cache indefinitely.", ) placeholder = value = "*" elif request_builder == "delete": request_builder_query_label = "Delete query" placeholder = value = "Delete does not take a request builder query" + ttl = rcol_placeholder.text_input( + "Results Cache duration", + value=0, + placeholder=0, + help="Set as `0` to always fetch the latest results (recommended for DML), or leave blank to cache indefinitely.", + ) elif request_builder == "upsert": request_builder_query_label = "Enter the rows to upsert as json (for single row) or array of jsons (for multiple rows)" placeholder = value = """{"name":"Wakanda","iso2":"WK", "continent":"Africa"}""" - ignore_duplicates = rcol_placeholder.checkbox( + rcol1, rcol2 = rcol_placeholder.columns(2) + ttl = rcol1.text_input( + "Cache duration", + value=0, + placeholder=0, + help="Set as `0` to always fetch the latest results (recommended for DML), or leave blank to cache indefinitely.", + ) + ignore_duplicates = rcol2.checkbox( label="Ignore duplicates", help="Whether duplicate rows should be ignored", ) elif request_builder == "update": request_builder_query_label = "Enter the rows to update as json (for single row) or array of jsons (for multiple rows)" placeholder = value = """{"iso3":"N/A","continent":"N/A"}""" + ttl = rcol_placeholder.text_input( + "Result cache duration", + value=0, + placeholder=0, + help="Set as `0` to always fetch the latest results (recommended for DML), or leave blank to cache indefinitely.", + ) + request_builder_query = st.text_input( label=request_builder_query_label, placeholder=placeholder, @@ -854,7 +878,6 @@ help="[RequestBuilder API reference](https://postgrest-py.readthedocs.io/en/latest/api/request_builders.html#postgrest.AsyncRequestBuilder)", disabled=request_builder == "delete", ) - if request_builder == "upsert" and not ignore_duplicates: on_conflict = st.text_input( label="Enter the columns to be considered UNIQUE in case of conflicts as comma-separated values", @@ -876,7 +899,7 @@ if request_builder not in ["insert", "update", "upsert"]: operators = st.text_input( - label="Chain any operators and filters you want 🔗", + label="Chain any modifiers and filters you want 🔗", value=""".eq("continent","Asia").order("name",desc=True).limit(5)""", placeholder=""".eq("continent","Asia").order("name",desc=True).limit(5)""", help="List of all available [operators](https://postgrest-py.readthedocs.io/en/latest/api/request_builders.html#postgrest.AsyncSelectRequestBuilder) and [filters](https://postgrest-py.readthedocs.io/en/latest/api/filters.html#postgrest.AsyncFilterRequestBuilder)", @@ -887,18 +910,16 @@ ttl = None if ttl == "" else ttl if operators: - if request_builder == "select": - constructed_db_query = f"""st_supabase.query({request_builder_query}, {table=}, {ttl=}){operators}.execute()""" - else: - constructed_db_query = f"""st_supabase.table("{table}").{request_builder}({request_builder_query}){operators}.execute()""" + constructed_db_query = ( + f"""execute_query(st_supabase.table("{table}").select({request_builder_query}){operators}, {ttl:=})""" + if request_builder == "select" + else f"""execute_query(st_supabase.table("{table}").{request_builder}({request_builder_query}){operators}, {ttl:=})""" + ) + elif request_builder == "select": + constructed_db_query = f"""execute_query(st_supabase.table("{table}").select({request_builder_query}), {ttl=})""" else: - if request_builder == "select": - constructed_db_query = ( - f"""st_supabase.query({request_builder_query}, {table=}, {ttl=}).execute()""" - ) - else: - constructed_db_query = f"""st_supabase.table("{table}").{request_builder}({request_builder_query}).execute()""" - st.write("**Constructed code**") + constructed_db_query = f"""execute_query(st_supabase.table("{table}").{request_builder}({request_builder_query}), {ttl=})""" + st.write("**Constructed query**") st.code(constructed_db_query) lcol, rcol = st.columns([2, 1]) @@ -909,7 +930,7 @@ ) if rcol.button( - "Run query 🏃", + "Execute query 🏃", use_container_width=True, type="primary", disabled=st.session_state["project"] == "demo" @@ -921,16 +942,16 @@ key="run_db_query", ): try: - data, count = eval(constructed_db_query) + response = eval(constructed_db_query) if count_method: st.write( - f"**{count[-1]}** rows {request_builder}ed. `count` does not take `limit` into account." + f"**{response.count}** rows {request_builder}ed. `count` does not take `limit` into account." ) if view == "Dataframe": - st.dataframe(data[-1], use_container_width=True) + st.dataframe(response.data, use_container_width=True) else: - st.write(data[-1]) + st.write(response.data) except ValueError: if count_method == "planned": st.error( @@ -1076,13 +1097,10 @@ else: raise Exception("No logged-in user session. Log in or sign up first.") - elif auth_operation == "sign_out": - auth_success_message = None - if auth_success_message: st.success(auth_success_message) - if response != None: + if response is not None: with st.expander("JSON response"): st.write(response.dict()) diff --git a/demo/requirements.txt b/demo/requirements.txt index 7c837e9..f0ebfbd 100644 --- a/demo/requirements.txt +++ b/demo/requirements.txt @@ -1,2 +1,3 @@ st-social-media-links st_supabase_connection +streamlit<1.34 # TODO: Update app to remove components.v1 usage diff --git a/src/st_supabase_connection/__init__.py b/src/st_supabase_connection/__init__.py index ed30613..98b5603 100644 --- a/src/st_supabase_connection/__init__.py +++ b/src/st_supabase_connection/__init__.py @@ -4,14 +4,19 @@ from datetime import timedelta from io import BytesIO from pathlib import Path -from typing import Literal, Optional, Tuple, Union, types - -from postgrest import SyncSelectRequestBuilder, types +from typing import Literal, Optional, Tuple, Union + +from postgrest import ( + APIResponse, + SyncFilterRequestBuilder, + SyncQueryRequestBuilder, + SyncSelectRequestBuilder, +) from streamlit import cache_data, cache_resource from streamlit.connections import BaseConnection from supabase import Client, create_client -__version__ = "1.2.2" +__version__ = "2.0.0" class SupabaseConnection(BaseConnection[Client]): @@ -66,33 +71,6 @@ def _connect(self, **kwargs) -> None: self.delete_bucket = self.client.storage.delete_bucket self.empty_bucket = self.client.storage.empty_bucket - def query( - self, - *columns: str, - table: str, - count: Optional[types.CountMethod] = None, - ttl: Optional[Union[float, timedelta, str]] = None, - ) -> SyncSelectRequestBuilder: - """Run a SELECT query. - - Parameters - ---------- - *columns : str - The names of the columns to fetch. - table : str - The table to run the query on. - count : str - The method to use to get the count of rows returned. Defaults to `None`. - ttl : float, timedelta, str, or None - The maximum time to keep an entry in the cache. Defaults to `None` (cache never expires). - """ - - @cache_resource(ttl=ttl) - def _query(_self, *columns, table, count): - return _self.client.table(table).select(*columns, count=count) - - return _query(self, *columns, table=table, count=count) - def get_bucket( self, bucket_id: str, @@ -501,3 +479,38 @@ def upload_to_signed_url( ) return response.json() + + +def execute_query( + query: Union[SyncSelectRequestBuilder, SyncQueryRequestBuilder, SyncFilterRequestBuilder], + ttl: Optional[Union[float, timedelta, str]] = None, +) -> APIResponse: + """Execute the query. + This function is a wrapper around the `query.execute()` method, with caching enabled. + This works with all types of queries, but caching may lead to unexpected results when running DML queries. + + It is recommended to set `ttl` to 0 for DML queries (insert, update, upsert, delete) to avoid caching issues. + + Parameters + ---------- + query : SyncSelectRequestBuilder, SyncQueryRequestBuilder, SyncFilterRequestBuilder + The query to execute. Can contain any number of chained filters and operators. + ttl : float, timedelta, str, or None + The maximum time to keep an entry in the cache. Defaults to `None` (cache never expires). + """ + + def _hash_func(x): + return hash(x.path + str(x.params)) + + @cache_resource( + ttl=ttl, + hash_funcs={ + SyncSelectRequestBuilder: _hash_func, + SyncQueryRequestBuilder: _hash_func, + SyncFilterRequestBuilder: _hash_func, + }, + ) + def _execute(query): + return query.execute() + + return _execute(query)