From 436d3c8c6d10984019a17bc33dfb395cf0c1a2de Mon Sep 17 00:00:00 2001
From: gilad12-coder <gilad.mo12@gmail.com>
Date: Sat, 11 Jan 2025 17:34:59 +0200
Subject: [PATCH 1/3] added the ability to process metadata feilds and add them
 directly to the prompt and added docstring throughtout the file

---
 dspy/adapters/chat_adapter.py | 215 ++++++++++++++++++++++++++++------
 1 file changed, 177 insertions(+), 38 deletions(-)

diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py
index 3f58fb0e8..64ac1c306 100644
--- a/dspy/adapters/chat_adapter.py
+++ b/dspy/adapters/chat_adapter.py
@@ -6,12 +6,13 @@
 import textwrap
 from collections.abc import Mapping
 from itertools import chain
-from typing import Any, Dict, List, Literal, NamedTuple, Union, get_args, get_origin
+from typing import Any, Dict, List, Literal, NamedTuple, Union, Type, get_args, get_origin
 
 import pydantic
 from pydantic import TypeAdapter
 from pydantic.fields import FieldInfo
 
+from dsp.adapters.base_template import Field
 from dspy.adapters.base import Adapter
 from dspy.adapters.utils import find_enum_member, format_field_value
 from dspy.signatures.field import OutputField
@@ -29,9 +30,26 @@ class FieldInfoWithName(NamedTuple):
 # Built-in field indicating that a chat turn has been completed.
 BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())
 
+# Constraints that can be applied to numeric fields.
+PERMITTED_CONSTRAINTS = {"gt", "lt", "ge", "le", "multiple_of", "allow_inf_nan"}
+
 
 class ChatAdapter(Adapter):
+
     def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
+        """
+        Creates a formatted list of messages to pass to the LLM as a prompt.
+
+        Args:
+            signature (Signature): The signature of the task.
+            demos (List[Dict[str, Any]]): A list of dictionaries, each containing a demonstration for how to perform the
+                task (i.e., mapping from input fields to output fields).
+            inputs: A dictionary containing the input fields for the task.
+
+        Returns:
+            A list of messages to pass to the LLM as a prompt. Each message is a dictionary with two keys: "role" (i.e.,
+                whether the message is from the user or the assistant) and "content" (i.e., the message text).
+        """
         messages: list[dict[str, Any]] = []
 
         # Extract demos where some of the output_fields are not filled in.
@@ -58,7 +76,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
         messages.append(format_turn(signature, inputs, role="user"))
         return messages
 
-    def parse(self, signature, completion):
+    def parse(self, signature: Signature, completion: str, _parse_values: bool = True):
         sections = [(None, [])]
 
         for line in completion.splitlines():
@@ -74,10 +92,10 @@ def parse(self, signature, completion):
         for k, v in sections:
             if (k not in fields) and (k in signature.output_fields):
                 try:
