diff --git a/CHANGELOG.md b/CHANGELOG.md index abe050506bf69..38795e41a199a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,9 @@ ### Bug Fixes / Nits - chore: added a help message to makefile (#6861) +### Bug Fixes / Nits +- Fixed support for using SQLTableSchema context_str attribute (#6891) + ## [v0.7.6] - 2023-07-12 ### New Features diff --git a/docs/end_to_end_tutorials/structured_data/sql_guide.md b/docs/end_to_end_tutorials/structured_data/sql_guide.md index 3a011cd1bc8ca..985d3202c070a 100644 --- a/docs/end_to_end_tutorials/structured_data/sql_guide.md +++ b/docs/end_to_end_tutorials/structured_data/sql_guide.md @@ -70,6 +70,7 @@ construct natural language queries that are synthesized into SQL queries. Note that we need to specify the tables we want to use with this query engine. If we don't the query engine will pull all the schema context, which could overflow the context window of the LLM. + ```python query_engine = NLSQLTableQueryEngine( sql_database=sql_database, @@ -80,6 +81,7 @@ query_str = ( ) response = query_engine.query(query_str) ``` + This query engine should used in any case where you can specify the tables you want to query over beforehand, or the total size of all the table schema plus the rest of the prompt fits your context window. @@ -103,13 +105,15 @@ obj_index = ObjectIndex.from_objects( VectorStoreIndex, ) ``` + Here you can see we define our table_node_mapping, and a single SQLTableSchema with the "city_stats" table name. We pass these into the ObjectIndex constructor, along with the VectorStoreIndex class definition we want to use. This will give us a VectorStoreIndex where each Node contains table schema and other context information. You can also add any additional context information you'd like. + ```python -# manually set context text +# manually set extra context text city_stats_text = ( "This table gives information regarding the population and country of a given city.\n" "The user will query with codewords, where 'foo' corresponds to population and 'bar'" diff --git a/llama_index/indices/struct_store/sql_query.py b/llama_index/indices/struct_store/sql_query.py index afda693291c08..53cae3b1830bb 100644 --- a/llama_index/indices/struct_store/sql_query.py +++ b/llama_index/indices/struct_store/sql_query.py @@ -352,6 +352,12 @@ def _get_table_context(self, query_bundle: QueryBundle) -> str: else: raise ValueError(f"Unknown table type: {table}") table_info = self._sql_database.get_single_table_info(table_str) + + if self._context_query_kwargs.get(table_str, None) is not None: + table_opt_context = " The table description is: " + table_opt_context += self._context_query_kwargs[table_str] + table_info += table_opt_context + context_strs.append(table_info) else: @@ -359,6 +365,12 @@ def _get_table_context(self, query_bundle: QueryBundle) -> str: table_names = self._sql_database.get_usable_table_names() for table_name in table_names: table_info = self._sql_database.get_single_table_info(table_name) + + if self._context_query_kwargs.get(table_name, None) is not None: + table_opt_context = " The table description is: " + table_opt_context += self._context_query_kwargs[table_name] + table_info += table_opt_context + context_strs.append(table_info) tables_desc_str = "\n\n".join(context_strs) @@ -399,16 +411,21 @@ def _get_table_context(self, query_bundle: QueryBundle) -> str: Get tables schema + optional context as a single string. """ + context_strs = [] if self._context_str_prefix is not None: context_strs = [self._context_str_prefix] - # TODO: allow top-level context table_schema_objs = self._table_retriever.retrieve(query_bundle) - context_strs = [] for table_schema_obj in table_schema_objs: table_info = self._sql_database.get_single_table_info( table_schema_obj.table_name ) + + if table_schema_obj.context_str: + table_opt_context = " The table description is: " + table_opt_context += table_schema_obj.context_str + table_info += table_opt_context + context_strs.append(table_info) tables_desc_str = "\n\n".join(context_strs)