From 436d3c8c6d10984019a17bc33dfb395cf0c1a2de Mon Sep 17 00:00:00 2001 From: gilad12-coder <gilad.mo12@gmail.com> Date: Sat, 11 Jan 2025 17:34:59 +0200 Subject: [PATCH 1/3] added the ability to process metadata feilds and add them directly to the prompt and added docstring throughtout the file --- dspy/adapters/chat_adapter.py | 215 ++++++++++++++++++++++++++++------ 1 file changed, 177 insertions(+), 38 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 3f58fb0e8..64ac1c306 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -6,12 +6,13 @@ import textwrap from collections.abc import Mapping from itertools import chain -from typing import Any, Dict, List, Literal, NamedTuple, Union, get_args, get_origin +from typing import Any, Dict, List, Literal, NamedTuple, Union, Type, get_args, get_origin import pydantic from pydantic import TypeAdapter from pydantic.fields import FieldInfo +from dsp.adapters.base_template import Field from dspy.adapters.base import Adapter from dspy.adapters.utils import find_enum_member, format_field_value from dspy.signatures.field import OutputField @@ -29,9 +30,26 @@ class FieldInfoWithName(NamedTuple): # Built-in field indicating that a chat turn has been completed. BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField()) +# Constraints that can be applied to numeric fields. +PERMITTED_CONSTRAINTS = {"gt", "lt", "ge", "le", "multiple_of", "allow_inf_nan"} + class ChatAdapter(Adapter): + def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + """ + Creates a formatted list of messages to pass to the LLM as a prompt. + + Args: + signature (Signature): The signature of the task. + demos (List[Dict[str, Any]]): A list of dictionaries, each containing a demonstration for how to perform the + task (i.e., mapping from input fields to output fields). + inputs: A dictionary containing the input fields for the task. + + Returns: + A list of messages to pass to the LLM as a prompt. Each message is a dictionary with two keys: "role" (i.e., + whether the message is from the user or the assistant) and "content" (i.e., the message text). + """ messages: list[dict[str, Any]] = [] # Extract demos where some of the output_fields are not filled in. @@ -58,7 +76,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict messages.append(format_turn(signature, inputs, role="user")) return messages - def parse(self, signature, completion): + def parse(self, signature: Signature, completion: str, _parse_values: bool = True): sections = [(None, [])] for line in completion.splitlines(): @@ -74,10 +92,10 @@ def parse(self, signature, completion): for k, v in sections: if (k not in fields) and (k in signature.output_fields): try: - fields[k] = parse_value(v, signature.output_fields[k].annotation) + fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v except Exception as e: raise ValueError( - f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```" + f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n\n{v}\n" ) if fields.keys() != signature.output_fields.keys(): @@ -86,7 +104,29 @@ def parse(self, signature, completion): return fields # TODO(PR): Looks ok? - def format_finetune_data(self, signature, demos, inputs, outputs): + def format_finetune_data( + self, + signature: Signature, + demos: List[Dict[str, Any]], + inputs: Dict[str, Any], + outputs: Dict[str, Any] + ) -> Dict[str, List[Dict[str, Any]]]: + """ + Formats the data for fine-tuning an LLM on a task. + + Args: + signature (Signature): The signature of the task. + demos (List[Dict[str, Any]]): A list of dictionaries, each containing a demonstration for how to perform the + task (i.e., mapping from input fields to output fields). + inputs: A dictionary containing the input fields for the task. + outputs: A dictionary containing the output fields for the task. + + Returns: + A dictionary containing the formatted data for fine-tuning an LLM on the task. The dictionary has a single + key, "messages", which maps to a list of messages to pass to the LLM as a prompt. Each message is a + dictionary with two keys: "role" (i.e., whether the message is from the user or the assistant) and + "content" (i.e., the message text). + """ # Get system + user messages messages = self.format(signature, demos, inputs) @@ -99,10 +139,30 @@ def format_finetune_data(self, signature, demos, inputs, outputs): # Wrap the messages in a dictionary with a "messages" key return dict(messages=messages) - def format_turn(self, signature, values, role, incomplete=False): + def format_turn( + self, + signature: Signature, + values: Dict[str, Any], + role: str, + incomplete: bool = False, + ) -> Dict[str, Any]: + """ + Formats a single turn in a chat thread. + + Args: + signature (Signature): The signature of the task. + values (Dict[str, Any]): A dictionary mapping field names to corresponding values. + role (str): The role of the message, which can be either "user" or "assistant". + incomplete (bool): If True, indicates that output field values are present in the set of specified values. + If False, indicates that values only contains input field values. + + Returns: + A dictionary representing a single turn in a chat thread. The dictionary has two keys: "role" (i.e., whether + the message is from the user or the assistant) and "content" (i.e., the message text). + """ return format_turn(signature, values, role, incomplete) - def format_fields(self, signature, values, role): + def format_fields(self, signature: Signature, values: Dict[str, Any], role: str) -> str: fields_with_values = { FieldInfoWithName(name=field_name, info=field_info): values.get( field_name, "Not supplied for this particular example." @@ -110,19 +170,23 @@ def format_fields(self, signature, values, role): for field_name, field_info in signature.fields.items() if field_name in values } - return format_fields(fields_with_values) -def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]: +def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text: bool = True) -> Union[str, List[dict]]: """ - Formats the values of the specified fields according to the field's DSPy type (input or output), - annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values - into a single string, which is is a multiline string if there are multiple fields. + Creates a formatted representation of the fields and their values. + + Formats the values of the specified fields according to the field's DSPy type (input or output), annotation (e.g. str, + int, etc.), and the type of the value itself. Joins the formatted values into a single string, which is is a multiline + string if there are multiple fields. Args: - fields_with_values: A dictionary mapping information about a field to its corresponding - value. + fields_with_values (Dict[FieldInforWithName, Any]): A dictionary mapping information about a field to its + corresponding value. + assume_text (bool): If True, assumes that the values are text and formats them as such. If False, formats the + values as a list of dictionaries. + Returns: The joined formatted values of the fields, represented as a string or a list of dicts """ @@ -143,7 +207,17 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text= return output -def parse_value(value, annotation): +def parse_value(value: Any, annotation: Type) -> Any: + """ + Parses a value according to the specified annotation. + + Args: + value: The value to parse. + annotation: The type to which the value should be parsed. + + Returns: + The parsed value. + """ if annotation is str: return str(value) @@ -163,10 +237,12 @@ def parse_value(value, annotation): return TypeAdapter(annotation).validate_python(parsed_value) -def format_turn(signature, values, role, incomplete=False): +def format_turn(signature: Signature, values: Dict[str, Any], role: str, incomplete: bool = False) -> Dict[str, Any]: """ - Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted - so that it can instruct an LLM to generate responses conforming to the specified DSPy signature. + Constructs a new message ("turn") to append to a chat thread. + + The message is carefully formatted so that it can instruct an LLM to generate responses conforming to the specified + DSPy signature. Args: signature: The DSPy signature to which future LLM responses should conform. @@ -174,11 +250,11 @@ def format_turn(signature, values, role, incomplete=False): that should be included in the message. role: The role of the message, which can be either "user" or "assistant". incomplete: If True, indicates that output field values are present in the set of specified - ``values``. If False, indicates that ``values`` only contains input field values. + `values. If False, indicates that ``values` only contains input field values. Returns: A chat message that can be appended to a chat thread. The message contains two string fields: - ``role`` ("user" or "assistant") and ``content`` (the message text). + `role` ("user" or "assistant") and `content` (the message text). """ fields_to_collapse = [] content = [] @@ -229,8 +305,9 @@ def type_info(v): { "type": "text", "text": "Respond with the corresponding output fields, starting with the field " - + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items()) - + ", and then ending with the marker for `[[ ## completed ## ]]`.", + + ", then ".join( + f"[[ ## {f} ## ]]{type_info(v)}" for f, v in signature.output_fields.items()) + + ", and then ending with the marker for [[ ## completed ## ]].", } ) @@ -267,12 +344,13 @@ def type_info(v): return {"role": role, "content": collapsed_messages} -def get_annotation_name(annotation): +def get_annotation_name(annotation: Type) -> str: + """Returns the name of the annotation as a string.""" origin = get_origin(annotation) args = get_args(annotation) if origin is None: - if hasattr(annotation, "__name__"): - return annotation.__name__ + if hasattr(annotation, "_name_"): + return annotation._name_ else: return str(annotation) else: @@ -280,18 +358,74 @@ def get_annotation_name(annotation): return f"{get_annotation_name(origin)}[{args_str}]" -def enumerate_fields(fields: dict) -> str: +def _format_constraint(name: str, value: Union[str, float]) -> str: + """ + Formats a constraint for a numeric field. + + Args: + name: The name of the constraint. + value: The value of the constraint. + + Returns: + The formatted constraint as a string. + """ + constraints = { + 'gt': f"greater than {value}", + 'lt': f"less than {value}", + 'ge': f"greater than or equal to {value}", + 'le': f"less than or equal to {value}", + 'multiple_of': f"a multiple of {value}", + 'allow_inf_nan': "allows infinite and NaN values" if value else "no infinite or NaN values allowed" + } + return constraints.get(name, f"{name}={value}") + + +def format_metadata_summary(field: pydantic.fields.FieldInfo) -> str: + """ + Formats a summary of the metadata for a field.""" + if not field.metadata: + return "" + metadata_parts = [str(meta) for meta in field.metadata] + if metadata_parts: + return f" [Metadata: {'; '.join(metadata_parts)}]" + return "" + + +def format_metadata_constraints(field: Field) -> str: + """Formats the constraints for a field.""" + if not hasattr(field, 'metadata') or not field.metadata: + return "" + formatted_constraints = [] + for meta in field.metadata: + constraint_names = [name for name in dir(meta) if not name.startswith('_')] + for name in constraint_names: + if hasattr(meta, name) and name in PERMITTED_CONSTRAINTS: + value = getattr(meta, name) + formatted_constraints.append(_format_constraint(name, value)) + if not formatted_constraints: + return "" + elif len(formatted_constraints) == 1: + return f" that is {formatted_constraints[0]}." + else: + *front, last = formatted_constraints + return f" that is {', '.join(front)} and {last}." + + +def enumerate_fields(fields: Dict[str, FieldInfo]) -> str: + """Enumerates the fields in a signature.""" parts = [] for idx, (k, v) in enumerate(fields.items()): - parts.append(f"{idx+1}. `{k}`") + parts.append(f"{idx + 1}. {k}") parts[-1] += f" ({get_annotation_name(v.annotation)})" parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else "" - + metadata_info = format_metadata_summary(v) + if metadata_info: + parts[-1] += metadata_info return "\n".join(parts).strip() -def move_type_to_front(d): - # Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence. +def move_type_to_front(d: Union[Dict, List, Any]) -> Union[Dict, List, Any]: + """Moves the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.""" if isinstance(d, Mapping): return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))} elif isinstance(d, list): @@ -299,19 +433,22 @@ def move_type_to_front(d): return d -def prepare_schema(type_): - schema = pydantic.TypeAdapter(type_).json_schema() +def prepare_schema(type_: Type) -> Dict[str, Any]: + """Prepares a JSON schema for a given type.""" + schema: Dict[str, Any] = pydantic.TypeAdapter(type_).json_schema() schema = move_type_to_front(schema) return schema -def prepare_instructions(signature: SignatureMeta): +def prepare_instructions(signature: SignatureMeta) -> str: + """Prepares the instructions for a signature.""" parts = [] parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields)) parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields)) parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") - def field_metadata(field_name, field_info): + def field_metadata(field_name: str, field_info: FieldInfo) -> str: + """Creates a formatted representation of a field's information and metadata.""" field_type = field_info.annotation if get_dspy_field_type(field_info) == "input" or field_type is str: @@ -320,9 +457,11 @@ def field_metadata(field_name, field_info): desc = "must be True or False" elif field_type in (int, float): desc = f"must be a single {field_type.__name__} value" + metadata_info = format_metadata_constraints(field_info) + if metadata_info: desc += metadata_info elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): desc = f"must be one of: {'; '.join(field_type.__members__)}" - elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal: + elif hasattr(field_type, "_origin") and field_type.origin_ is Literal: desc = f"must be one of: {'; '.join([str(x) for x in field_type.__args__])}" else: desc = "must be pareseable according to the following JSON schema: " @@ -331,7 +470,8 @@ def field_metadata(field_name, field_info): desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else "" return f"{{{field_name}}}{desc}" - def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): + def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]) -> str: + """Formats the fields from the signature for the instructions.""" return format_fields( fields_with_values={ FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info) @@ -346,5 +486,4 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): instructions = textwrap.dedent(signature.instructions) objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) parts.append(f"In adhering to this structure, your objective is: {objective}") - - return "\n\n".join(parts).strip() + return "\n\n".join(parts).strip() \ No newline at end of file From cc2e2060857f237e16e93e7fd41396d0eb28c40a Mon Sep 17 00:00:00 2001 From: gilad12-coder <gilad.mo12@gmail.com> Date: Sat, 11 Jan 2025 17:49:48 +0200 Subject: [PATCH 2/3] quick fix to get_annotation_name --- dspy/adapters/chat_adapter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 64ac1c306..858e97dab 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -349,8 +349,8 @@ def get_annotation_name(annotation: Type) -> str: origin = get_origin(annotation) args = get_args(annotation) if origin is None: - if hasattr(annotation, "_name_"): - return annotation._name_ + if hasattr(annotation, "__name__"): + return annotation.__name__ else: return str(annotation) else: From 0a30136f3c8caa3a5b90c671805d542d6c51659b Mon Sep 17 00:00:00 2001 From: gilad12-coder <gilad.mo12@gmail.com> Date: Sat, 11 Jan 2025 20:18:32 +0200 Subject: [PATCH 3/3] fixed some small formating issues --- dspy/adapters/chat_adapter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 858e97dab..d4cb1cc5e 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -250,7 +250,7 @@ def format_turn(signature: Signature, values: Dict[str, Any], role: str, incompl that should be included in the message. role: The role of the message, which can be either "user" or "assistant". incomplete: If True, indicates that output field values are present in the set of specified - `values. If False, indicates that ``values` only contains input field values. + `values`. If False, indicates that `values` only contains input field values. Returns: A chat message that can be appended to a chat thread. The message contains two string fields: @@ -415,7 +415,7 @@ def enumerate_fields(fields: Dict[str, FieldInfo]) -> str: """Enumerates the fields in a signature.""" parts = [] for idx, (k, v) in enumerate(fields.items()): - parts.append(f"{idx + 1}. {k}") + parts.append(f"{idx+1}. `{k}`") parts[-1] += f" ({get_annotation_name(v.annotation)})" parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else "" metadata_info = format_metadata_summary(v) @@ -461,7 +461,7 @@ def field_metadata(field_name: str, field_info: FieldInfo) -> str: if metadata_info: desc += metadata_info elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): desc = f"must be one of: {'; '.join(field_type.__members__)}" - elif hasattr(field_type, "_origin") and field_type.origin_ is Literal: + elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal: desc = f"must be one of: {'; '.join([str(x) for x in field_type.__args__])}" else: desc = "must be pareseable according to the following JSON schema: "