Skip to content

Commit

Permalink
support anthropic reasonning in agent bridge (#1422)
Browse files Browse the repository at this point in the history
Co-authored-by: jjallaire <[email protected]>
  • Loading branch information
jjallaire-aisi and jjallaire authored Feb 27, 2025
1 parent 122cbc3 commit 2e2969d
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 36 deletions.
116 changes: 82 additions & 34 deletions src/inspect_ai/model/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from inspect_ai._util.url import is_http_url
from inspect_ai.model._call_tools import parse_tool_call
from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs
from inspect_ai.model._reasoning import parse_content_with_reasoning
from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo

from ._chat_message import (
Expand Down Expand Up @@ -154,14 +155,14 @@ async def openai_chat_message(
if message.tool_calls:
return ChatCompletionAssistantMessageParam(
role=message.role,
content=message.text,
content=openai_assistant_content(message),
tool_calls=[
openai_chat_tool_call_param(call) for call in message.tool_calls
],
)
else:
return ChatCompletionAssistantMessageParam(
role=message.role, content=message.text
role=message.role, content=openai_assistant_content(message)
)
elif message.role == "tool":
return ChatCompletionToolMessageParam(
Expand All @@ -181,16 +182,29 @@ async def openai_chat_messages(
return [await openai_chat_message(message, model) for message in messages]


def openai_assistant_content(message: ChatMessageAssistant) -> str:
if isinstance(message.content, str):
content = message.content
else:
content = ""
for c in message.content:
if c.type == "reasoning":
attribs = ""
if c.signature is not None:
attribs = f'{attribs} signature="{c.signature}"'
if c.redacted:
attribs = f'{attribs} redacted="true"'
content = f"{content}\n<think{attribs}>\n{c.reasoning}\n</think>\n"
elif c.type == "text":
content = f"{content}\n{c.text}"
return content


def openai_chat_choices(choices: list[ChatCompletionChoice]) -> list[Choice]:
oai_choices: list[Choice] = []

for index, choice in enumerate(choices):
if isinstance(choice.message.content, str):
content = choice.message.content
else:
content = "\n".join(
[c.text for c in choice.message.content if c.type == "text"]
)
content = openai_assistant_content(choice.message)
if choice.message.tool_calls:
tool_calls = [openai_chat_tool_call(tc) for tc in choice.message.tool_calls]
else:
Expand Down Expand Up @@ -280,35 +294,47 @@ def chat_messages_from_openai(
chat_messages: list[ChatMessage] = []

for message in messages:
content: str | list[Content] = []
if message["role"] == "system" or message["role"] == "developer":
sys_content = message["content"]
if isinstance(sys_content, str):
chat_messages.append(ChatMessageSystem(content=sys_content))
else:
chat_messages.append(
ChatMessageSystem(
content=[content_from_openai(c) for c in sys_content]
)
)
content = []
for sc in sys_content:
content.extend(content_from_openai(sc))
chat_messages.append(ChatMessageSystem(content=content))
elif message["role"] == "user":
user_content = message["content"]
if isinstance(user_content, str):
chat_messages.append(ChatMessageUser(content=user_content))
else:
chat_messages.append(
ChatMessageUser(
content=[content_from_openai(c) for c in user_content]
)
)
content = []
for uc in user_content:
content.extend(content_from_openai(uc))
chat_messages.append(ChatMessageUser(content=content))
elif message["role"] == "assistant":
# resolve content
asst_content = message.get("content", None)
if isinstance(asst_content, str):
content: str | list[Content] = asst_content
result = parse_content_with_reasoning(asst_content)
if result is not None:
content = [
ContentReasoning(
reasoning=result.reasoning,
signature=result.signature,
redacted=result.redacted,
),
ContentText(text=result.content),
]
else:
content = asst_content
elif asst_content is None:
content = message.get("refusal", None) or ""
else:
content = [content_from_openai(c) for c in asst_content]
content = []
for ac in asst_content:
content.extend(content_from_openai(ac, parse_reasoning=True))

# resolve reasoning (OpenAI doesn't suport this however OpenAI-compatible
# interfaces e.g. DeepSeek do include this field so we pluck it out)
Expand All @@ -324,9 +350,9 @@ def chat_messages_from_openai(
# return message
if "tool_calls" in message:
tool_calls: list[ToolCall] = []
for tc in message["tool_calls"]:
tool_calls.append(tool_call_from_openai(tc))
tool_names[tc["id"]] = tc["function"]["name"]
for call in message["tool_calls"]:
tool_calls.append(tool_call_from_openai(call))
tool_names[call["id"]] = call["function"]["name"]

else:
tool_calls = []
Expand All @@ -342,7 +368,9 @@ def chat_messages_from_openai(
if isinstance(tool_content, str):
content = tool_content
else:
content = [content_from_openai(c) for c in tool_content]
content = []
for tc in tool_content:
content.extend(content_from_openai(tc))
chat_messages.append(
ChatMessageTool(
content=content,
Expand All @@ -366,20 +394,40 @@ def tool_call_from_openai(tool_call: ChatCompletionMessageToolCallParam) -> Tool

def content_from_openai(
content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam,
) -> Content:
parse_reasoning: bool = False,
) -> list[Content]:
if content["type"] == "text":
return ContentText(text=content["text"])
text = content["text"]
if parse_reasoning:
result = parse_content_with_reasoning(text)
if result:
return [
ContentReasoning(
reasoning=result.reasoning,
signature=result.signature,
redacted=result.redacted,
),
ContentText(text=result.content),
]
else:
return [ContentText(text=text)]
else:
return [ContentText(text=text)]
elif content["type"] == "image_url":
return ContentImage(
image=content["image_url"]["url"], detail=content["image_url"]["detail"]
)
return [
ContentImage(
image=content["image_url"]["url"], detail=content["image_url"]["detail"]
)
]
elif content["type"] == "input_audio":
return ContentAudio(
audio=content["input_audio"]["data"],
format=content["input_audio"]["format"],
)
return [
ContentAudio(
audio=content["input_audio"]["data"],
format=content["input_audio"]["format"],
)
]
elif content["type"] == "refusal":
return ContentText(text=content["refusal"])
return [ContentText(text=content["refusal"])]


def chat_message_assistant_from_openai(
Expand Down
17 changes: 15 additions & 2 deletions src/inspect_ai/model/_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,26 @@
class ContentWithReasoning(NamedTuple):
content: str
reasoning: str
signature: str | None = None
redacted: bool = False


def parse_content_with_reasoning(content: str) -> ContentWithReasoning | None:
match = re.match(r"\s*<think>(.*?)</think>(.*)", content, re.DOTALL)
# Match <think> tag with optional attributes
pattern = r'\s*<think(?:\s+signature="([^"]*)")?(?:\s+redacted="(true)")?\s*>(.*?)</think>(.*)'
match = re.match(pattern, content, re.DOTALL)

if match:
signature = match.group(1) # This will be None if not present
redacted_value = match.group(2) # This will be "true" or None
reasoning = match.group(3).strip()
content_text = match.group(4).strip()

return ContentWithReasoning(
content=match.group(2).strip(), reasoning=match.group(1).strip()
content=content_text,
reasoning=reasoning,
signature=signature,
redacted=redacted_value == "true",
)
else:
return None
79 changes: 79 additions & 0 deletions tests/model/test_reasoning_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ def test_reasoning_parse_basic():
assert result is not None
assert result.reasoning == "Simple reasoning"
assert result.content == "Normal text"
assert result.signature is None
assert result.redacted is False


def test_reasoning_parse_with_leading_whitespace():
Expand All @@ -15,34 +17,44 @@ def test_reasoning_parse_with_leading_whitespace():
assert result is not None
assert result.reasoning == "Indented reasoning"
assert result.content == "Text"
assert result.signature is None
assert result.redacted is False


def test_reasoning_parse_with_trailing_whitespace():
result = parse_content_with_reasoning("<think>Reasoning</think> \n Text \n")
assert result is not None
assert result.reasoning == "Reasoning"
assert result.content == "Text"
assert result.signature is None
assert result.redacted is False


def test_reasoning_parse_with_newlines_in_reasoning():
result = parse_content_with_reasoning("<think>Multi\nline\nreasoning</think>Text")
assert result is not None
assert result.reasoning == "Multi\nline\nreasoning"
assert result.content == "Text"
assert result.signature is None
assert result.redacted is False


def test_reasoning_parse_empty():
result = parse_content_with_reasoning("<think></think>Text")
assert result is not None
assert result.reasoning == ""
assert result.content == "Text"
assert result.signature is None
assert result.redacted is False


def test_reasoning_parse_empty_content():
result = parse_content_with_reasoning("<think>Just reasoning</think>")
assert result is not None
assert result.reasoning == "Just reasoning"
assert result.content == ""
assert result.signature is None
assert result.redacted is False


def test_reasoning_parse_whitespace_everywhere():
Expand All @@ -57,6 +69,8 @@ def test_reasoning_parse_whitespace_everywhere():
assert result is not None
assert result.reasoning == "Messy\n reasoning"
assert result.content == "Messy\n text"
assert result.signature is None
assert result.redacted is False


def test_reasoning_parse_no_think_tag():
Expand All @@ -67,3 +81,68 @@ def test_reasoning_parse_no_think_tag():
def test_reasoning_parse_unclosed_tag():
result = parse_content_with_reasoning("<think>Unclosed reasoning")
assert result is None


# New tests for signature attribute
def test_reasoning_parse_with_signature():
result = parse_content_with_reasoning(
'<think signature="45ef5ab">Reasoning with signature</think>Content'
)
assert result is not None
assert result.reasoning == "Reasoning with signature"
assert result.content == "Content"
assert result.signature == "45ef5ab"
assert result.redacted is False


# New tests for redacted attribute
def test_reasoning_parse_with_redacted():
result = parse_content_with_reasoning(
'<think redacted="true">Redacted reasoning</think>Content'
)
assert result is not None
assert result.reasoning == "Redacted reasoning"
assert result.content == "Content"
assert result.signature is None
assert result.redacted is True


# New tests for both attributes
def test_reasoning_parse_with_signature_and_redacted():
result = parse_content_with_reasoning(
'<think signature="45ef5ab" redacted="true">Both attributes</think>Content'
)
assert result is not None
assert result.reasoning == "Both attributes"
assert result.content == "Content"
assert result.signature == "45ef5ab"
assert result.redacted is True


# Test with whitespace in attributes
def test_reasoning_parse_with_whitespace_in_attributes():
result = parse_content_with_reasoning(
'<think signature="45ef5ab" redacted="true" >Whitespace in attributes</think>Content'
)
assert result is not None
assert result.reasoning == "Whitespace in attributes"
assert result.content == "Content"
assert result.signature == "45ef5ab"
assert result.redacted is True


# Test with attributes and multiline content
def test_reasoning_parse_with_attributes_and_multiline():
result = parse_content_with_reasoning("""
<think signature="45ef5ab" redacted="true">
Complex
reasoning
</think>
Content
here
""")
assert result is not None
assert result.reasoning == "Complex\n reasoning"
assert result.content == "Content\n here"
assert result.signature == "45ef5ab"
assert result.redacted is True

0 comments on commit 2e2969d

Please sign in to comment.