diff --git a/src/control_flow/flow.py b/src/control_flow/flow.py index 4eb99565..1f6bbb9a 100644 --- a/src/control_flow/flow.py +++ b/src/control_flow/flow.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field, field_validator from control_flow.context import ctx +from control_flow.utilities.marvin import patch_marvin logger = get_logger(__name__) @@ -99,7 +100,7 @@ def wrapper( f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"' ) - with ctx(flow=flow_obj): + with ctx(flow=flow_obj), patch_marvin(): return p_fn(*args, **kwargs) return wrapper diff --git a/src/control_flow/utilities/marvin.py b/src/control_flow/utilities/marvin.py new file mode 100644 index 00000000..709f8686 --- /dev/null +++ b/src/control_flow/utilities/marvin.py @@ -0,0 +1,86 @@ +import functools +from contextlib import contextmanager +from typing import Any, Callable + +import marvin.ai.text +from marvin.client.openai import AsyncMarvinClient +from marvin.settings import temporary_settings as temporary_marvin_settings +from openai.types.chat import ChatCompletion +from prefect import task as prefect_task + +from control_flow.utilities.prefect import create_json_artifact + +original_classify_async = marvin.classify_async +original_cast_async = marvin.cast_async +original_extract_async = marvin.extract_async +original_generate_async = marvin.generate_async +original_paint_async = marvin.paint_async +original_speak_async = marvin.speak_async +original_transcribe_async = marvin.transcribe_async + + +class AsyncControlFlowClient(AsyncMarvinClient): + async def generate_chat(self, **kwargs: Any) -> "ChatCompletion": + super_method = super().generate_chat + + @prefect_task(task_run_name="Generate OpenAI chat completion") + async def _generate_chat(**kwargs): + messages = kwargs.get("messages", []) + create_json_artifact(key="prompt", data=messages) + response = await super_method(**kwargs) + create_json_artifact(key="response", data=response) + return response + + return _generate_chat(**kwargs) + + +def generate_task(name: str, original_fn: Callable): + @functools.wraps(marvin.classify_async) + async def wrapper(*args, **kwargs): + @prefect_task(name=name) + async def inner(*args, **kwargs): + create_json_artifact(key="args", data=[args, kwargs]) + result = await original_fn(*args, **kwargs) + create_json_artifact(key="result", data=result) + return result + + # do this to avoid weirdness with async/sync behavior + return inner(*args, **kwargs) + + return wrapper + + +@contextmanager +def patch_marvin(): + with temporary_marvin_settings(default_async_client_cls=AsyncControlFlowClient): + try: + marvin.ai.text.classify_async = generate_task( + "marvin.classify", original_classify_async + ) + marvin.ai.text.cast_async = generate_task( + "marvin.cast", original_cast_async + ) + marvin.ai.text.extract_async = generate_task( + "marvin.extract", original_extract_async + ) + marvin.ai.text.generate_async = generate_task( + "marvin.generate", original_generate_async + ) + marvin.ai.images.paint_async = generate_task( + "marvin.paint", original_paint_async + ) + marvin.ai.audio.speak_async = generate_task( + "marvin.speak", original_speak_async + ) + marvin.ai.audio.transcribe_async = generate_task( + "marvin.transcribe", original_transcribe_async + ) + yield + finally: + marvin.ai.text.classify_async = original_classify_async + marvin.ai.text.cast_async = original_cast_async + marvin.ai.text.extract_async = original_extract_async + marvin.ai.text.generate_async = original_generate_async + marvin.ai.images.paint_async = original_paint_async + marvin.ai.audio.speak_async = original_speak_async + marvin.ai.audio.transcribe_async = original_transcribe_async diff --git a/src/control_flow/utilities/prefect.py b/src/control_flow/utilities/prefect.py index 1e26b364..49f5d73e 100644 --- a/src/control_flow/utilities/prefect.py +++ b/src/control_flow/utilities/prefect.py @@ -53,10 +53,7 @@ def create_json_artifact( Create a JSON artifact. """ - if isinstance(data, str): - json_data = data - else: - json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode() + json_data = TypeAdapter(type(data)).dump_json(data, indent=2).decode() create_markdown_artifact( key=key,