-                    fields[k] = parse_value(v, signature.output_fields[k].annotation)
+                    fields[k] = parse_value(v, signature.output_fields[k].annotation) if _parse_values else v
                 except Exception as e:
                     raise ValueError(
-                        f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```"
+                        f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n\n{v}\n"
                     )
 
         if fields.keys() != signature.output_fields.keys():
@@ -86,7 +104,29 @@ def parse(self, signature, completion):
         return fields
 
     # TODO(PR): Looks ok?
-    def format_finetune_data(self, signature, demos, inputs, outputs):
+    def format_finetune_data(
+            self,
+            signature: Signature,
+            demos: List[Dict[str, Any]],
+            inputs: Dict[str, Any],
+            outputs: Dict[str, Any]
+    ) -> Dict[str, List[Dict[str, Any]]]:
+        """
+        Formats the data for fine-tuning an LLM on a task.
+
+        Args:
+            signature (Signature): The signature of the task.
+            demos (List[Dict[str, Any]]): A list of dictionaries, each containing a demonstration for how to perform the
+                task (i.e., mapping from input fields to output fields).
+            inputs: A dictionary containing the input fields for the task.
+            outputs: A dictionary containing the output fields for the task.
+
+        Returns:
+            A dictionary containing the formatted data for fine-tuning an LLM on the task. The dictionary has a single
+                key, "messages", which maps to a list of messages to pass to the LLM as a prompt. Each message is a
+                dictionary with two keys: "role" (i.e., whether the message is from the user or the assistant) and
+                "content" (i.e., the message text).
+        """
         # Get system + user messages
         messages = self.format(signature, demos, inputs)
 
@@ -99,10 +139,30 @@ def format_finetune_data(self, signature, demos, inputs, outputs):
         # Wrap the messages in a dictionary with a "messages" key
         return dict(messages=messages)
 
-    def format_turn(self, signature, values, role, incomplete=False):
+    def format_turn(
+            self,
+            signature: Signature,
+            values: Dict[str, Any],
+            role: str,
+            incomplete: bool = False,
+    ) -> Dict[str, Any]:
+        """
+        Formats a single turn in a chat thread.
+
+        Args:
+            signature (Signature): The signature of the task.
+            values (Dict[str, Any]): A dictionary mapping field names to corresponding values.
+            role (str): The role of the message, which can be either "user" or "assistant".
+            incomplete (bool): If True, indicates that output field values are present in the set of specified values.
+                If False, indicates that values only contains input field values.
+
+        Returns:
+            A dictionary representing a single turn in a chat thread. The dictionary has two keys: "role" (i.e., whether
+                the message is from the user or the assistant) and "content" (i.e., the message text).
+        """
         return format_turn(signature, values, role, incomplete)
 
-    def format_fields(self, signature, values, role):
+    def format_fields(self, signature: Signature, values: Dict[str, Any], role: str) -> str:
         fields_with_values = {
             FieldInfoWithName(name=field_name, info=field_info): values.get(
                 field_name, "Not supplied for this particular example."
@@ -110,19 +170,23 @@ def format_fields(self, signature, values, role):
             for field_name, field_info in signature.fields.items()
             if field_name in values
         }
-
         return format_fields(fields_with_values)
 
 
-def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=True) -> Union[str, List[dict]]:
+def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text: bool = True) -> Union[str, List[dict]]:
     """
-    Formats the values of the specified fields according to the field's DSPy type (input or output),
-    annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
-    into a single string, which is is a multiline string if there are multiple fields.
+    Creates a formatted representation of the fields and their values.
+
+    Formats the values of the specified fields according to the field's DSPy type (input or output), annotation (e.g. str,
+    int, etc.), and the type of the value itself. Joins the formatted values into a single string, which is is a multiline
+    string if there are multiple fields.
 
     Args:
-      fields_with_values: A dictionary mapping information about a field to its corresponding
-                          value.
+        fields_with_values (Dict[FieldInforWithName, Any]): A dictionary mapping information about a field to its
+            corresponding value.
+        assume_text (bool): If True, assumes that the values are text and formats them as such. If False, formats the
+            values as a list of dictionaries.
+
     Returns:
       The joined formatted values of the fields, represented as a string or a list of dicts
     """
@@ -143,7 +207,17 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
         return output
 
 
-def parse_value(value, annotation):
+def parse_value(value: Any, annotation: Type) -> Any:
+    """
+    Parses a value according to the specified annotation.
+
+    Args:
+        value: The value to parse.
+        annotation: The type to which the value should be parsed.
+
+    Returns:
+        The parsed value.
+    """
     if annotation is str:
         return str(value)
 
@@ -163,10 +237,12 @@ def parse_value(value, annotation):
     return TypeAdapter(annotation).validate_python(parsed_value)
 
 
-def format_turn(signature, values, role, incomplete=False):
+def format_turn(signature: Signature, values: Dict[str, Any], role: str, incomplete: bool = False) -> Dict[str, Any]:
     """
-    Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
-    so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.
+    Constructs a new message ("turn") to append to a chat thread.
+
+    The message is carefully formatted so that it can instruct an LLM to generate responses conforming to the specified
+    DSPy signature.
 
     Args:
         signature: The DSPy signature to which future LLM responses should conform.
@@ -174,11 +250,11 @@ def format_turn(signature, values, role, incomplete=False):
             that should be included in the message.
         role: The role of the message, which can be either "user" or "assistant".
         incomplete: If True, indicates that output field values are present in the set of specified
-            ``values``. If False, indicates that ``values`` only contains input field values.
+            `values. If False, indicates that ``values` only contains input field values.
 
     Returns:
         A chat message that can be appended to a chat thread. The message contains two string fields:
-        ``role`` ("user" or "assistant") and ``content`` (the message text).
+        `role` ("user" or "assistant") and `content` (the message text).
     """
     fields_to_collapse = []
     content = []
@@ -229,8 +305,9 @@ def type_info(v):
                 {
                     "type": "text",
                     "text": "Respond with the corresponding output fields, starting with the field "
-                    + ", then ".join(f"`[[ ## {f} ## ]]`{type_info(v)}" for f, v in signature.output_fields.items())
-                    + ", and then ending with the marker for `[[ ## completed ## ]]`.",
+                            + ", then ".join(
+                        f"[[ ## {f} ## ]]{type_info(v)}" for f, v in signature.output_fields.items())
+                            + ", and then ending with the marker for [[ ## completed ## ]].",
                 }
             )
 
@@ -267,12 +344,13 @@ def type_info(v):
     return {"role": role, "content": collapsed_messages}
 
 
-def get_annotation_name(annotation):
+def get_annotation_name(annotation: Type) -> str:
+    """Returns the name of the annotation as a string."""
     origin = get_origin(annotation)
     args = get_args(annotation)
     if origin is None:
-        if hasattr(annotation, "__name__"):
-            return annotation.__name__
+        if hasattr(annotation, "_name_"):
+            return annotation._name_
         else:
             return str(annotation)
     else:
@@ -280,18 +358,74 @@ def get_annotation_name(annotation):
         return f"{get_annotation_name(origin)}[{args_str}]"
 
 
-def enumerate_fields(fields: dict) -> str:
+def _format_constraint(name: str, value: Union[str, float]) -> str:
+    """
+    Formats a constraint for a numeric field.
+
+    Args:
+        name: The name of the constraint.
+        value: The value of the constraint.
+
+    Returns:
+        The formatted constraint as a string.
+    """
+    constraints = {
+        'gt': f"greater than {value}",
+        'lt': f"less than {value}",
+        'ge': f"greater than or equal to {value}",
+        'le': f"less than or equal to {value}",
+        'multiple_of': f"a multiple of {value}",
+        'allow_inf_nan': "allows infinite and NaN values" if value else "no infinite or NaN values allowed"
+    }
+    return constraints.get(name, f"{name}={value}")
+
+
+def format_metadata_summary(field: pydantic.fields.FieldInfo) -> str:
+    """
+    Formats a summary of the metadata for a field."""
+    if not field.metadata:
+        return ""
+    metadata_parts = [str(meta) for meta in field.metadata]
+    if metadata_parts:
+        return f" [Metadata: {'; '.join(metadata_parts)}]"
+    return ""
+
+
+def format_metadata_constraints(field: Field) -> str:
+    """Formats the constraints for a field."""
+    if not hasattr(field, 'metadata') or not field.metadata:
+        return ""
+    formatted_constraints = []
+    for meta in field.metadata:
+        constraint_names = [name for name in dir(meta) if not name.startswith('_')]
+        for name in constraint_names:
+            if hasattr(meta, name) and name in PERMITTED_CONSTRAINTS:
+                value = getattr(meta, name)
+                formatted_constraints.append(_format_constraint(name, value))
+    if not formatted_constraints:
+        return ""
+    elif len(formatted_constraints) == 1:
+        return f" that is {formatted_constraints[0]}."
+    else:
+        *front, last = formatted_constraints
+        return f" that is {', '.join(front)} and {last}."
+
+
+def enumerate_fields(fields: Dict[str, FieldInfo]) -> str:
+    """Enumerates the fields in a signature."""
     parts = []
     for idx, (k, v) in enumerate(fields.items()):
-        parts.append(f"{idx+1}. `{k}`")
+        parts.append(f"{idx + 1}. {k}")
         parts[-1] += f" ({get_annotation_name(v.annotation)})"
         parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
-
+        metadata_info = format_metadata_summary(v)
+        if metadata_info:
+            parts[-1] += metadata_info
     return "\n".join(parts).strip()
 
 
-def move_type_to_front(d):
-    # Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.
+def move_type_to_front(d: Union[Dict, List, Any]) -> Union[Dict, List, Any]:
+    """Moves the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence."""
     if isinstance(d, Mapping):
         return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))}
     elif isinstance(d, list):
