diff --git a/agentops/client.py b/agentops/client.py index 80cb1cc5..c3928786 100644 --- a/agentops/client.py +++ b/agentops/client.py @@ -36,7 +36,7 @@ def __init__(self): self._pre_init_messages: List[str] = [] self._initialized: bool = False self._llm_tracker: Optional[LlmTracker] = None - self._sessions: List[Session] = active_sessions + self._sessions: List[Session] = list(active_sessions) self._config = Configuration() self._pre_init_queue = {"agents": []} diff --git a/agentops/config.py b/agentops/config.py index 7dfb574d..ad1019d8 100644 --- a/agentops/config.py +++ b/agentops/config.py @@ -1,74 +1,53 @@ -from typing import List, Optional +from dataclasses import dataclass, field +from typing import Optional, Set from uuid import UUID from .log_config import logger +# TODO: Use annotations to clarify the purpose of each attribute. +# Details are defined in a docstrings found in __init__.py, but +# it's good to have those right on the fields at class definition +@dataclass class Configuration: - def __init__(self): - self.api_key: Optional[str] = None - self.parent_key: Optional[str] = None - self.endpoint: str = "https://api.agentops.ai" - self.max_wait_time: int = 5000 - self.max_queue_size: int = 512 - self.default_tags: set[str] = set() - self.instrument_llm_calls: bool = True - self.auto_start_session: bool = True - self.skip_auto_end_session: bool = False - self.env_data_opt_out: bool = False + api_key: Optional[str] = None + parent_key: Optional[str] = None + endpoint: str = "https://api.agentops.ai" + max_wait_time: int = 5000 + max_queue_size: int = 512 + graceful_shutdown_wait_time: int = 2000 + default_tags: Set[str] = field(default_factory=set) + instrument_llm_calls: bool = True + auto_start_session: bool = True + skip_auto_end_session: bool = False + env_data_opt_out: bool = False def configure( self, client, - api_key: Optional[str] = None, - parent_key: Optional[str] = None, - endpoint: Optional[str] = None, - max_wait_time: Optional[int] = None, - max_queue_size: Optional[int] = None, - default_tags: Optional[List[str]] = None, - instrument_llm_calls: Optional[bool] = None, - auto_start_session: Optional[bool] = None, - skip_auto_end_session: Optional[bool] = None, - env_data_opt_out: Optional[bool] = None, + **kwargs ): - if api_key is not None: - try: - UUID(api_key) - self.api_key = api_key - except ValueError: - message = f"API Key is invalid: {{{api_key}}}.\n\t Find your API key at https://app.agentops.ai/settings/projects" - client.add_pre_init_warning(message) - logger.error(message) - - if parent_key is not None: - try: - UUID(parent_key) - self.parent_key = parent_key - except ValueError: - message = f"Parent Key is invalid: {parent_key}" - client.add_pre_init_warning(message) - logger.warning(message) - - if endpoint is not None: - self.endpoint = endpoint - - if max_wait_time is not None: - self.max_wait_time = max_wait_time - - if max_queue_size is not None: - self.max_queue_size = max_queue_size - - if default_tags is not None: - self.default_tags.update(default_tags) - - if instrument_llm_calls is not None: - self.instrument_llm_calls = instrument_llm_calls - - if auto_start_session is not None: - self.auto_start_session = auto_start_session - - if skip_auto_end_session is not None: - self.skip_auto_end_session = skip_auto_end_session - - if env_data_opt_out is not None: - self.env_data_opt_out = env_data_opt_out + # Special handling for keys that need UUID validation + for key_name in ['api_key', 'parent_key']: + if key_name in kwargs and kwargs[key_name] is not None: + try: + UUID(kwargs[key_name]) + setattr(self, key_name, kwargs[key_name]) + except ValueError: + message = ( + f"API Key is invalid: {{{kwargs[key_name]}}}.\n\t Find your API key at https://app.agentops.ai/settings/projects" + if key_name == 'api_key' + else f"Parent Key is invalid: {kwargs[key_name]}" + ) + client.add_pre_init_warning(message) + logger.error(message) if key_name == 'api_key' else logger.warning(message) + kwargs.pop(key_name) + + # Special handling for default_tags which needs update() instead of assignment + if 'default_tags' in kwargs and kwargs['default_tags'] is not None: + self.default_tags.update(kwargs.pop('default_tags')) + + # Handle all other attributes + for key, value in kwargs.items(): + if value is not None and hasattr(self, key): + setattr(self, key, value) diff --git a/agentops/session.py b/agentops/session.py index 58225b6c..60669702 100644 --- a/agentops/session.py +++ b/agentops/session.py @@ -1,24 +1,224 @@ -import copy -import functools +from __future__ import annotations # Allow forward references + +import datetime as dt import json +import queue import threading import time from decimal import ROUND_HALF_UP, Decimal -from termcolor import colored -from typing import Optional, List, Union +from typing import Annotated, Dict, List, Optional, Union from uuid import UUID, uuid4 -from datetime import datetime -from .exceptions import ApiServerException -from .enums import EndState -from .event import ErrorEvent, Event -from .log_config import logger +from termcolor import colored + from .config import Configuration -from .helpers import get_ISO_time, filter_unjsonable, safe_serialize +from .enums import EndState, EventType +from .event import ErrorEvent, Event +from .exceptions import ApiServerException +from .helpers import filter_unjsonable, get_ISO_time, safe_serialize from .http_client import HttpClient +from .log_config import logger +""" + +minor changes: +- `active_sessions` is now a WeakSet. Cleanup is not required. Tests can be simplified + +- Removed unsafe usage of __dict__ (which is very dangerous when used without __slots__) +- Introduced Session.active property +""" + +from typing import DefaultDict +from weakref import WeakSet + + +class SessionDict(DefaultDict): + session_id: UUID + # -------------- + config: Configuration + end_state: str = EndState.INDETERMINATE.value + end_state_reason: Optional[str] = None + end_timestamp: Optional[str] = None + # Create a counter dictionary with each EventType name initialized to 0 + event_counts: Dict[str, int] + host_env: Optional[dict] = None + init_timestamp: str # Will be set to get_ISO_time() during __init__ + is_running: bool = False + jwt: Optional[str] = None + tags: Optional[List[str]] = None + video: Optional[str] = None + + def __init__(self, **kwargs): + kwargs.setdefault( + "event_counts", {event_type.value: 0 for event_type in EventType} + ) + kwargs.setdefault("init_timestamp", get_ISO_time()) + super().__init__(**kwargs) -class Session: + +class SessionsCollection(WeakSet): + """ + A custom collection for managing Session objects that combines WeakSet's automatic cleanup + with list-like indexing capabilities. + + This class is needed because: + 1. We want WeakSet's automatic cleanup of unreferenced sessions + 2. We need to access sessions by index (e.g., self._sessions[0]) for backwards compatibility + 3. Standard WeakSet doesn't support indexing + """ + + def __getitem__(self, index: int) -> Session: + """ + Enable indexing into the collection (e.g., sessions[0]). + """ + # Convert to list for indexing since sets aren't ordered + items = list(self) + return items[index] + + def __iter__(self): + """ + Override the default iterator to yield sessions sorted by init_timestamp. + If init_timestamp is not available, fall back to __create_ts. + + WARNING: Using __create_ts as a fallback for ordering may lead to unexpected results + if init_timestamp is not set correctly. + """ + return iter(sorted(super().__iter__(), key=lambda session: ( + session.init_timestamp if hasattr(session, 'init_timestamp') else session.__create_ts + ))) + +class SessionApi: + """ + Solely focuses on interacting with the API + + Developer notes: + Need to clarify (and define) a standard and consistent Api interface. + + The way it can be approached is by having a base `Api` class that holds common + configurations and features, while implementors provide entity-related controllers + """ + + # TODO: Decouple from standard Configuration a Session's entity own configuration. + # NOTE: pydantic-settings plays out beautifully in such setup, but it's not a requirement. + # TODO: Eventually move to apis/ + session: Session + + def __init__(self, session: Session): + self.session = session + + @property + def config(self): # Forward decl. + return self.session.config + + @property + def jwt(self) -> Optional[str]: + """Convenience property that falls back to dictionary access""" + return self.session.get("jwt") + + @jwt.setter + def jwt(self, value: Optional[str]): + self.session["jwt"] = value + + def update_session(self) -> None: + try: + payload = {"session": dict(self.session)} + res = HttpClient.post( + f"{self.config.endpoint}/v2/update_session", + json.dumps(filter_unjsonable(payload)).encode("utf-8"), + jwt=self.jwt, + ) + except ApiServerException as e: + return logger.error(f"Could not update session - {e}") + + # WARN: Unused method + def reauthorize_jwt(self) -> Union[str, None]: + payload = {"session_id": self.session.session_id} + serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") + res = HttpClient.post( + f"{self.config.endpoint}/v2/reauthorize_jwt", + serialized_payload, + self.config.api_key, + ) + + logger.debug(res.body) + + if res.code != 200: + return None + + jwt = res.body.get("jwt", None) + self.jwt = jwt + return jwt + + def create_session(self, session: SessionDict): + """ + Creates a new session via API call + + Returns: + tuple containing: + - success (bool): Whether the creation was successful + - jwt (Optional[str]): JWT token if successful + - session_url (Optional[str]): URL to view the session if successful + """ + payload = {"session": dict(session)} + serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") + + try: + res = HttpClient.post( + f"{self.config.endpoint}/v2/create_session", + serialized_payload, + self.config.api_key, + self.config.parent_key, + ) + except ApiServerException as e: + logger.error(f"Could not start session - {e}") + return False + else: + if res.code != 200: + return False + + jwt = res.body.get("jwt", None) + self.jwt = jwt + if jwt is None: + return False + + session_url = res.body.get( + "session_url", + f"https://app.agentops.ai/drilldown?session_id={session.session_id}", + ) + + logger.info( + colored( + f"\x1b[34mSession Replay: {session_url}\x1b[0m", + "blue", + ) + ) + + return True + + def batch(self, events: List[Event]) -> None: + serialized_payload = safe_serialize(dict(events=events)).encode("utf-8") + try: + HttpClient.post( + f"{self.config.endpoint}/v2/create_events", + serialized_payload, + jwt=self.jwt, + ) + except ApiServerException as e: + return logger.error(f"Could not post events - {e}") + + # Update event counts on the session instance + for event in events: + event_type = event["event_type"] + if event_type in self.session["event_counts"]: + self.session["event_counts"][event_type] += 1 + + logger.debug("\n") + logger.debug(f"Session request to {self.config.endpoint}/v2/create_events") + logger.debug(serialized_payload) + logger.debug("\n") + + +class Session(SessionDict): """ Represents a session of events, with a start and end state. @@ -28,48 +228,96 @@ class Session: Attributes: init_timestamp (float): The timestamp for when the session started, represented as seconds since the epoch. - end_timestamp (float, optional): The timestamp for when the session ended, represented as seconds since the epoch. This is only set after end_session is called. - end_state (str, optional): The final state of the session. Suggested: "Success", "Fail", "Indeterminate". Defaults to "Indeterminate". + end_timestamp (float, optional): The timestamp for when the session ended, represented as seconds since the epoch. + end_state (str, optional): The final state of the session. Defaults to "Indeterminate". end_state_reason (str, optional): The reason for ending the session. - """ - def __init__( - self, - session_id: UUID, - config: Configuration, - tags: Optional[List[str]] = None, - host_env: Optional[dict] = None, - ): - self.end_timestamp = None - self.end_state: Optional[str] = "Indeterminate" - self.session_id = session_id - self.init_timestamp = get_ISO_time() - self.tags: List[str] = tags or [] - self.video: Optional[str] = None - self.end_state_reason: Optional[str] = None - self.host_env = host_env - self.config = config - self.jwt = None - self.lock = threading.Lock() - self.queue = [] - self.event_counts = { - "llms": 0, - "tools": 0, - "actions": 0, - "errors": 0, - "apis": 0, + thread: Annotated[ + EventPublisherThread, "Publishes events to the API in a background thread." + ] + + def __init__(self, **kwargs): + logger.debug(f"Initializing new Session with id {kwargs.get('session_id')}") + + # Set creation timestamp + self.__create_ts = time.monotonic() + + self.api = SessionApi(self) + + self._locks = { + "lifecycle": threading.Lock(), # Controls session lifecycle operations + "events": threading.Lock(), # Protects event queue operations + "session": threading.Lock(), # Protects session state updates + "tags": threading.Lock(), # Protects tag modifications + "api": threading.Lock(), # Protects API calls + } + + self.config = kwargs.pop( + "config", Configuration() + ) # config assigned at object level + logger.debug(f"{self.__class__.__name__}: Session locks initialized") + + super().__init__( + **kwargs + ) # Initialize SessionStruct after `locks` were defined + + # Initialize queue with max size from config + self._events = queue.Queue(maxsize=self.config.max_queue_size) + self._cleanup_done = False + self._stop_flag = threading.Event() + + # Initialize conditions + self.conditions = { + "cleanup": threading.Condition(self._locks["lifecycle"]), + "changes": threading.Condition(self._locks["session"]), } - self.stop_flag = threading.Event() - self.thread = threading.Thread(target=self._run) - self.thread.daemon = True - self.thread.start() + # Initialize threads + self.publisher_thread = EventPublisherThread(self) + self.observer_thread = ChangesObserverThread(self) + + self._is_running = False + + # Add session to active_sessions + active_sessions.add(self) + + self._start_session() + + if self.is_running: + # Only start threads if session started successfully + self.publisher_thread.start() + self.observer_thread.start() + else: + self.stop() + + def __hash__(self) -> int: + return hash(self.session_id) + + @property + def session_id(self) -> UUID: + return self["session_id"] + + @property + def active(self) -> bool: + """Alias for is_running for backward compatibility""" + return self.is_running + + @property + def is_running(self) -> bool: + """Thread-safe access to running state""" + with self._locks["lifecycle"]: + return self._is_running - self.is_running = self._start_session() - if self.is_running == False: - self.stop_flag.set() - self.thread.join(timeout=1) + @is_running.setter + def is_running(self, value: bool): + """Thread-safe modification of running state""" + with self._locks["lifecycle"]: + self._is_running = value + + @property + def counter(self): + return self["event_counts"] def set_video(self, video: str) -> None: """ @@ -80,34 +328,93 @@ def set_video(self, video: str) -> None: """ self.video = video + def add_tags(self, tags: List[str]) -> None: + """ + Append to session tags at runtime. + + Args: + tags (List[str]): The list of tags to append. + """ + if not self.is_running: + return + + if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)): + if isinstance(tags, str): + tags = [tags] + + if self.tags is None: + self.tags = tags + else: + for tag in tags: + if tag not in self.tags: + self.tags.append(tag) + + self._publish() + + def set_tags(self, tags): + if not self.is_running: + return + + if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)): + if isinstance(tags, str): + tags = [tags] + + self.tags = tags + self._publish() + + # --- Interactors + def record(self, event: Union[Event, ErrorEvent]): + if not self.is_running: + logger.debug( + f"{self.__class__.__name__}: Attempted to record event but session is not running" + ) + return + + logger.debug(f"Recording event: {event.event_type}") + if isinstance(event, Event): + if not event.end_timestamp or event.init_timestamp == event.end_timestamp: + event.end_timestamp = get_ISO_time() # WARN: Unrestricted assignment + elif isinstance(event, ErrorEvent): + if event.trigger_event: + if ( + not event.trigger_event.end_timestamp + or event.trigger_event.init_timestamp + == event.trigger_event.end_timestamp + ): + event.trigger_event.end_timestamp = get_ISO_time() + + event.trigger_event_id = event.trigger_event.id + event.trigger_event_type = event.trigger_event.event_type + self._enqueue(event.trigger_event.__dict__) + event.trigger_event = None # removes trigger_event from serialization + # ^^ NOTE: Consider memento https://refactoring.guru/design-patterns/memento/python/example + + self._enqueue( + event.__dict__ + ) # WARNING: This is very dangerous. Either define Event.__slots__ or turn Event into a dataclass + def end_session( self, end_state: str = "Indeterminate", end_state_reason: Optional[str] = None, video: Optional[str] = None, ) -> Union[Decimal, None]: - - if not self.is_running: - return - if not any(end_state == state.value for state in EndState): return logger.warning( "Invalid end_state. Please use one of the EndState enums" ) - self.end_timestamp = get_ISO_time() - self.end_state = end_state - self.end_state_reason = end_state_reason - if video is not None: - self.video = video + self["end_timestamp"] = get_ISO_time() + self["end_state"] = end_state or self.end_state + self["end_state_reason"] = end_state_reason or self.end_state_reason - self.stop_flag.set() - self.thread.join(timeout=1) - self._flush_queue() - - def format_duration(start_time, end_time): - start = datetime.fromisoformat(start_time.replace("Z", "+00:00")) - end = datetime.fromisoformat(end_time.replace("Z", "+00:00")) + def __calc_elapsed(): + start = dt.datetime.fromisoformat( + self["init_timestamp"].replace("Z", "+00:00") + ) + end = dt.datetime.fromisoformat( + self["end_timestamp"].replace("Z", "+00:00") + ) duration = end - start hours, remainder = divmod(duration.total_seconds(), 3600) @@ -122,8 +429,8 @@ def format_duration(start_time, end_time): return " ".join(parts) - with self.lock: - payload = {"session": self.__dict__} + with self._locks["api"]: + payload = {"session": dict(self)} try: res = HttpClient.post( f"{self.config.endpoint}/v2/update_session", @@ -131,12 +438,13 @@ def format_duration(start_time, end_time): jwt=self.jwt, ) except ApiServerException as e: - return logger.error(f"Could not end session - {e}") + logger.error(f"Could not end session - {e}") + return None logger.debug(res.body) token_cost = res.body.get("token_cost", "unknown") - formatted_duration = format_duration(self.init_timestamp, self.end_timestamp) + formatted_duration = __calc_elapsed() if token_cost == "unknown" or token_cost is None: token_cost_d = Decimal(0) @@ -155,10 +463,10 @@ def format_duration(start_time, end_time): f"Session Stats - " f"{colored('Duration:', attrs=['bold'])} {formatted_duration} | " f"{colored('Cost:', attrs=['bold'])} ${formatted_cost} | " - f"{colored('LLMs:', attrs=['bold'])} {self.event_counts['llms']} | " - f"{colored('Tools:', attrs=['bold'])} {self.event_counts['tools']} | " - f"{colored('Actions:', attrs=['bold'])} {self.event_counts['actions']} | " - f"{colored('Errors:', attrs=['bold'])} {self.event_counts['errors']}" + f"{colored('LLMs:', attrs=['bold'])} {self.counter['llms']} | " + f"{colored('Tools:', attrs=['bold'])} {self.counter['tools']} | " + f"{colored('Actions:', attrs=['bold'])} {self.counter['actions']} | " + f"{colored('Errors:', attrs=['bold'])} {self.counter['errors']}" ) logger.info(analytics) @@ -174,227 +482,363 @@ def format_duration(start_time, end_time): ) ) - active_sessions.remove(self) + active_sessions.discard(self) return token_cost_d - def add_tags(self, tags: List[str]) -> None: - """ - Append to session tags at runtime. - - Args: - tags (List[str]): The list of tags to append. - """ + def create_agent( + self, name: str, agent_id: Optional[str] = None + ) -> object: # FIXME: Is this `int`, `UUID`, or `str`? if not self.is_running: return + if agent_id is None: + agent_id = str(uuid4()) - if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)): - if isinstance(tags, str): - tags = [tags] + payload = { + "id": agent_id, + "name": name, + } - if self.tags is None: - self.tags = tags - else: - for tag in tags: - if tag not in self.tags: - self.tags.append(tag) + serialized_payload = safe_serialize(payload).encode("utf-8") + try: + HttpClient.post( + f"{self.config.endpoint}/v2/create_agent", + serialized_payload, + jwt=self.jwt, + ) + except ApiServerException as e: + logger.error(f"Could not create agent - {e}") - self._update_session() + return agent_id - def set_tags(self, tags): + def _enqueue(self, event: dict) -> None: + """Thread-safe event enqueueing""" if not self.is_running: + logger.warning("Attempted to enqueue event but session is not running") return - if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)): - if isinstance(tags, str): - tags = [tags] - - self.tags = tags - self._update_session() + logger.debug( + f"{self.__class__.__name__} enqueueing event, current queue size: {self._events.qsize()}" + ) + try: + self._events.put_nowait( + event + ) # Use put_nowait instead of directly accessing queue + logger.debug( + f"{self.__class__.__name__} successfully enqueued event, new size: {self._events.qsize()}" + ) + except queue.Full: + logger.warning( + f"{self.__class__.__name__} event queue is full, event will be dropped" + ) - def record(self, event: Union[Event, ErrorEvent]): - if not self.is_running: - return - if isinstance(event, Event): - if not event.end_timestamp or event.init_timestamp == event.end_timestamp: - event.end_timestamp = get_ISO_time() - elif isinstance(event, ErrorEvent): - if event.trigger_event: - if ( - not event.trigger_event.end_timestamp - or event.trigger_event.init_timestamp - == event.trigger_event.end_timestamp - ): - event.trigger_event.end_timestamp = get_ISO_time() + if self._events.qsize() >= self.config.max_queue_size: + logger.debug( + f"{self.__class__.__name__}: Queue reached max size, triggering flush" + ) + self._flush_queue() - event.trigger_event_id = event.trigger_event.id - event.trigger_event_type = event.trigger_event.event_type - self._add_event(event.trigger_event.__dict__) - event.trigger_event = None # removes trigger_event from serialization + def _publish(self): + """Notify the ChangesObserverThread to perform the API call.""" + with self.conditions["changes"]: # Acquire the lock before notifying + self.conditions["changes"].notify() - self._add_event(event.__dict__) + def stop(self) -> None: + """ + Stops (ends) the session and initiates cleanup. + This is thread-safe + """ + with self._locks["lifecycle"]: + if not self._is_running: + return + + self._is_running = False + self._stop_flag.set() # Signal threads to stop + + # Stop threads + if hasattr(self, "publisher_thread") and self.publisher_thread: + logger.debug(f"{self.__class__.__name__}: Stopping publisher thread...") + self.publisher_thread.stop() + self.publisher_thread = None # Remove reference + + if hasattr(self, "observer_thread") and self.observer_thread: + logger.debug(f"{self.__class__.__name__}: Stopping observer thread...") + self.observer_thread.stop() + self.observer_thread = None # Remove reference + + # Flush any remaining events + if not self._events.empty(): + self._flush_queue() - def _add_event(self, event: dict) -> None: - with self.lock: - self.queue.append(event) + try: + if not self.end_timestamp: + self.end_session( + end_state="Indeterminate", + end_state_reason="Session terminated during cleanup", + ) + finally: + self._cleanup_done = True + self.conditions["cleanup"].notify_all() - if len(self.queue) >= self.config.max_queue_size: - self._flush_queue() + # Remove from active_sessions + active_sessions.discard(self) - def _reauthorize_jwt(self) -> Union[str, None]: - with self.lock: - payload = {"session_id": self.session_id} - serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") - res = HttpClient.post( - f"{self.config.endpoint}/v2/reauthorize_jwt", - serialized_payload, - self.config.api_key, - ) + logger.debug(f"{self.__class__.__name__}: Session stopped") - logger.debug(res.body) + def pause(self) -> None: + """ + Temporarily pause event processing without stopping the session. + """ + with self._locks["lifecycle"]: + if not self._is_running: + return + self._stop_flag.set() - if res.code != 200: - return None + def resume(self) -> None: + """ + Resume a paused session. + """ + with self._locks["lifecycle"]: + if not self._is_running: + return + self._stop_flag.clear() - jwt = res.body.get("jwt", None) - self.jwt = jwt - return jwt + def __del__(self): + """Ensure cleanup runs when object is garbage collected""" + try: + self.stop() + except Exception as e: + logger.error(f"Error during session cleanup: {e}") - def _start_session(self): - self.queue = [] - with self.lock: - payload = {"session": self.__dict__} - serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8") + def _flush_queue(self) -> None: + """Thread-safe queue flushing""" + with self._locks["events"]: + events = [] + queue_size = self._events.qsize() + logger.debug(f"Flushing queue with {queue_size} events") - try: - res = HttpClient.post( - f"{self.config.endpoint}/v2/create_session", - serialized_payload, - self.config.api_key, - self.config.parent_key, - ) - except ApiServerException as e: - return logger.error(f"Could not start session - {e}") + while not self._events.empty(): + try: + events.append(self._events.get_nowait()) + except queue.Empty: + break - logger.debug(res.body) + if events: + logger.debug(f"{self.__class__.__name__} batching {len(events)} events") + try: + self.api.batch(events) + except Exception as e: + logger.error(f"Failed to batch events during flush: {e}") - if res.code != 200: - return False + def _start_session(self) -> bool: + """ + Initializes and starts the session via API call. + Thread-safe method that sets up initial session state. - jwt = res.body.get("jwt", None) - self.jwt = jwt - if jwt is None: - return False + Returns: + bool: True if session started successfully, False otherwise + """ + with self._locks["lifecycle"]: + if self._is_running: + logger.warning("Session already running") + return True - session_url = res.body.get( - "session_url", - f"https://app.agentops.ai/drilldown?session_id={self.session_id}", - ) + # Use the API class to create session + success = self.api.create_session(self) + if success: + self._is_running = True + return True - logger.info( - colored( - f"\x1b[34mSession Replay: {session_url}\x1b[0m", - "blue", - ) - ) + return False + def start(self) -> bool: + """ + Start the session if it's not already running. + Returns True if session started successfully or was already running. + """ + if self.is_running: return True - def _update_session(self) -> None: - if not self.is_running: - return - with self.lock: - payload = {"session": self.__dict__} + success = self._start_session() + if success: + self.publisher_thread.start() + self.observer_thread.start() + return success + + +class _SessionThread(threading.Thread): + """Base class for session-related threads.""" + + def __init__(self, session: Session): + super().__init__() + self.s = session + self.daemon = True # Keep as daemon thread + self._local_stop = threading.Event() + + @property + def stopping(self) -> bool: + """Check if thread should stop""" + return self._local_stop.is_set() or self.s._stop_flag.is_set() + def stop(self) -> None: + # """Signal thread to stop and wait for completion""" + self._local_stop.set() + if self.is_alive(): try: - res = HttpClient.post( - f"{self.config.endpoint}/v2/update_session", - json.dumps(filter_unjsonable(payload)).encode("utf-8"), - jwt=self.jwt, + self.join(timeout=1.0) # Increased timeout slightly + if self.is_alive(): + logger.warning( + f"{self.__class__.__name__} thread did not stop cleanly" + ) + except RuntimeError: + # Handle case where thread is already stopped + pass + + +class EventPublisherThread(_SessionThread): + """Polls events from Session and publishes them in batches""" + + def __init__(self, session: Session): + super().__init__(session) + self._last_batch_time = time.monotonic() + self._batch = [] + self._batch_lock = threading.Lock() + + def run(self) -> None: + logger.debug(f"{self.__class__.__name__}: started") + while not self.stopping: + try: + current_time = time.monotonic() + should_publish = False + # with self._batch_lock: + batch_size = len(self._batch) + logger.debug( + f"{self.__class__.__name__} current batch size: {batch_size}" ) - except ApiServerException as e: - return logger.error(f"Could not update session - {e}") - def _flush_queue(self) -> None: - if not self.is_running: - return - with self.lock: - queue_copy = self.queue[:] # Copy the current items - self.queue = [] + # Try to collect events up to max batch size + while ( + len(self._batch) < self.s.config.max_queue_size + and not self.stopping + ): + try: + event = self.s._events.get_nowait() + if event: # Make sure we got a valid event + self._batch.append(event) + logger.debug( + f"{self.__class__.__name__} added event to batch: {event}" + ) + except queue.Empty: + break + + new_batch_size = len(self._batch) + if new_batch_size > batch_size: + logger.debug( + f"{self.__class__.__name__} added {new_batch_size - batch_size} events to batch" + ) - if len(queue_copy) > 0: - payload = { - "events": queue_copy, - } + # Determine if we should publish based on conditions + time_elapsed = current_time - self._last_batch_time + should_publish = ( + len(self._batch) >= self.s.config.max_queue_size + or ( # Batch is full + len(self._batch) > 0 + and time_elapsed # Have events and max time elapsed + >= self.s.config.max_wait_time / 1000 + ) + or ( + len(self._batch) > 0 + and (self.s._events.empty() or self.stopping) + ) # Have events and queue is empty or stopping + ) - serialized_payload = safe_serialize(payload).encode("utf-8") + if should_publish and self._batch: # Only publish if we have events + try: + logger.debug( + f"{self.__class__.__name__} publishing batch of {len(self._batch)} events: {self._batch}" + ) + self.s.api.batch(self._batch[:]) # Send a copy of the batch + self._batch.clear() + self._last_batch_time = current_time + except Exception as e: + logger.error( + f"{self.__class__.__name__} failed to publish batch: {e}" + ) + + if not should_publish and not self.stopping: + # Sleep briefly to prevent tight polling + time.sleep(0.1) + + except Exception as e: + logger.error(f"{self.__class__.__name__} error: {e}") + if not self.stopping: + time.sleep(0.1) + + def stop(self) -> None: + """Ensure any remaining events are published before stopping""" + super().stop() # Set stop flag first + + with self._batch_lock: + if self._batch: try: - HttpClient.post( - f"{self.config.endpoint}/v2/create_events", - serialized_payload, - jwt=self.jwt, + logger.debug( + f"{self.__class__.__name__} final batch of {len(self._batch)} events" + ) + self.s.api.batch(self._batch) + except Exception as e: + logger.error( + f"{self.__class__.__name__} failed to publish final batch during shutdown: {e}" ) - except ApiServerException as e: - return logger.error(f"Could not post events - {e}") + finally: + self._batch.clear() - logger.debug("\n") - logger.debug( - f"Session request to {self.config.endpoint}/v2/create_events" - ) - logger.debug(serialized_payload) - logger.debug("\n") - - # Count total events created based on type - events = payload["events"] - for event in events: - event_type = event["event_type"] - if event_type == "llms": - self.event_counts["llms"] += 1 - elif event_type == "tools": - self.event_counts["tools"] += 1 - elif event_type == "actions": - self.event_counts["actions"] += 1 - elif event_type == "errors": - self.event_counts["errors"] += 1 - elif event_type == "apis": - self.event_counts["apis"] += 1 - - def _run(self) -> None: - while not self.stop_flag.is_set(): - time.sleep(self.config.max_wait_time / 1000) - if self.queue: - self._flush_queue() - def create_agent(self, name, agent_id): - if not self.is_running: - return - if agent_id is None: - agent_id = str(uuid4()) +class ChangesObserverThread(_SessionThread): + """Observes changes in the session and performs API calls for event publishing.""" - payload = { - "id": agent_id, - "name": name, - } + def run(self) -> None: + logger.debug(f"{self.__class__.__name__}: started") + while not self.stopping: + try: + # Wait for explicit notification instead of continuous polling + with self.s.conditions["changes"]: + # Use wait with timeout to allow checking stopping condition + self.s.conditions["changes"].wait(timeout=0.5) - serialized_payload = safe_serialize(payload).encode("utf-8") - try: - HttpClient.post( - f"{self.config.endpoint}/v2/create_agent", - serialized_payload, - jwt=self.jwt, - ) - except ApiServerException as e: - return logger.error(f"Could not create agent - {e}") + if self.stopping: + break + + # Only update if explicitly notified (not due to timeout) + with self.s._locks["session"]: + if not self.stopping: + self.s.api.update_session() + + except Exception as e: + logger.error(f"{self.__class__.__name__} error: {e}") + if self.stopping: + break + + logger.debug(f"{self.__class__.__name__} exited") + + def stop(self) -> None: + """Signal thread to stop""" + logger.debug(f"{self.__class__.__name__} stopping") + self._local_stop.set() + if self.is_alive(): + self.join(timeout=0.5) - return agent_id - def patch(self, func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - kwargs["session"] = self - return func(*args, **kwargs) +active_sessions = SessionsCollection() - return wrapper +__all__ = ["Session"] -active_sessions: List[Session] = [] +if __name__ == "__main__": + session = Session(uuid4(), config=Configuration()) + try: + # Use session... + session.pause() # Temporarily pause processing + # Do something... + session.resume() # Resume processing + finally: + session.stop() # Explicit cleanup diff --git a/pyproject.toml b/pyproject.toml index ea3df6d1..6b2286cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,3 +49,4 @@ agentops = "agentops.cli:main" [tool.pytest.ini_options] asyncio_mode = "strict" asyncio_default_fixture_loop_scope = "function" +addopts = "-s" diff --git a/test.py b/test.py new file mode 100644 index 00000000..9cddc60f --- /dev/null +++ b/test.py @@ -0,0 +1,38 @@ +import threading + +condition = threading.Condition() +shared_data = [] + +# Producer function +def producer(): + with condition: + shared_data.append("New Data") + print("Produced data, notifying consumers.") + condition.notify() # Notify one waiting consumer + +# Consumer function +def consumer(): + with condition: + condition.wait() # Wait until notified by the producer + print("Consumed:", shared_data.pop()) + +# Start threads +threading.Thread(target=consumer).start() +threading.Thread(target=producer).start() +# +# +# +# +# import os +# import time +# +# import agentops +# import dotenv +# +# dotenv.load_dotenv() +# +# agentops.init(api_key=os.environ["AGENTOPS_API_KEY"]) +# +# +# while True: +# time.sleep(0.2) diff --git a/tests/test_session.py b/tests/test_session.py index b40fd04e..a63802f2 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -1,11 +1,24 @@ +import time +from uuid import uuid4 + import pytest +from requests import patch import requests_mock -import time + import agentops +from agentops.config import Configuration from agentops import ActionEvent, Client +from agentops.config import Configuration +from agentops.session import EventPublisherThread, Session, active_sessions from agentops.singleton import clear_singletons +from unittest.mock import MagicMock + +import logging +agentops.logger.setLevel(logging.DEBUG) # TODO: Move into client initialization + + @pytest.fixture(autouse=True) def setup_teardown(mock_req): clear_singletons() @@ -58,7 +71,7 @@ def test_session(self, mock_req): assert len(mock_req.request_history) == 3 time.sleep(0.15) - assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt" + assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt" request_json = mock_req.last_request.json() assert request_json["events"][0]["event_type"] == self.event_type @@ -68,7 +81,7 @@ def test_session(self, mock_req): # We should have 4 requests (additional end session) assert len(mock_req.request_history) == 4 - assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt" + assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt" request_json = mock_req.last_request.json() assert request_json["session"]["end_state"] == end_state assert len(request_json["session"]["tags"]) == 0 @@ -217,7 +230,7 @@ def test_two_sessions(self, mock_req): # 5 requests: check_for_updates, 2 start_session, 2 record_event assert len(mock_req.request_history) == 5 - assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt" + assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt" request_json = mock_req.last_request.json() assert request_json["events"][0]["event_type"] == self.event_type @@ -228,7 +241,7 @@ def test_two_sessions(self, mock_req): # Additional end session request assert len(mock_req.request_history) == 6 - assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt" + assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt" request_json = mock_req.last_request.json() assert request_json["session"]["end_state"] == end_state assert len(request_json["session"]["tags"]) == 0 @@ -236,7 +249,7 @@ def test_two_sessions(self, mock_req): session_2.end_session(end_state) # Additional end session request assert len(mock_req.request_history) == 7 - assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt" + assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt" request_json = mock_req.last_request.json() assert request_json["session"]["end_state"] == end_state assert len(request_json["session"]["tags"]) == 0 @@ -284,3 +297,177 @@ def test_add_tags(self, mock_req): "session-2", "session-2-added", ] + + +# +# +# class TestSessionInterrupts: +# def setup_method(self): +# self.api_key = "11111111-1111-4111-8111-111111111111" +# self.event_type = "test_event_type" +# agentops.init(api_key=self.api_key, max_wait_time=50, auto_start_session=False) +# +# def test_cleanup_on_interrupt(self, mock_req): +# session = agentops.start_session() +# assert session is not None +# +# # Record some events +# for _ in range(3): # Add a few events +# session.record(ActionEvent(self.event_type)) +# +# # Force cleanup +# session._cleanup() +# +# # Short wait to allow any pending operations +# time.sleep(0.1) +# +# # Verify session is fully stopped +# assert not session.is_running +# assert session._cleanup_done +# assert session.end_timestamp is not None +# +# # Try recording after cleanup - should be ignored +# session.record(ActionEvent(self.event_type)) +# +# # Verify we can exit cleanly +# agentops.end_all_sessions() +# +# def test_multiple_cleanup_attempts(self, mock_req): +# session = agentops.start_session() +# assert session is not None +# +# # First cleanup +# session._cleanup() +# first_end_timestamp = session.end_timestamp +# +# # Try cleaning up again +# time.sleep(0.1) +# session._cleanup() +# +# # Verify the end timestamp didn't change +# assert session.end_timestamp == first_end_timestamp +# +# # Verify only one end session request was made +# end_session_requests = [ +# req for req in mock_req.request_history if "update_session" in req.url +# ] +# assert len(end_session_requests) == 1 +# +# def test_thread_safety_during_cleanup(self, mock_req): +# session = agentops.start_session() +# assert session is not None +# +# # Start recording events +# for _ in range(5): +# session.record(ActionEvent(self.event_type)) +# +# # Cleanup while events are being processed +# session._cleanup() +# +# # Verify all events were flushed +# time.sleep(0.15) # Allow async operations to complete +# +# event_requests = [ +# req for req in mock_req.request_history if "create_events" in req.url +# ] +# total_events = sum(len(req.json()["events"]) for req in event_requests) +# assert total_events == 5 # All events should have been sent + +@pytest.fixture(autouse=True) +def cleanup_sessions(): + yield + active_sessions.clear() # clear is sufficient; __del__ takes care of stopping gracefully + +def test_event_publisher_thread_run(): + # Setup + session = Session(session_id=uuid4(), config=Configuration()) + active_sessions.add(session) # Add to active sessions list + + # Mock the API batch method to track calls + session.api.batch = MagicMock() + + # Ensure session is running before starting tests + assert session.start() + + # Create some test events + test_events = [ + {"id": str(uuid4()), "event_type": "test", "data": f"event_{i}"} + for i in range(5) + ] + + # Test Case 1: Batch size trigger + session.config.max_queue_size = 3 # Small batch size for testing + + # Use session's publisher thread instead of creating new ones + publisher = session.publisher_thread + + for event in test_events[:3]: + session._enqueue(event) + + # Wait briefly for processing + time.sleep(0.2) + + # Verify batch was published due to size + session.api.batch.assert_called_once() + published_events = session.api.batch.call_args[0][0] + assert len(published_events) == 3 + assert all(e["data"] in [f"event_{i}" for i in range(3)] for e in published_events) + + # Reset for next test + session.api.batch.reset_mock() + + # Test Case 2: Time trigger + session.config.max_wait_time = 100 # 100ms wait time + session._enqueue(test_events[3]) + + # Wait for time trigger + time.sleep(0.2) + + # Verify batch was published due to time + session.api.batch.assert_called_once() + published_events = session.api.batch.call_args[0][0] + assert len(published_events) == 1 + assert published_events[0]["data"] == "event_3" + + # Reset for final test + session.api.batch.reset_mock() + + # Test Case 3: Empty queue trigger + session._enqueue(test_events[4]) + + # Wait briefly for processing + time.sleep(0.2) + + # Verify batch was published + session.api.batch.assert_called_once() + published_events = session.api.batch.call_args[0][0] + assert len(published_events) == 1 + assert published_events[0]["data"] == "event_4" + + # Cleanup + if session in active_sessions: + active_sessions.discard(session) + session.stop() + +def test_event_publisher_thread_error_handling(): + # Setup + session = Session(session_id=uuid4(), config=Configuration()) + + # Ensure session is running + assert session.start() + + # Mock API to raise an exception + session.api.batch = MagicMock(side_effect=Exception("API Error")) + + # Add test event + test_event = {"id": str(uuid4()), "event_type": "test", "data": "error_test"} + session._enqueue(test_event) + + # Wait briefly for processing + time.sleep(0.2) + + # Verify the API was called + session.api.batch.assert_called_once() + + # Cleanup + session.stop()