Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Adjust flush timing for nested generations using outermost_lock_context #183

Merged
merged 6 commits into from
Feb 12, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 70 additions & 24 deletions lilypad/generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import inspect
import json
from collections.abc import Callable, Coroutine
from contextvars import ContextVar
from collections.abc import Callable, Coroutine, Generator
from contextlib import contextmanager
from contextvars import ContextVar, Token
from functools import wraps
from typing import Any, ParamSpec, Protocol, TypeVar, overload

from fastapi.encoders import jsonable_encoder
from opentelemetry.trace import get_tracer
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import get_tracer, get_tracer_provider
from opentelemetry.util.types import AttributeValue
from pydantic import BaseModel

Expand Down Expand Up @@ -49,6 +51,45 @@ def __call__(
...


def _get_batch_span_processor() -> BatchSpanProcessor | None:
"""Get the BatchSpanProcessor from the current TracerProvider.

Retrieve the BatchSpanProcessor from the current TracerProvider dynamically.
This avoids using a global variable by inspecting the provider's _active_span_processors.
"""
tracer_provider = get_tracer_provider()
processor = getattr(tracer_provider, "_active_span_processor", None)
if not processor:
return None
_span_processors = getattr(processor, "_span_processors", None)
if _span_processors:
for processor in _span_processors:
if isinstance(processor, BatchSpanProcessor):
return processor
return None


@contextmanager
def outermost_lock_context(enable_lock: bool) -> Generator[None, None, None]:
"""Acquire the BatchSpanProcessor's condition lock if enable_lock is True.

This context manager is intended for use in the outermost generation.
When enable_lock is True, it retrieves the current BatchSpanProcessor and acquires its
condition lock. This ensures that flush operations are synchronized and only executed
at the outermost generation level.
For inner generations (enable_lock is False), no lock is acquired.
"""
if not enable_lock:
yield
return
processor = _get_batch_span_processor()
if not processor:
yield
return
with processor.condition:
yield


def _construct_trace_attributes(
generation: GenerationPublic,
arg_types: dict[str, str],
Expand Down Expand Up @@ -181,18 +222,21 @@ async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> _R:
)
token = current_generation.set(generation)
try:
if not is_mirascope_call:
decorator = _trace(
generation=generation,
arg_types=arg_types,
arg_values=arg_values,
prompt_template="",
# Check if this is the outermost generation (no previous generation)
is_outermost = token.old_value is None
with outermost_lock_context(is_outermost):
if not is_mirascope_call:
decorator_inner = _trace(
generation=generation,
arg_types=arg_types,
arg_values=arg_values,
prompt_template="",
)
return await decorator_inner(fn)(*args, **kwargs)
decorator_inner = create_mirascope_middleware(
generation, arg_types, arg_values, True, prompt_template
)
return await decorator(fn)(*args, **kwargs)
decorator = create_mirascope_middleware(
generation, arg_types, arg_values, True, prompt_template
)
return await decorator(fn)(*args, **kwargs)
return await decorator_inner(fn)(*args, **kwargs)
finally:
current_generation.reset(token)

Expand All @@ -208,18 +252,20 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
)
token = current_generation.set(generation)
try:
if not is_mirascope_call:
decorator = _trace(
generation=generation,
arg_types=arg_types,
arg_values=arg_values,
prompt_template="",
is_outermost = token.old_value == Token.MISSING
with outermost_lock_context(is_outermost):
if not is_mirascope_call:
decorator_inner = _trace(
generation=generation,
arg_types=arg_types,
arg_values=arg_values,
prompt_template="",
)
return decorator_inner(fn)(*args, **kwargs) # pyright: ignore [reportReturnType]
decorator_inner = create_mirascope_middleware(
generation, arg_types, arg_values, False, prompt_template
)
return decorator(fn)(*args, **kwargs) # pyright: ignore [reportReturnType]
decorator = create_mirascope_middleware(
generation, arg_types, arg_values, False, prompt_template
)
return decorator(fn)(*args, **kwargs) # pyright: ignore [reportReturnType]
finally:
current_generation.reset(token)

Expand Down