Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multiturn history #7851

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def __init_subclass__(cls, **kwargs) -> None:
cls.format = with_callbacks(cls.format)
cls.parse = with_callbacks(cls.parse)

def __call__(self, lm, lm_kwargs, signature, demos, inputs):
inputs_ = self.format(signature, demos, inputs)
def __call__(self, lm, lm_kwargs, signature, demos, inputs, conversation_history=None):
inputs_ = self.format(signature, demos, inputs, conversation_history)
inputs_ = dict(prompt=inputs_) if isinstance(inputs_, str) else dict(messages=inputs_)

outputs = lm(**inputs_, **lm_kwargs)
Expand Down
40 changes: 30 additions & 10 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ class FieldInfoWithName(NamedTuple):


class ChatAdapter(Adapter):
def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
def format(
self,
signature: Signature,
demos: list[dict[str, Any]],
inputs: dict[str, Any],
conversation_history: list[dict[str, Any]],
) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []

# Extract demos where some of the output_fields are not filled in.
Expand All @@ -46,12 +52,19 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
]

demos = incomplete_demos + complete_demos
conversation_history = conversation_history or []

prepared_instructions = prepare_instructions(signature)
messages.append({"role": "system", "content": prepared_instructions})

# Add the few-shot examples
for demo in demos:
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))
# Add the chat history after few-shot examples
for message in conversation_history:
messages.append(format_turn(signature, message, role="user"))
messages.append(format_turn(signature, message, role="assistant"))

messages.append(format_turn(signature, inputs, role="user"))
messages = try_expand_image_tags(messages)
Expand Down Expand Up @@ -129,7 +142,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
return "\n\n".join(output).strip()


def format_turn(signature, values, role, incomplete=False):
def format_turn(signature, values, role, incomplete=False, is_conversation_history=False):
"""
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
Expand All @@ -140,22 +153,29 @@ def format_turn(signature, values, role, incomplete=False):
that should be included in the message.
role: The role of the message, which can be either "user" or "assistant".
incomplete: If True, indicates that output field values are present in the set of specified
``values``. If False, indicates that ``values`` only contains input field values.
`values`. If False, indicates that `values` only contains input field values. Only used if
`is_conversation_history` is False.
is_conversation_history: If True, indicates that the message is part of the chat history, otherwise
it is a demo (few-shot example).

Returns:
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
`role` ("user" or "assistant") and `content` (the message text).
"""
if role == "user":
fields = signature.input_fields
message_prefix = (
"This is an example of the task, though some input or output fields are not supplied." if incomplete else ""
)
if incomplete and not is_conversation_history:
message_prefix = "This is an example of the task, though some input or output fields are not supplied."
else:
message_prefix = ""
else:
# Add the completed field for the assistant turn
fields = {**signature.output_fields, BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info}
values = {**values, BuiltInCompletedOutputFieldInfo.name: ""}
# Add the completed field or chat history for the assistant turn
fields = {**signature.output_fields}
values = {**values}
message_prefix = ""
if not is_conversation_history:
fields.update({BuiltInCompletedOutputFieldInfo.name: BuiltInCompletedOutputFieldInfo.info})
values.update({BuiltInCompletedOutputFieldInfo.name: ""})

if not incomplete and not set(values).issuperset(fields.keys()):
raise ValueError(f"Expected {fields.keys()} but got {values.keys()}")
Expand Down
47 changes: 32 additions & 15 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from dspy.adapters.base import Adapter
from dspy.adapters.image_utils import Image
from dspy.adapters.utils import parse_value, format_field_value, get_annotation_name, serialize_for_json
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value, serialize_for_json
from dspy.signatures.signature import SignatureMeta
from dspy.signatures.utils import get_dspy_field_type

Expand All @@ -30,8 +30,8 @@ class JSONAdapter(Adapter):
def __init__(self):
pass

def __call__(self, lm, lm_kwargs, signature, demos, inputs):
inputs = self.format(signature, demos, inputs)
def __call__(self, lm, lm_kwargs, signature, demos, inputs, conversation_history=None):
inputs = self.format(signature, demos, inputs, conversation_history)
inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs)

