Skip to content

Commit

Permalink
Include context_str in string returned by _get_table_context (run-lla…
Browse files Browse the repository at this point in the history
  • Loading branch information
richardguinness authored Jul 13, 2023
1 parent baf2eca commit ba7df3f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
### Bug Fixes / Nits
- chore: added a help message to makefile (#6861)

### Bug Fixes / Nits
- Fixed support for using SQLTableSchema context_str attribute (#6891)

## [v0.7.6] - 2023-07-12

### New Features
Expand Down
6 changes: 5 additions & 1 deletion docs/end_to_end_tutorials/structured_data/sql_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ construct natural language queries that are synthesized into SQL queries.
Note that we need to specify the tables we want to use with this query engine.
If we don't the query engine will pull all the schema context, which could
overflow the context window of the LLM.

```python
query_engine = NLSQLTableQueryEngine(
sql_database=sql_database,
Expand All @@ -80,6 +81,7 @@ query_str = (
)
response = query_engine.query(query_str)
```

This query engine should used in any case where you can specify the tables you want
to query over beforehand, or the total size of all the table schema plus the rest of
the prompt fits your context window.
Expand All @@ -103,13 +105,15 @@ obj_index = ObjectIndex.from_objects(
VectorStoreIndex,
)
```

Here you can see we define our table_node_mapping, and a single SQLTableSchema with the
"city_stats" table name. We pass these into the ObjectIndex constructor, along with the
VectorStoreIndex class definition we want to use. This will give us a VectorStoreIndex where
each Node contains table schema and other context information. You can also add any additional
context information you'd like.

```python
# manually set context text
# manually set extra context text
city_stats_text = (
"This table gives information regarding the population and country of a given city.\n"
"The user will query with codewords, where 'foo' corresponds to population and 'bar'"
Expand Down
21 changes: 19 additions & 2 deletions llama_index/indices/struct_store/sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,25 @@ def _get_table_context(self, query_bundle: QueryBundle) -> str:
else:
raise ValueError(f"Unknown table type: {table}")
table_info = self._sql_database.get_single_table_info(table_str)

if self._context_query_kwargs.get(table_str, None) is not None:
table_opt_context = " The table description is: "
table_opt_context += self._context_query_kwargs[table_str]
table_info += table_opt_context

context_strs.append(table_info)

else:
# get all tables
table_names = self._sql_database.get_usable_table_names()
for table_name in table_names:
table_info = self._sql_database.get_single_table_info(table_name)

if self._context_query_kwargs.get(table_name, None) is not None:
table_opt_context = " The table description is: "
table_opt_context += self._context_query_kwargs[table_name]
table_info += table_opt_context

context_strs.append(table_info)

tables_desc_str = "\n\n".join(context_strs)
Expand Down Expand Up @@ -399,16 +411,21 @@ def _get_table_context(self, query_bundle: QueryBundle) -> str:
Get tables schema + optional context as a single string.
"""
context_strs = []
if self._context_str_prefix is not None:
context_strs = [self._context_str_prefix]

# TODO: allow top-level context
table_schema_objs = self._table_retriever.retrieve(query_bundle)
context_strs = []
for table_schema_obj in table_schema_objs:
table_info = self._sql_database.get_single_table_info(
table_schema_obj.table_name
)

if table_schema_obj.context_str:
table_opt_context = " The table description is: "
table_opt_context += table_schema_obj.context_str
table_info += table_opt_context

context_strs.append(table_info)

tables_desc_str = "\n\n".join(context_strs)
Expand Down

0 comments on commit ba7df3f

Please sign in to comment.