From 2e2969d5c59200c88339e8872c8e06e106441934 Mon Sep 17 00:00:00 2001 From: jjallaire-aisi Date: Thu, 27 Feb 2025 16:20:18 +0000 Subject: [PATCH] support anthropic reasonning in agent bridge (#1422) Co-authored-by: jjallaire --- src/inspect_ai/model/_openai.py | 116 ++++++++++++++++++++-------- src/inspect_ai/model/_reasoning.py | 17 +++- tests/model/test_reasoning_parse.py | 79 +++++++++++++++++++ 3 files changed, 176 insertions(+), 36 deletions(-) diff --git a/src/inspect_ai/model/_openai.py b/src/inspect_ai/model/_openai.py index b2459c530..0084d6196 100644 --- a/src/inspect_ai/model/_openai.py +++ b/src/inspect_ai/model/_openai.py @@ -38,6 +38,7 @@ from inspect_ai._util.url import is_http_url from inspect_ai.model._call_tools import parse_tool_call from inspect_ai.model._model_output import ChatCompletionChoice, Logprobs +from inspect_ai.model._reasoning import parse_content_with_reasoning from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo from ._chat_message import ( @@ -154,14 +155,14 @@ async def openai_chat_message( if message.tool_calls: return ChatCompletionAssistantMessageParam( role=message.role, - content=message.text, + content=openai_assistant_content(message), tool_calls=[ openai_chat_tool_call_param(call) for call in message.tool_calls ], ) else: return ChatCompletionAssistantMessageParam( - role=message.role, content=message.text + role=message.role, content=openai_assistant_content(message) ) elif message.role == "tool": return ChatCompletionToolMessageParam( @@ -181,16 +182,29 @@ async def openai_chat_messages( return [await openai_chat_message(message, model) for message in messages] +def openai_assistant_content(message: ChatMessageAssistant) -> str: + if isinstance(message.content, str): + content = message.content + else: + content = "" + for c in message.content: + if c.type == "reasoning": + attribs = "" + if c.signature is not None: + attribs = f'{attribs} signature="{c.signature}"' + if c.redacted: + attribs = f'{attribs} redacted="true"' + content = f"{content}\n\n{c.reasoning}\n\n" + elif c.type == "text": + content = f"{content}\n{c.text}" + return content + + def openai_chat_choices(choices: list[ChatCompletionChoice]) -> list[Choice]: oai_choices: list[Choice] = [] for index, choice in enumerate(choices): - if isinstance(choice.message.content, str): - content = choice.message.content - else: - content = "\n".join( - [c.text for c in choice.message.content if c.type == "text"] - ) + content = openai_assistant_content(choice.message) if choice.message.tool_calls: tool_calls = [openai_chat_tool_call(tc) for tc in choice.message.tool_calls] else: @@ -280,35 +294,47 @@ def chat_messages_from_openai( chat_messages: list[ChatMessage] = [] for message in messages: + content: str | list[Content] = [] if message["role"] == "system" or message["role"] == "developer": sys_content = message["content"] if isinstance(sys_content, str): chat_messages.append(ChatMessageSystem(content=sys_content)) else: - chat_messages.append( - ChatMessageSystem( - content=[content_from_openai(c) for c in sys_content] - ) - ) + content = [] + for sc in sys_content: + content.extend(content_from_openai(sc)) + chat_messages.append(ChatMessageSystem(content=content)) elif message["role"] == "user": user_content = message["content"] if isinstance(user_content, str): chat_messages.append(ChatMessageUser(content=user_content)) else: - chat_messages.append( - ChatMessageUser( - content=[content_from_openai(c) for c in user_content] - ) - ) + content = [] + for uc in user_content: + content.extend(content_from_openai(uc)) + chat_messages.append(ChatMessageUser(content=content)) elif message["role"] == "assistant": # resolve content asst_content = message.get("content", None) if isinstance(asst_content, str): - content: str | list[Content] = asst_content + result = parse_content_with_reasoning(asst_content) + if result is not None: + content = [ + ContentReasoning( + reasoning=result.reasoning, + signature=result.signature, + redacted=result.redacted, + ), + ContentText(text=result.content), + ] + else: + content = asst_content elif asst_content is None: content = message.get("refusal", None) or "" else: - content = [content_from_openai(c) for c in asst_content] + content = [] + for ac in asst_content: + content.extend(content_from_openai(ac, parse_reasoning=True)) # resolve reasoning (OpenAI doesn't suport this however OpenAI-compatible # interfaces e.g. DeepSeek do include this field so we pluck it out) @@ -324,9 +350,9 @@ def chat_messages_from_openai( # return message if "tool_calls" in message: tool_calls: list[ToolCall] = [] - for tc in message["tool_calls"]: - tool_calls.append(tool_call_from_openai(tc)) - tool_names[tc["id"]] = tc["function"]["name"] + for call in message["tool_calls"]: + tool_calls.append(tool_call_from_openai(call)) + tool_names[call["id"]] = call["function"]["name"] else: tool_calls = [] @@ -342,7 +368,9 @@ def chat_messages_from_openai( if isinstance(tool_content, str): content = tool_content else: - content = [content_from_openai(c) for c in tool_content] + content = [] + for tc in tool_content: + content.extend(content_from_openai(tc)) chat_messages.append( ChatMessageTool( content=content, @@ -366,20 +394,40 @@ def tool_call_from_openai(tool_call: ChatCompletionMessageToolCallParam) -> Tool def content_from_openai( content: ChatCompletionContentPartParam | ChatCompletionContentPartRefusalParam, -) -> Content: + parse_reasoning: bool = False, +) -> list[Content]: if content["type"] == "text": - return ContentText(text=content["text"]) + text = content["text"] + if parse_reasoning: + result = parse_content_with_reasoning(text) + if result: + return [ + ContentReasoning( + reasoning=result.reasoning, + signature=result.signature, + redacted=result.redacted, + ), + ContentText(text=result.content), + ] + else: + return [ContentText(text=text)] + else: + return [ContentText(text=text)] elif content["type"] == "image_url": - return ContentImage( - image=content["image_url"]["url"], detail=content["image_url"]["detail"] - ) + return [ + ContentImage( + image=content["image_url"]["url"], detail=content["image_url"]["detail"] + ) + ] elif content["type"] == "input_audio": - return ContentAudio( - audio=content["input_audio"]["data"], - format=content["input_audio"]["format"], - ) + return [ + ContentAudio( + audio=content["input_audio"]["data"], + format=content["input_audio"]["format"], + ) + ] elif content["type"] == "refusal": - return ContentText(text=content["refusal"]) + return [ContentText(text=content["refusal"])] def chat_message_assistant_from_openai( diff --git a/src/inspect_ai/model/_reasoning.py b/src/inspect_ai/model/_reasoning.py index 7e9a5befa..d341acfb5 100644 --- a/src/inspect_ai/model/_reasoning.py +++ b/src/inspect_ai/model/_reasoning.py @@ -5,13 +5,26 @@ class ContentWithReasoning(NamedTuple): content: str reasoning: str + signature: str | None = None + redacted: bool = False def parse_content_with_reasoning(content: str) -> ContentWithReasoning | None: - match = re.match(r"\s*(.*?)(.*)", content, re.DOTALL) + # Match tag with optional attributes + pattern = r'\s*(.*?)(.*)' + match = re.match(pattern, content, re.DOTALL) + if match: + signature = match.group(1) # This will be None if not present + redacted_value = match.group(2) # This will be "true" or None + reasoning = match.group(3).strip() + content_text = match.group(4).strip() + return ContentWithReasoning( - content=match.group(2).strip(), reasoning=match.group(1).strip() + content=content_text, + reasoning=reasoning, + signature=signature, + redacted=redacted_value == "true", ) else: return None diff --git a/tests/model/test_reasoning_parse.py b/tests/model/test_reasoning_parse.py index 5818da11a..3af46c4c8 100644 --- a/tests/model/test_reasoning_parse.py +++ b/tests/model/test_reasoning_parse.py @@ -6,6 +6,8 @@ def test_reasoning_parse_basic(): assert result is not None assert result.reasoning == "Simple reasoning" assert result.content == "Normal text" + assert result.signature is None + assert result.redacted is False def test_reasoning_parse_with_leading_whitespace(): @@ -15,6 +17,8 @@ def test_reasoning_parse_with_leading_whitespace(): assert result is not None assert result.reasoning == "Indented reasoning" assert result.content == "Text" + assert result.signature is None + assert result.redacted is False def test_reasoning_parse_with_trailing_whitespace(): @@ -22,6 +26,8 @@ def test_reasoning_parse_with_trailing_whitespace(): assert result is not None assert result.reasoning == "Reasoning" assert result.content == "Text" + assert result.signature is None + assert result.redacted is False def test_reasoning_parse_with_newlines_in_reasoning(): @@ -29,6 +35,8 @@ def test_reasoning_parse_with_newlines_in_reasoning(): assert result is not None assert result.reasoning == "Multi\nline\nreasoning" assert result.content == "Text" + assert result.signature is None + assert result.redacted is False def test_reasoning_parse_empty(): @@ -36,6 +44,8 @@ def test_reasoning_parse_empty(): assert result is not None assert result.reasoning == "" assert result.content == "Text" + assert result.signature is None + assert result.redacted is False def test_reasoning_parse_empty_content(): @@ -43,6 +53,8 @@ def test_reasoning_parse_empty_content(): assert result is not None assert result.reasoning == "Just reasoning" assert result.content == "" + assert result.signature is None + assert result.redacted is False def test_reasoning_parse_whitespace_everywhere(): @@ -57,6 +69,8 @@ def test_reasoning_parse_whitespace_everywhere(): assert result is not None assert result.reasoning == "Messy\n reasoning" assert result.content == "Messy\n text" + assert result.signature is None + assert result.redacted is False def test_reasoning_parse_no_think_tag(): @@ -67,3 +81,68 @@ def test_reasoning_parse_no_think_tag(): def test_reasoning_parse_unclosed_tag(): result = parse_content_with_reasoning("Unclosed reasoning") assert result is None + + +# New tests for signature attribute +def test_reasoning_parse_with_signature(): + result = parse_content_with_reasoning( + 'Reasoning with signatureContent' + ) + assert result is not None + assert result.reasoning == "Reasoning with signature" + assert result.content == "Content" + assert result.signature == "45ef5ab" + assert result.redacted is False + + +# New tests for redacted attribute +def test_reasoning_parse_with_redacted(): + result = parse_content_with_reasoning( + 'Redacted reasoningContent' + ) + assert result is not None + assert result.reasoning == "Redacted reasoning" + assert result.content == "Content" + assert result.signature is None + assert result.redacted is True + + +# New tests for both attributes +def test_reasoning_parse_with_signature_and_redacted(): + result = parse_content_with_reasoning( + 'Both attributesContent' + ) + assert result is not None + assert result.reasoning == "Both attributes" + assert result.content == "Content" + assert result.signature == "45ef5ab" + assert result.redacted is True + + +# Test with whitespace in attributes +def test_reasoning_parse_with_whitespace_in_attributes(): + result = parse_content_with_reasoning( + 'Whitespace in attributesContent' + ) + assert result is not None + assert result.reasoning == "Whitespace in attributes" + assert result.content == "Content" + assert result.signature == "45ef5ab" + assert result.redacted is True + + +# Test with attributes and multiline content +def test_reasoning_parse_with_attributes_and_multiline(): + result = parse_content_with_reasoning(""" + + Complex + reasoning + + Content + here + """) + assert result is not None + assert result.reasoning == "Complex\n reasoning" + assert result.content == "Content\n here" + assert result.signature == "45ef5ab" + assert result.redacted is True