Skip to content

Commit

Permalink
example: dspy demo with compile
Browse files Browse the repository at this point in the history
  • Loading branch information
AntiKnot committed Jul 10, 2024
1 parent 1d0ff1f commit f47dcde
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 18 deletions.
43 changes: 37 additions & 6 deletions examples/dspy-demo/example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import os
from functools import partial
from dotenv import load_dotenv
import dspy
from dotenv import load_dotenv
from dspy.datasets import HotPotQA
from dspy.evaluate import Evaluate
from dspy.teleprompt import BootstrapFewShot
from sentence_transformers import SentenceTransformer
from tidb_vector.integrations import TiDBVectorClient
from utils import sentence_transformer_embedding_function, TidbRM, RAG

Expand All @@ -12,11 +16,11 @@
transformer = {
# The name or path of the sentence-transformers model to use.
"model": "sentence-transformers/multi-qa-mpnet-base-dot-v1",
# The dimension of the vector generated by the embedding model.
"embed_model_dims": 768,
}

embedding_function = partial(sentence_transformer_embedding_function, transformer["model"])
embed_model = SentenceTransformer(transformer["model"], trust_remote_code=True)
embed_model_dim = embed_model.get_sentence_embedding_dimension()
embedding_function = partial(sentence_transformer_embedding_function, embed_model)

# The configuration for the TiDBVectorClient.
tidb_vector_client = TiDBVectorClient(
Expand All @@ -27,7 +31,7 @@
# mysql+pymysql://<USER>:<PASSWORD>@<HOST>:4000/<DATABASE>?ssl_ca=<CA_PATH>&ssl_verify_cert=true&ssl_verify_identity=true
connection_string=os.environ.get('TIDB_DATABASE_URL'),
# The dimension of the vector generated by the embedding model.
vector_dimension=transformer["embed_model_dims"],
vector_dimension=embed_model_dim,
# Determine whether to recreate the table if it already exists.
drop_existing_table=True,
)
Expand All @@ -51,7 +55,7 @@

print("Embedding sample data...")
documents = []
for idx, passage in enumerate(sample_data.split('\n')):
for idx, passage in enumerate(sample_data.split('\n')[:3]):
embedding = embedding_function([passage])[0]
print(idx, passage[:10], embedding[:5])
if len(passage) == 0:
Expand Down Expand Up @@ -83,6 +87,33 @@

rag = RAG(retriever_model)

dataset = HotPotQA(train_seed=1, train_size=2, eval_seed=2023, dev_size=5, test_size=0)
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
trainset = [x.with_inputs('question') for x in dataset.train]
devset = [x.with_inputs('question') for x in dataset.dev]

metric = dspy.evaluate.answer_exact_match
evaluate_on_hotpotqa = Evaluate(devset=devset[:], display_progress=True, display_table=False)
score = evaluate_on_hotpotqa(rag, metric=metric)
print('rag:', score)


# Validation logic: check that the predicted answer is correct.
# Also check that the retrieved context does contain that answer.
def validate_context_and_answer(example, pred, trace=None):
answer_em = dspy.evaluate.answer_exact_match(example, pred)
answer_pm = dspy.evaluate.answer_passage_match(example, pred)
return answer_em and answer_pm


# Set up a basic teleprompter, which will compile our RAG program.
teleprompter = BootstrapFewShot(metric=validate_context_and_answer)

# Compile!
compiled_rag = teleprompter.compile(rag, trainset=trainset)
# Now compiled_rag is optimized and ready to answer your new question!
score = evaluate_on_hotpotqa(compiled_rag, metric=metric)
print('compile_rag:', score)

if __name__ == '__main__':
print("Answering the question: 'who write At My Window'...")
Expand Down
24 changes: 12 additions & 12 deletions examples/dspy-demo/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from typing import List, Optional, Union
from dsp import dotdict
from typing import Union, List, Optional
import dspy
from dsp.utils import dotdict
from sentence_transformers import SentenceTransformer

from tidb_vector.integrations import TiDBVectorClient


def sentence_transformer_embedding_function(model: str, sentences: Union[str, List[str]]):
def sentence_transformer_embedding_function(
embed_model: SentenceTransformer,
sentences: Union[str, List[str]]
) -> List[float]:
"""
Generates vector embeddings for the given text using the sentence-transformers model.
Args:
model (str): The name or path of the sentence-transformers model to use.
embed_model (SentenceTransformer): The sentence-transformers model to use.
sentences (List[str]): A list of text sentences for which to generate embeddings.
Returns:
Expand All @@ -20,11 +22,11 @@ def sentence_transformer_embedding_function(model: str, sentences: Union[str, Li
Examples:
Below is a code snippet that shows how to use this function:
```python
embeddings = sentence_transformer_embedding_function("sentence-transformers/multi-qa-mpnet-base-dot-v1", ["Hello, world!"])
embeddings = sentence_transformer_embedding_function(["Hello, world!"])
```
"""
embed_model = SentenceTransformer(model, trust_remote_code=True)
return embed_model.encode(sentences)

return embed_model.encode(sentences).tolist()


class TidbRM(dspy.Retrieve):
Expand Down Expand Up @@ -66,7 +68,7 @@ def __init__(self, tidb_vector_client: TiDBVectorClient, embedding_function: Opt
super().__init__(k)
self.tidb_vector_client = tidb_vector_client
self.embedding_function = embedding_function
self.k = k
self.top_k = k

def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None, **kwargs) -> dspy.Prediction:
"""
Expand All @@ -85,10 +87,8 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
passages = self.retrieve("Hello, world!")
```
"""
if self.embedding_function is None:
raise ValueError("embedding_function is required to use TidbRM")

query_embeddings = self.embedding_function(query_or_queries)
k = k or self.top_k
tidb_vector_res = self.tidb_vector_client.query(query_vector=query_embeddings, k=k)
passages_scores = {}
for res in tidb_vector_res:
Expand Down

0 comments on commit f47dcde

Please sign in to comment.