Skip to content

Commit

Permalink
Add PERSIST parameter, :settings command and modify indexing logic. C…
Browse files Browse the repository at this point in the history
…loses #8. Closes #9. Closes #10
  • Loading branch information
davidmezzetti committed Aug 9, 2024
1 parent 8a1c19a commit ab5803b
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 25 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,13 @@ The RAG application has a number of environment variables that can be set to con
| | | arm64 : [Mistral-7B-OpenOrca-GGUF](https://huggingface.co/TheBloke/Mistral-7B-OpenOrca-GGUF) |
| EMBEDDINGS | Sets the embeddings database path | [neuml/txtai-wikipedia-slim](https://huggingface.co/NeuML/txtai-wikipedia-slim) |
| DATA | Optionally sets the input data directory | None |
| PERSIST | Optionally persist embeddings index | None |
| TOPICSBATCH | Optionally batches topic LLM queries | None |

*Note: AWQ models are only supported on `x86-64` machines*

In the application, these settings can be shown by typing `:settings`.

See the following examples for setting this configuration with the Docker container. When running within a Python virtual environment, simply set these as environment variables.

### Llama 3.1 8B
Expand Down Expand Up @@ -148,4 +151,11 @@ docker run -d --gpus=all -it -p 8501:8501 -e EMBEDDINGS=neuml/arxiv neuml/rag
docker run -d --gpus=all -it -p 8501:8501 -e DATA=/data/path -v local/path:/data/path neuml/rag
```

### Persist embeddings and cache models

```
docker run -d --gpus=all -it -p 8501:8501 -e DATA=/data/path -e EMBEDDINGS=/data/embeddings \
-e PERSIST=/data/embeddings -e HF_HOME=/data/modelcache -v localdata:/data neuml/rag
```

See the documentation for the [LLM pipeline](https://neuml.github.io/txtai/pipeline/text/llm/) and [Embeddings](https://neuml.github.io/txtai/embeddings/) for more information.
110 changes: 85 additions & 25 deletions rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def deduplicate(self, graph, threshold):
labels[node], topics[label] = label, node
else:
# Copy edges to primary node
logger.debug(f"DUPLICATE {label} - {topicnames[pid]}")
logger.debug(f"DUPLICATE NODE: {label} - {topicnames[pid]}")
edges = graph.edges(node)
if edges:
for target, attributes in graph.edges(node).items():
Expand Down Expand Up @@ -335,33 +335,33 @@ def load(self):
# Embeddings database path
database = os.environ.get("EMBEDDINGS", "neuml/txtai-wikipedia-slim")

# Create a new embeddings database if:
# - A data path is provided OR
# - The database path is None
if data or not database:
# Create empty embeddings database
embeddings = Embeddings(
autoid="uuid5",
path="intfloat/e5-large",
instructions={"query": "query: ", "data": "passage: "},
content=True,
graph={"approximate": False, "minscore": 0.7},
)

# Index data directory, if provided
if data:
embeddings.index(self.stream(data))

# Create LLM-generated topics
self.infertopics(embeddings, 0)

else:
# Load existing model
# Check for existing index
if database:
logger.debug(f"LOAD INDEX: {database}")
embeddings = Embeddings()
if os.path.exists(database):
if embeddings.exists(database):
embeddings.load(database)
else:
elif not os.path.isabs(database) and embeddings.exists(
cloud={"provider": "huggingface-hub", "container": database}
):
embeddings.load(provider="huggingface-hub", container=database)
else:
logger.debug(f"NO INDEX FOUND: {database}")
embeddings = None

# Default embeddings index if not found
embeddings = embeddings if embeddings else self.create()

# Add content from data directory, if provided
if data:
logger.debug(f"INDEX DATA: {data}")
embeddings.upsert(self.stream(data))

# Create LLM-generated topics
self.infertopics(embeddings, 0)

# Save embeddings, if necessary
self.persist(embeddings)

return embeddings

Expand All @@ -382,6 +382,26 @@ def addurl(self, url):
# Create LLM-generated topics
self.infertopics(self.embeddings, start)

# Save embeddings, if necessary
self.persist(self.embeddings)

def create(self):
"""
Creates a new empty Embeddings index.
Returns:
Embeddings
"""

# Create empty embeddings database
return Embeddings(
autoid="uuid5",
path="intfloat/e5-large",
instructions={"query": "query: ", "data": "passage: "},
content=True,
graph={"approximate": False, "minscore": 0.7},
)

def stream(self, data):
"""
Runs a textractor pipeline and streams extracted content from a data directory.
Expand Down Expand Up @@ -440,6 +460,19 @@ def infertopics(self, embeddings, start):
if batch:
self.topics(embeddings, batch)

def persist(self, embeddings):
"""
Saves an embeddings index if the PERSIST parameter is set.
Args:
embeddings: embeddings to save
"""

persist = os.environ.get("PERSIST")
if persist:
logger.debug(f"SAVE INDEX: {persist}")
embeddings.save(persist)

def topics(self, embeddings, batch):
"""
Generates a batch of topics with a LLM. Topics are set directly on the embeddings
Expand Down Expand Up @@ -512,6 +545,28 @@ def instructions(self):

return instructions

def settings(self):
"""
Generates a message with current settings.
Returns:
settings
"""

# Generate config settings rows
config = "\n".join(
f"|{name}|{os.environ.get(name)}|"
for name in ["EMBEDDINGS", "DATA", "PERSIST", "LLM"]
if name
)

return (
"The following is a table with the current settings.\n"
f"|Name|Value|\n"
f"|----|-----|\n"
f"|RECORD COUNT|{self.embeddings.count()}|\n"
) + config

def run(self):
"""
Runs a Streamlit application.
Expand Down Expand Up @@ -550,6 +605,11 @@ def run(self):
response = f"Added _{url}_ to index"
st.write(response)

# Show settings
elif question == ":settings":
response = self.settings()
st.write(response)

else:
# Check for Graph RAG
graph = GraphContext(self.embeddings, self.context)
Expand Down

0 comments on commit ab5803b

Please sign in to comment.