try:
Expand Down Expand Up @@ -65,7 +65,7 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs):

return values

def format(self, signature, demos, inputs):
def format(self, signature, demos, inputs, conversation_history=None):
messages = []

# Extract demos where some of the output_fields are not filled in.
Expand All @@ -78,13 +78,19 @@ def format(self, signature, demos, inputs):
]

demos = incomplete_demos + complete_demos
conversation_history = conversation_history or []

messages.append({"role": "system", "content": prepare_instructions(signature)})

for demo in demos:
messages.append(format_turn(signature, demo, role="user", incomplete=demo in incomplete_demos))
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))

# Add the chat history after few-shot examples
for message in conversation_history:
messages.append(format_turn(signature, message, role="user"))
messages.append(format_turn(signature, message, role="assistant"))

messages.append(format_turn(signature, inputs, role="user"))

return messages
Expand Down Expand Up @@ -163,7 +169,13 @@ def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -
return "\n\n".join(output).strip()


def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomplete=False) -> Dict[str, str]:
def format_turn(
signature: SignatureMeta,
values: Dict[str, Any],
role,
incomplete=False,
is_conversation_history=False,
) -> Dict[str, str]:
"""
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
Expand All @@ -174,7 +186,10 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple
that should be included in the message.
role: The role of the message, which can be either "user" or "assistant".
incomplete: If True, indicates that output field values are present in the set of specified
``values``. If False, indicates that ``values`` only contains input field values.
`values`. If False, indicates that `values` only contains input field values. Only
relevant if `is_conversation_history` is False.
is_conversation_history: If True, indicates that the message is part of a chat history instead of a
few-shot example.
Returns:
A chat message that can be appended to a chat thread. The message contains two string fields:
``role`` ("user" or "assistant") and ``content`` (the message text).
Expand All @@ -183,25 +198,27 @@ def format_turn(signature: SignatureMeta, values: Dict[str, Any], role, incomple

if role == "user":
fields: Dict[str, FieldInfo] = signature.input_fields
if incomplete:
if incomplete and not is_conversation_history:
content.append("This is an example of the task, though some input or output fields are not supplied.")
else:
fields: Dict[str, FieldInfo] = signature.output_fields

if not incomplete:
if not incomplete and not is_conversation_history:
# For complete few-shot examples, ensure that the values contain all the fields.
field_names: KeysView = fields.keys()
if not set(values).issuperset(set(field_names)):
raise ValueError(f"Expected {field_names} but got {values.keys()}")

formatted_fields = format_fields(
role=role,
fields_with_values={
FieldInfoWithName(name=field_name, info=field_info): values.get(
fields_with_values = {}
for field_name, field_info in fields.items():
if is_conversation_history:
fields_with_values[FieldInfoWithName(name=field_name, info=field_info)] = values.get(field_name, None)
else:
fields_with_values[FieldInfoWithName(name=field_name, info=field_info)] = values.get(
field_name, "Not supplied for this particular example."
)
for field_name, field_info in fields.items()
},
)

formatted_fields = format_fields(role=role, fields_with_values=fields_with_values)
content.append(formatted_fields)

if role == "user":
Expand Down
16 changes: 13 additions & 3 deletions dspy/predict/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def load_state(self, state):
# `excluded_keys` are fields that go through special handling.
if name not in excluded_keys:
setattr(self, name, value)

self.signature = self.signature.load_state(state["signature"])

if "extended_signature" in state: # legacy, up to and including 2.5, for CoT.
if "extended_signature" in state: # legacy, up to and including 2.5, for CoT.
raise NotImplementedError("Loading extended_signature is no longer supported in DSPy 2.6+")

return self
Expand All @@ -73,6 +73,7 @@ def forward(self, **kwargs):
assert "new_signature" not in kwargs, "new_signature is no longer a valid keyword argument."
signature = ensure_signature(kwargs.pop("signature", self.signature))
demos = kwargs.pop("demos", self.demos)
conversation_history = kwargs.pop("conversation_history", None)
config = dict(**self.config, **kwargs.pop("config", {}))

# Get the right LM to use.
Expand All @@ -93,8 +94,16 @@ def forward(self, **kwargs):
print(f"WARNING: Not all input fields were provided to module. Present: {present}. Missing: {missing}.")

import dspy

adapter = dspy.settings.adapter or dspy.ChatAdapter()
completions = adapter(lm, lm_kwargs=config, signature=signature, demos=demos, inputs=kwargs)
completions = adapter(
lm,
lm_kwargs=config,
signature=signature,
demos=demos,
conversation_history=conversation_history,
inputs=kwargs,
)

pred = Prediction.from_completions(completions, signature=signature)

Expand All @@ -113,6 +122,7 @@ def get_config(self):
def __repr__(self):
return f"{self.__class__.__name__}({self.signature})"


def serialize_object(obj):
"""
Recursively serialize a given object into a JSON-compatible format.
Expand Down
55 changes: 46 additions & 9 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import dspy
from dspy import Predict, Signature
from dspy.utils.dummies import DummyLM
from unittest.mock import patch, MagicMock, Mock