@@ -299,19 +433,22 @@ def move_type_to_front(d):
     return d
 
 
-def prepare_schema(type_):
-    schema = pydantic.TypeAdapter(type_).json_schema()
+def prepare_schema(type_: Type) -> Dict[str, Any]:
+    """Prepares a JSON schema for a given type."""
+    schema: Dict[str, Any] = pydantic.TypeAdapter(type_).json_schema()
     schema = move_type_to_front(schema)
     return schema
 
 
-def prepare_instructions(signature: SignatureMeta):
+def prepare_instructions(signature: SignatureMeta) -> str:
+    """Prepares the instructions for a signature."""
     parts = []
     parts.append("Your input fields are:\n" + enumerate_fields(signature.input_fields))
     parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
     parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")
 
-    def field_metadata(field_name, field_info):
+    def field_metadata(field_name: str, field_info: FieldInfo) -> str:
+        """Creates a formatted representation of a field's information and metadata."""
         field_type = field_info.annotation
 
         if get_dspy_field_type(field_info) == "input" or field_type is str:
@@ -320,9 +457,11 @@ def field_metadata(field_name, field_info):
             desc = "must be True or False"
         elif field_type in (int, float):
             desc = f"must be a single {field_type.__name__} value"
+            metadata_info = format_metadata_constraints(field_info)
+            if metadata_info: desc += metadata_info
         elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
             desc = f"must be one of: {'; '.join(field_type.__members__)}"
-        elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
+        elif hasattr(field_type, "_origin") and field_type.origin_ is Literal:
             desc = f"must be one of: {'; '.join([str(x) for x in field_type.__args__])}"
         else:
             desc = "must be pareseable according to the following JSON schema: "
@@ -331,7 +470,8 @@ def field_metadata(field_name, field_info):
         desc = (" " * 8) + f"# note: the value you produce {desc}" if desc else ""
         return f"{{{field_name}}}{desc}"
 
-    def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
+    def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]) -> str:
+        """Formats the fields from the signature for the instructions."""
         return format_fields(
             fields_with_values={
                 FieldInfoWithName(name=field_name, info=field_info): field_metadata(field_name, field_info)
@@ -346,5 +486,4 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
     instructions = textwrap.dedent(signature.instructions)
     objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
     parts.append(f"In adhering to this structure, your objective is: {objective}")
-
-    return "\n\n".join(parts).strip()
+    return "\n\n".join(parts).strip()
\ No newline at end of file

From cc2e2060857f237e16e93e7fd41396d0eb28c40a Mon Sep 17 00:00:00 2001
From: gilad12-coder <gilad.mo12@gmail.com>
Date: Sat, 11 Jan 2025 17:49:48 +0200
Subject: [PATCH 2/3] quick fix to get_annotation_name

---
 dspy/adapters/chat_adapter.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py
index 64ac1c306..858e97dab 100644
--- a/dspy/adapters/chat_adapter.py
+++ b/dspy/adapters/chat_adapter.py
@@ -349,8 +349,8 @@ def get_annotation_name(annotation: Type) -> str:
     origin = get_origin(annotation)
     args = get_args(annotation)
     if origin is None:
-        if hasattr(annotation, "_name_"):
-            return annotation._name_
+        if hasattr(annotation, "__name__"):
+            return annotation.__name__
         else:
             return str(annotation)
     else:

From 0a30136f3c8caa3a5b90c671805d542d6c51659b Mon Sep 17 00:00:00 2001
From: gilad12-coder <gilad.mo12@gmail.com>
Date: Sat, 11 Jan 2025 20:18:32 +0200
Subject: [PATCH 3/3] fixed some small formating issues

---
 dspy/adapters/chat_adapter.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py
index 858e97dab..d4cb1cc5e 100644
--- a/dspy/adapters/chat_adapter.py
+++ b/dspy/adapters/chat_adapter.py
@@ -250,7 +250,7 @@ def format_turn(signature: Signature, values: Dict[str, Any], role: str, incompl
             that should be included in the message.
         role: The role of the message, which can be either "user" or "assistant".
         incomplete: If True, indicates that output field values are present in the set of specified
-            `values. If False, indicates that ``values` only contains input field values.
+            `values`. If False, indicates that `values` only contains input field values.
 
     Returns:
         A chat message that can be appended to a chat thread. The message contains two string fields:
@@ -415,7 +415,7 @@ def enumerate_fields(fields: Dict[str, FieldInfo]) -> str:
     """Enumerates the fields in a signature."""
     parts = []
     for idx, (k, v) in enumerate(fields.items()):
-        parts.append(f"{idx + 1}. {k}")
+        parts.append(f"{idx+1}. `{k}`")
         parts[-1] += f" ({get_annotation_name(v.annotation)})"
         parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""
         metadata_info = format_metadata_summary(v)
@@ -461,7 +461,7 @@ def field_metadata(field_name: str, field_info: FieldInfo) -> str:
             if metadata_info: desc += metadata_info
         elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
             desc = f"must be one of: {'; '.join(field_type.__members__)}"
-        elif hasattr(field_type, "_origin") and field_type.origin_ is Literal:
+        elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
             desc = f"must be one of: {'; '.join([str(x) for x in field_type.__args__])}"
         else:
             desc = "must be pareseable according to the following JSON schema: "