def test_initialization_with_string_signature():
Expand Down Expand Up @@ -87,13 +88,15 @@ class TranslateToEnglish(dspy.Signature):
# Demos don't need to keep the same types after saving and loading the state.
assert new_instance.demos[0]["content"] == original_instance.demos[0].content


def test_typed_demos_after_dump_and_load_state():
class Item(pydantic.BaseModel):
name: str
quantity: int

class InventorySignature(dspy.Signature):
"""Handle inventory items and their translations."""

items: list[Item] = dspy.InputField()
language: str = dspy.InputField()
translated_items: list[Item] = dspy.OutputField()
Expand All @@ -102,16 +105,10 @@ class InventorySignature(dspy.Signature):
original_instance = Predict(InventorySignature)
original_instance.demos = [
dspy.Example(
items=[
Item(name="apple", quantity=5),
Item(name="banana", quantity=3)
],
items=[Item(name="apple", quantity=5), Item(name="banana", quantity=3)],
language="SPANISH",
translated_items=[
Item(name="manzana", quantity=5),
Item(name="plátano", quantity=3)
],
total_quantity=8
translated_items=[Item(name="manzana", quantity=5), Item(name="plátano", quantity=3)],
total_quantity=8,
).with_inputs("items", "language"),
]

Expand Down Expand Up @@ -147,6 +144,7 @@ class InventorySignature(dspy.Signature):
assert loaded_demo["translated_items"][0]["name"] == "manzana"
assert loaded_demo["translated_items"][1]["name"] == "plátano"


# def test_typed_demos_after_dump_and_load_state():
# class TypedTranslateToEnglish(dspy.Signature):
# """Translate content from a language to English."""
Expand Down Expand Up @@ -403,3 +401,42 @@ def test_load_state_chaining():
new_instance = Predict("question -> answer").load_state(state)
assert new_instance is not None
assert new_instance.demos == original.demos


@pytest.mark.parametrize("adapter_type", ["chat", "json"])
def test_call_predict_with_chat_history(adapter_type):
class SpyLM(dspy.LM):
def __init__(self, *args, return_json=False, **kwargs):
super().__init__(*args, **kwargs)
self.calls = []
self.return_json = return_json

def __call__(self, prompt=None, messages=None, **kwargs):
self.calls.append({"prompt": prompt, "messages": messages, "kwargs": kwargs})
if self.return_json:
return ["{'answer':'100%'}"]
return ["[[ ## answer ## ]]\n100%!"]

program = Predict("question -> answer")

if adapter_type == "chat":
lm = SpyLM("dummy_model")
dspy.settings.configure(adapter=dspy.ChatAdapter(), lm=lm)
else:
lm = SpyLM("dummy_model", return_json=True)
dspy.settings.configure(adapter=dspy.JSONAdapter(), lm=lm)

program(
question="are you sure that's correct?",
conversation_history=[{"question": "what's the capital of france?", "answer": "paris"}],
)

# Verify the LM was called with correct messages
assert len(lm.calls) == 1
messages = lm.calls[0]["messages"]

assert len(messages) == 4

assert "what's the capital of france?" in messages[1]["content"]
assert "paris" in messages[2]["content"]
assert "are you sure that's correct" in messages[3]["content"]
Loading