diff --git a/agentops/client.py b/agentops/client.py
index 80cb1cc5..c3928786 100644
--- a/agentops/client.py
+++ b/agentops/client.py
@@ -36,7 +36,7 @@ def __init__(self):
self._pre_init_messages: List[str] = []
self._initialized: bool = False
self._llm_tracker: Optional[LlmTracker] = None
- self._sessions: List[Session] = active_sessions
+ self._sessions: List[Session] = list(active_sessions)
self._config = Configuration()
self._pre_init_queue = {"agents": []}
diff --git a/agentops/config.py b/agentops/config.py
index 7dfb574d..ad1019d8 100644
--- a/agentops/config.py
+++ b/agentops/config.py
@@ -1,74 +1,53 @@
-from typing import List, Optional
+from dataclasses import dataclass, field
+from typing import Optional, Set
from uuid import UUID
from .log_config import logger
+# TODO: Use annotations to clarify the purpose of each attribute.
+# Details are defined in a docstrings found in __init__.py, but
+# it's good to have those right on the fields at class definition
+@dataclass
class Configuration:
- def __init__(self):
- self.api_key: Optional[str] = None
- self.parent_key: Optional[str] = None
- self.endpoint: str = "https://api.agentops.ai"
- self.max_wait_time: int = 5000
- self.max_queue_size: int = 512
- self.default_tags: set[str] = set()
- self.instrument_llm_calls: bool = True
- self.auto_start_session: bool = True
- self.skip_auto_end_session: bool = False
- self.env_data_opt_out: bool = False
+ api_key: Optional[str] = None
+ parent_key: Optional[str] = None
+ endpoint: str = "https://api.agentops.ai"
+ max_wait_time: int = 5000
+ max_queue_size: int = 512
+ graceful_shutdown_wait_time: int = 2000
+ default_tags: Set[str] = field(default_factory=set)
+ instrument_llm_calls: bool = True
+ auto_start_session: bool = True
+ skip_auto_end_session: bool = False
+ env_data_opt_out: bool = False
def configure(
self,
client,
- api_key: Optional[str] = None,
- parent_key: Optional[str] = None,
- endpoint: Optional[str] = None,
- max_wait_time: Optional[int] = None,
- max_queue_size: Optional[int] = None,
- default_tags: Optional[List[str]] = None,
- instrument_llm_calls: Optional[bool] = None,
- auto_start_session: Optional[bool] = None,
- skip_auto_end_session: Optional[bool] = None,
- env_data_opt_out: Optional[bool] = None,
+ **kwargs
):
- if api_key is not None:
- try:
- UUID(api_key)
- self.api_key = api_key
- except ValueError:
- message = f"API Key is invalid: {{{api_key}}}.\n\t Find your API key at https://app.agentops.ai/settings/projects"
- client.add_pre_init_warning(message)
- logger.error(message)
-
- if parent_key is not None:
- try:
- UUID(parent_key)
- self.parent_key = parent_key
- except ValueError:
- message = f"Parent Key is invalid: {parent_key}"
- client.add_pre_init_warning(message)
- logger.warning(message)
-
- if endpoint is not None:
- self.endpoint = endpoint
-
- if max_wait_time is not None:
- self.max_wait_time = max_wait_time
-
- if max_queue_size is not None:
- self.max_queue_size = max_queue_size
-
- if default_tags is not None:
- self.default_tags.update(default_tags)
-
- if instrument_llm_calls is not None:
- self.instrument_llm_calls = instrument_llm_calls
-
- if auto_start_session is not None:
- self.auto_start_session = auto_start_session
-
- if skip_auto_end_session is not None:
- self.skip_auto_end_session = skip_auto_end_session
-
- if env_data_opt_out is not None:
- self.env_data_opt_out = env_data_opt_out
+ # Special handling for keys that need UUID validation
+ for key_name in ['api_key', 'parent_key']:
+ if key_name in kwargs and kwargs[key_name] is not None:
+ try:
+ UUID(kwargs[key_name])
+ setattr(self, key_name, kwargs[key_name])
+ except ValueError:
+ message = (
+ f"API Key is invalid: {{{kwargs[key_name]}}}.\n\t Find your API key at https://app.agentops.ai/settings/projects"
+ if key_name == 'api_key'
+ else f"Parent Key is invalid: {kwargs[key_name]}"
+ )
+ client.add_pre_init_warning(message)
+ logger.error(message) if key_name == 'api_key' else logger.warning(message)
+ kwargs.pop(key_name)
+
+ # Special handling for default_tags which needs update() instead of assignment
+ if 'default_tags' in kwargs and kwargs['default_tags'] is not None:
+ self.default_tags.update(kwargs.pop('default_tags'))
+
+ # Handle all other attributes
+ for key, value in kwargs.items():
+ if value is not None and hasattr(self, key):
+ setattr(self, key, value)
diff --git a/agentops/session.py b/agentops/session.py
index 58225b6c..60669702 100644
--- a/agentops/session.py
+++ b/agentops/session.py
@@ -1,24 +1,224 @@
-import copy
-import functools
+from __future__ import annotations # Allow forward references
+
+import datetime as dt
import json
+import queue
import threading
import time
from decimal import ROUND_HALF_UP, Decimal
-from termcolor import colored
-from typing import Optional, List, Union
+from typing import Annotated, Dict, List, Optional, Union
from uuid import UUID, uuid4
-from datetime import datetime
-from .exceptions import ApiServerException
-from .enums import EndState
-from .event import ErrorEvent, Event
-from .log_config import logger
+from termcolor import colored
+
from .config import Configuration
-from .helpers import get_ISO_time, filter_unjsonable, safe_serialize
+from .enums import EndState, EventType
+from .event import ErrorEvent, Event
+from .exceptions import ApiServerException
+from .helpers import filter_unjsonable, get_ISO_time, safe_serialize
from .http_client import HttpClient
+from .log_config import logger
+"""
+
+minor changes:
+- `active_sessions` is now a WeakSet. Cleanup is not required. Tests can be simplified
+
+- Removed unsafe usage of __dict__ (which is very dangerous when used without __slots__)
+- Introduced Session.active property
+"""
+
+from typing import DefaultDict
+from weakref import WeakSet
+
+
+class SessionDict(DefaultDict):
+ session_id: UUID
+ # --------------
+ config: Configuration
+ end_state: str = EndState.INDETERMINATE.value
+ end_state_reason: Optional[str] = None
+ end_timestamp: Optional[str] = None
+ # Create a counter dictionary with each EventType name initialized to 0
+ event_counts: Dict[str, int]
+ host_env: Optional[dict] = None
+ init_timestamp: str # Will be set to get_ISO_time() during __init__
+ is_running: bool = False
+ jwt: Optional[str] = None
+ tags: Optional[List[str]] = None
+ video: Optional[str] = None
+
+ def __init__(self, **kwargs):
+ kwargs.setdefault(
+ "event_counts", {event_type.value: 0 for event_type in EventType}
+ )
+ kwargs.setdefault("init_timestamp", get_ISO_time())
+ super().__init__(**kwargs)
-class Session:
+
+class SessionsCollection(WeakSet):
+ """
+ A custom collection for managing Session objects that combines WeakSet's automatic cleanup
+ with list-like indexing capabilities.
+
+ This class is needed because:
+ 1. We want WeakSet's automatic cleanup of unreferenced sessions
+ 2. We need to access sessions by index (e.g., self._sessions[0]) for backwards compatibility
+ 3. Standard WeakSet doesn't support indexing
+ """
+
+ def __getitem__(self, index: int) -> Session:
+ """
+ Enable indexing into the collection (e.g., sessions[0]).
+ """
+ # Convert to list for indexing since sets aren't ordered
+ items = list(self)
+ return items[index]
+
+ def __iter__(self):
+ """
+ Override the default iterator to yield sessions sorted by init_timestamp.
+ If init_timestamp is not available, fall back to __create_ts.
+
+ WARNING: Using __create_ts as a fallback for ordering may lead to unexpected results
+ if init_timestamp is not set correctly.
+ """
+ return iter(sorted(super().__iter__(), key=lambda session: (
+ session.init_timestamp if hasattr(session, 'init_timestamp') else session.__create_ts
+ )))
+
+class SessionApi:
+ """
+ Solely focuses on interacting with the API
+
+ Developer notes:
+ Need to clarify (and define) a standard and consistent Api interface.
+
+ The way it can be approached is by having a base `Api` class that holds common
+ configurations and features, while implementors provide entity-related controllers
+ """
+
+ # TODO: Decouple from standard Configuration a Session's entity own configuration.
+ # NOTE: pydantic-settings plays out beautifully in such setup, but it's not a requirement.
+ # TODO: Eventually move to apis/
+ session: Session
+
+ def __init__(self, session: Session):
+ self.session = session
+
+ @property
+ def config(self): # Forward decl.
+ return self.session.config
+
+ @property
+ def jwt(self) -> Optional[str]:
+ """Convenience property that falls back to dictionary access"""
+ return self.session.get("jwt")
+
+ @jwt.setter
+ def jwt(self, value: Optional[str]):
+ self.session["jwt"] = value
+
+ def update_session(self) -> None:
+ try:
+ payload = {"session": dict(self.session)}
+ res = HttpClient.post(
+ f"{self.config.endpoint}/v2/update_session",
+ json.dumps(filter_unjsonable(payload)).encode("utf-8"),
+ jwt=self.jwt,
+ )
+ except ApiServerException as e:
+ return logger.error(f"Could not update session - {e}")
+
+ # WARN: Unused method
+ def reauthorize_jwt(self) -> Union[str, None]:
+ payload = {"session_id": self.session.session_id}
+ serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8")
+ res = HttpClient.post(
+ f"{self.config.endpoint}/v2/reauthorize_jwt",
+ serialized_payload,
+ self.config.api_key,
+ )
+
+ logger.debug(res.body)
+
+ if res.code != 200:
+ return None
+
+ jwt = res.body.get("jwt", None)
+ self.jwt = jwt
+ return jwt
+
+ def create_session(self, session: SessionDict):
+ """
+ Creates a new session via API call
+
+ Returns:
+ tuple containing:
+ - success (bool): Whether the creation was successful
+ - jwt (Optional[str]): JWT token if successful
+ - session_url (Optional[str]): URL to view the session if successful
+ """
+ payload = {"session": dict(session)}
+ serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8")
+
+ try:
+ res = HttpClient.post(
+ f"{self.config.endpoint}/v2/create_session",
+ serialized_payload,
+ self.config.api_key,
+ self.config.parent_key,
+ )
+ except ApiServerException as e:
+ logger.error(f"Could not start session - {e}")
+ return False
+ else:
+ if res.code != 200:
+ return False
+
+ jwt = res.body.get("jwt", None)
+ self.jwt = jwt
+ if jwt is None:
+ return False
+
+ session_url = res.body.get(
+ "session_url",
+ f"https://app.agentops.ai/drilldown?session_id={session.session_id}",
+ )
+
+ logger.info(
+ colored(
+ f"\x1b[34mSession Replay: {session_url}\x1b[0m",
+ "blue",
+ )
+ )
+
+ return True
+
+ def batch(self, events: List[Event]) -> None:
+ serialized_payload = safe_serialize(dict(events=events)).encode("utf-8")
+ try:
+ HttpClient.post(
+ f"{self.config.endpoint}/v2/create_events",
+ serialized_payload,
+ jwt=self.jwt,
+ )
+ except ApiServerException as e:
+ return logger.error(f"Could not post events - {e}")
+
+ # Update event counts on the session instance
+ for event in events:
+ event_type = event["event_type"]
+ if event_type in self.session["event_counts"]:
+ self.session["event_counts"][event_type] += 1
+
+ logger.debug("\n")
+ logger.debug(f"Session request to {self.config.endpoint}/v2/create_events")
+ logger.debug(serialized_payload)
+ logger.debug("\n")
+
+
+class Session(SessionDict):
"""
Represents a session of events, with a start and end state.
@@ -28,48 +228,96 @@ class Session:
Attributes:
init_timestamp (float): The timestamp for when the session started, represented as seconds since the epoch.
- end_timestamp (float, optional): The timestamp for when the session ended, represented as seconds since the epoch. This is only set after end_session is called.
- end_state (str, optional): The final state of the session. Suggested: "Success", "Fail", "Indeterminate". Defaults to "Indeterminate".
+ end_timestamp (float, optional): The timestamp for when the session ended, represented as seconds since the epoch.
+ end_state (str, optional): The final state of the session. Defaults to "Indeterminate".
end_state_reason (str, optional): The reason for ending the session.
-
"""
- def __init__(
- self,
- session_id: UUID,
- config: Configuration,
- tags: Optional[List[str]] = None,
- host_env: Optional[dict] = None,
- ):
- self.end_timestamp = None
- self.end_state: Optional[str] = "Indeterminate"
- self.session_id = session_id
- self.init_timestamp = get_ISO_time()
- self.tags: List[str] = tags or []
- self.video: Optional[str] = None
- self.end_state_reason: Optional[str] = None
- self.host_env = host_env
- self.config = config
- self.jwt = None
- self.lock = threading.Lock()
- self.queue = []
- self.event_counts = {
- "llms": 0,
- "tools": 0,
- "actions": 0,
- "errors": 0,
- "apis": 0,
+ thread: Annotated[
+ EventPublisherThread, "Publishes events to the API in a background thread."
+ ]
+
+ def __init__(self, **kwargs):
+ logger.debug(f"Initializing new Session with id {kwargs.get('session_id')}")
+
+ # Set creation timestamp
+ self.__create_ts = time.monotonic()
+
+ self.api = SessionApi(self)
+
+ self._locks = {
+ "lifecycle": threading.Lock(), # Controls session lifecycle operations
+ "events": threading.Lock(), # Protects event queue operations
+ "session": threading.Lock(), # Protects session state updates
+ "tags": threading.Lock(), # Protects tag modifications
+ "api": threading.Lock(), # Protects API calls
+ }
+
+ self.config = kwargs.pop(
+ "config", Configuration()
+ ) # config assigned at object level
+ logger.debug(f"{self.__class__.__name__}: Session locks initialized")
+
+ super().__init__(
+ **kwargs
+ ) # Initialize SessionStruct after `locks` were defined
+
+ # Initialize queue with max size from config
+ self._events = queue.Queue(maxsize=self.config.max_queue_size)
+ self._cleanup_done = False
+ self._stop_flag = threading.Event()
+
+ # Initialize conditions
+ self.conditions = {
+ "cleanup": threading.Condition(self._locks["lifecycle"]),
+ "changes": threading.Condition(self._locks["session"]),
}
- self.stop_flag = threading.Event()
- self.thread = threading.Thread(target=self._run)
- self.thread.daemon = True
- self.thread.start()
+ # Initialize threads
+ self.publisher_thread = EventPublisherThread(self)
+ self.observer_thread = ChangesObserverThread(self)
+
+ self._is_running = False
+
+ # Add session to active_sessions
+ active_sessions.add(self)
+
+ self._start_session()
+
+ if self.is_running:
+ # Only start threads if session started successfully
+ self.publisher_thread.start()
+ self.observer_thread.start()
+ else:
+ self.stop()
+
+ def __hash__(self) -> int:
+ return hash(self.session_id)
+
+ @property
+ def session_id(self) -> UUID:
+ return self["session_id"]
+
+ @property
+ def active(self) -> bool:
+ """Alias for is_running for backward compatibility"""
+ return self.is_running
+
+ @property
+ def is_running(self) -> bool:
+ """Thread-safe access to running state"""
+ with self._locks["lifecycle"]:
+ return self._is_running
- self.is_running = self._start_session()
- if self.is_running == False:
- self.stop_flag.set()
- self.thread.join(timeout=1)
+ @is_running.setter
+ def is_running(self, value: bool):
+ """Thread-safe modification of running state"""
+ with self._locks["lifecycle"]:
+ self._is_running = value
+
+ @property
+ def counter(self):
+ return self["event_counts"]
def set_video(self, video: str) -> None:
"""
@@ -80,34 +328,93 @@ def set_video(self, video: str) -> None:
"""
self.video = video
+ def add_tags(self, tags: List[str]) -> None:
+ """
+ Append to session tags at runtime.
+
+ Args:
+ tags (List[str]): The list of tags to append.
+ """
+ if not self.is_running:
+ return
+
+ if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)):
+ if isinstance(tags, str):
+ tags = [tags]
+
+ if self.tags is None:
+ self.tags = tags
+ else:
+ for tag in tags:
+ if tag not in self.tags:
+ self.tags.append(tag)
+
+ self._publish()
+
+ def set_tags(self, tags):
+ if not self.is_running:
+ return
+
+ if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)):
+ if isinstance(tags, str):
+ tags = [tags]
+
+ self.tags = tags
+ self._publish()
+
+ # --- Interactors
+ def record(self, event: Union[Event, ErrorEvent]):
+ if not self.is_running:
+ logger.debug(
+ f"{self.__class__.__name__}: Attempted to record event but session is not running"
+ )
+ return
+
+ logger.debug(f"Recording event: {event.event_type}")
+ if isinstance(event, Event):
+ if not event.end_timestamp or event.init_timestamp == event.end_timestamp:
+ event.end_timestamp = get_ISO_time() # WARN: Unrestricted assignment
+ elif isinstance(event, ErrorEvent):
+ if event.trigger_event:
+ if (
+ not event.trigger_event.end_timestamp
+ or event.trigger_event.init_timestamp
+ == event.trigger_event.end_timestamp
+ ):
+ event.trigger_event.end_timestamp = get_ISO_time()
+
+ event.trigger_event_id = event.trigger_event.id
+ event.trigger_event_type = event.trigger_event.event_type
+ self._enqueue(event.trigger_event.__dict__)
+ event.trigger_event = None # removes trigger_event from serialization
+ # ^^ NOTE: Consider memento https://refactoring.guru/design-patterns/memento/python/example
+
+ self._enqueue(
+ event.__dict__
+ ) # WARNING: This is very dangerous. Either define Event.__slots__ or turn Event into a dataclass
+
def end_session(
self,
end_state: str = "Indeterminate",
end_state_reason: Optional[str] = None,
video: Optional[str] = None,
) -> Union[Decimal, None]:
-
- if not self.is_running:
- return
-
if not any(end_state == state.value for state in EndState):
return logger.warning(
"Invalid end_state. Please use one of the EndState enums"
)
- self.end_timestamp = get_ISO_time()
- self.end_state = end_state
- self.end_state_reason = end_state_reason
- if video is not None:
- self.video = video
+ self["end_timestamp"] = get_ISO_time()
+ self["end_state"] = end_state or self.end_state
+ self["end_state_reason"] = end_state_reason or self.end_state_reason
- self.stop_flag.set()
- self.thread.join(timeout=1)
- self._flush_queue()
-
- def format_duration(start_time, end_time):
- start = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
- end = datetime.fromisoformat(end_time.replace("Z", "+00:00"))
+ def __calc_elapsed():
+ start = dt.datetime.fromisoformat(
+ self["init_timestamp"].replace("Z", "+00:00")
+ )
+ end = dt.datetime.fromisoformat(
+ self["end_timestamp"].replace("Z", "+00:00")
+ )
duration = end - start
hours, remainder = divmod(duration.total_seconds(), 3600)
@@ -122,8 +429,8 @@ def format_duration(start_time, end_time):
return " ".join(parts)
- with self.lock:
- payload = {"session": self.__dict__}
+ with self._locks["api"]:
+ payload = {"session": dict(self)}
try:
res = HttpClient.post(
f"{self.config.endpoint}/v2/update_session",
@@ -131,12 +438,13 @@ def format_duration(start_time, end_time):
jwt=self.jwt,
)
except ApiServerException as e:
- return logger.error(f"Could not end session - {e}")
+ logger.error(f"Could not end session - {e}")
+ return None
logger.debug(res.body)
token_cost = res.body.get("token_cost", "unknown")
- formatted_duration = format_duration(self.init_timestamp, self.end_timestamp)
+ formatted_duration = __calc_elapsed()
if token_cost == "unknown" or token_cost is None:
token_cost_d = Decimal(0)
@@ -155,10 +463,10 @@ def format_duration(start_time, end_time):
f"Session Stats - "
f"{colored('Duration:', attrs=['bold'])} {formatted_duration} | "
f"{colored('Cost:', attrs=['bold'])} ${formatted_cost} | "
- f"{colored('LLMs:', attrs=['bold'])} {self.event_counts['llms']} | "
- f"{colored('Tools:', attrs=['bold'])} {self.event_counts['tools']} | "
- f"{colored('Actions:', attrs=['bold'])} {self.event_counts['actions']} | "
- f"{colored('Errors:', attrs=['bold'])} {self.event_counts['errors']}"
+ f"{colored('LLMs:', attrs=['bold'])} {self.counter['llms']} | "
+ f"{colored('Tools:', attrs=['bold'])} {self.counter['tools']} | "
+ f"{colored('Actions:', attrs=['bold'])} {self.counter['actions']} | "
+ f"{colored('Errors:', attrs=['bold'])} {self.counter['errors']}"
)
logger.info(analytics)
@@ -174,227 +482,363 @@ def format_duration(start_time, end_time):
)
)
- active_sessions.remove(self)
+ active_sessions.discard(self)
return token_cost_d
- def add_tags(self, tags: List[str]) -> None:
- """
- Append to session tags at runtime.
-
- Args:
- tags (List[str]): The list of tags to append.
- """
+ def create_agent(
+ self, name: str, agent_id: Optional[str] = None
+ ) -> object: # FIXME: Is this `int`, `UUID`, or `str`?
if not self.is_running:
return
+ if agent_id is None:
+ agent_id = str(uuid4())
- if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)):
- if isinstance(tags, str):
- tags = [tags]
+ payload = {
+ "id": agent_id,
+ "name": name,
+ }
- if self.tags is None:
- self.tags = tags
- else:
- for tag in tags:
- if tag not in self.tags:
- self.tags.append(tag)
+ serialized_payload = safe_serialize(payload).encode("utf-8")
+ try:
+ HttpClient.post(
+ f"{self.config.endpoint}/v2/create_agent",
+ serialized_payload,
+ jwt=self.jwt,
+ )
+ except ApiServerException as e:
+ logger.error(f"Could not create agent - {e}")
- self._update_session()
+ return agent_id
- def set_tags(self, tags):
+ def _enqueue(self, event: dict) -> None:
+ """Thread-safe event enqueueing"""
if not self.is_running:
+ logger.warning("Attempted to enqueue event but session is not running")
return
- if not (isinstance(tags, list) and all(isinstance(item, str) for item in tags)):
- if isinstance(tags, str):
- tags = [tags]
-
- self.tags = tags
- self._update_session()
+ logger.debug(
+ f"{self.__class__.__name__} enqueueing event, current queue size: {self._events.qsize()}"
+ )
+ try:
+ self._events.put_nowait(
+ event
+ ) # Use put_nowait instead of directly accessing queue
+ logger.debug(
+ f"{self.__class__.__name__} successfully enqueued event, new size: {self._events.qsize()}"
+ )
+ except queue.Full:
+ logger.warning(
+ f"{self.__class__.__name__} event queue is full, event will be dropped"
+ )
- def record(self, event: Union[Event, ErrorEvent]):
- if not self.is_running:
- return
- if isinstance(event, Event):
- if not event.end_timestamp or event.init_timestamp == event.end_timestamp:
- event.end_timestamp = get_ISO_time()
- elif isinstance(event, ErrorEvent):
- if event.trigger_event:
- if (
- not event.trigger_event.end_timestamp
- or event.trigger_event.init_timestamp
- == event.trigger_event.end_timestamp
- ):
- event.trigger_event.end_timestamp = get_ISO_time()
+ if self._events.qsize() >= self.config.max_queue_size:
+ logger.debug(
+ f"{self.__class__.__name__}: Queue reached max size, triggering flush"
+ )
+ self._flush_queue()
- event.trigger_event_id = event.trigger_event.id
- event.trigger_event_type = event.trigger_event.event_type
- self._add_event(event.trigger_event.__dict__)
- event.trigger_event = None # removes trigger_event from serialization
+ def _publish(self):
+ """Notify the ChangesObserverThread to perform the API call."""
+ with self.conditions["changes"]: # Acquire the lock before notifying
+ self.conditions["changes"].notify()
- self._add_event(event.__dict__)
+ def stop(self) -> None:
+ """
+ Stops (ends) the session and initiates cleanup.
+ This is thread-safe
+ """
+ with self._locks["lifecycle"]:
+ if not self._is_running:
+ return
+
+ self._is_running = False
+ self._stop_flag.set() # Signal threads to stop
+
+ # Stop threads
+ if hasattr(self, "publisher_thread") and self.publisher_thread:
+ logger.debug(f"{self.__class__.__name__}: Stopping publisher thread...")
+ self.publisher_thread.stop()
+ self.publisher_thread = None # Remove reference
+
+ if hasattr(self, "observer_thread") and self.observer_thread:
+ logger.debug(f"{self.__class__.__name__}: Stopping observer thread...")
+ self.observer_thread.stop()
+ self.observer_thread = None # Remove reference
+
+ # Flush any remaining events
+ if not self._events.empty():
+ self._flush_queue()
- def _add_event(self, event: dict) -> None:
- with self.lock:
- self.queue.append(event)
+ try:
+ if not self.end_timestamp:
+ self.end_session(
+ end_state="Indeterminate",
+ end_state_reason="Session terminated during cleanup",
+ )
+ finally:
+ self._cleanup_done = True
+ self.conditions["cleanup"].notify_all()
- if len(self.queue) >= self.config.max_queue_size:
- self._flush_queue()
+ # Remove from active_sessions
+ active_sessions.discard(self)
- def _reauthorize_jwt(self) -> Union[str, None]:
- with self.lock:
- payload = {"session_id": self.session_id}
- serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8")
- res = HttpClient.post(
- f"{self.config.endpoint}/v2/reauthorize_jwt",
- serialized_payload,
- self.config.api_key,
- )
+ logger.debug(f"{self.__class__.__name__}: Session stopped")
- logger.debug(res.body)
+ def pause(self) -> None:
+ """
+ Temporarily pause event processing without stopping the session.
+ """
+ with self._locks["lifecycle"]:
+ if not self._is_running:
+ return
+ self._stop_flag.set()
- if res.code != 200:
- return None
+ def resume(self) -> None:
+ """
+ Resume a paused session.
+ """
+ with self._locks["lifecycle"]:
+ if not self._is_running:
+ return
+ self._stop_flag.clear()
- jwt = res.body.get("jwt", None)
- self.jwt = jwt
- return jwt
+ def __del__(self):
+ """Ensure cleanup runs when object is garbage collected"""
+ try:
+ self.stop()
+ except Exception as e:
+ logger.error(f"Error during session cleanup: {e}")
- def _start_session(self):
- self.queue = []
- with self.lock:
- payload = {"session": self.__dict__}
- serialized_payload = json.dumps(filter_unjsonable(payload)).encode("utf-8")
+ def _flush_queue(self) -> None:
+ """Thread-safe queue flushing"""
+ with self._locks["events"]:
+ events = []
+ queue_size = self._events.qsize()
+ logger.debug(f"Flushing queue with {queue_size} events")
- try:
- res = HttpClient.post(
- f"{self.config.endpoint}/v2/create_session",
- serialized_payload,
- self.config.api_key,
- self.config.parent_key,
- )
- except ApiServerException as e:
- return logger.error(f"Could not start session - {e}")
+ while not self._events.empty():
+ try:
+ events.append(self._events.get_nowait())
+ except queue.Empty:
+ break
- logger.debug(res.body)
+ if events:
+ logger.debug(f"{self.__class__.__name__} batching {len(events)} events")
+ try:
+ self.api.batch(events)
+ except Exception as e:
+ logger.error(f"Failed to batch events during flush: {e}")
- if res.code != 200:
- return False
+ def _start_session(self) -> bool:
+ """
+ Initializes and starts the session via API call.
+ Thread-safe method that sets up initial session state.
- jwt = res.body.get("jwt", None)
- self.jwt = jwt
- if jwt is None:
- return False
+ Returns:
+ bool: True if session started successfully, False otherwise
+ """
+ with self._locks["lifecycle"]:
+ if self._is_running:
+ logger.warning("Session already running")
+ return True
- session_url = res.body.get(
- "session_url",
- f"https://app.agentops.ai/drilldown?session_id={self.session_id}",
- )
+ # Use the API class to create session
+ success = self.api.create_session(self)
+ if success:
+ self._is_running = True
+ return True
- logger.info(
- colored(
- f"\x1b[34mSession Replay: {session_url}\x1b[0m",
- "blue",
- )
- )
+ return False
+ def start(self) -> bool:
+ """
+ Start the session if it's not already running.
+ Returns True if session started successfully or was already running.
+ """
+ if self.is_running:
return True
- def _update_session(self) -> None:
- if not self.is_running:
- return
- with self.lock:
- payload = {"session": self.__dict__}
+ success = self._start_session()
+ if success:
+ self.publisher_thread.start()
+ self.observer_thread.start()
+ return success
+
+
+class _SessionThread(threading.Thread):
+ """Base class for session-related threads."""
+
+ def __init__(self, session: Session):
+ super().__init__()
+ self.s = session
+ self.daemon = True # Keep as daemon thread
+ self._local_stop = threading.Event()
+
+ @property
+ def stopping(self) -> bool:
+ """Check if thread should stop"""
+ return self._local_stop.is_set() or self.s._stop_flag.is_set()
+ def stop(self) -> None:
+ # """Signal thread to stop and wait for completion"""
+ self._local_stop.set()
+ if self.is_alive():
try:
- res = HttpClient.post(
- f"{self.config.endpoint}/v2/update_session",
- json.dumps(filter_unjsonable(payload)).encode("utf-8"),
- jwt=self.jwt,
+ self.join(timeout=1.0) # Increased timeout slightly
+ if self.is_alive():
+ logger.warning(
+ f"{self.__class__.__name__} thread did not stop cleanly"
+ )
+ except RuntimeError:
+ # Handle case where thread is already stopped
+ pass
+
+
+class EventPublisherThread(_SessionThread):
+ """Polls events from Session and publishes them in batches"""
+
+ def __init__(self, session: Session):
+ super().__init__(session)
+ self._last_batch_time = time.monotonic()
+ self._batch = []
+ self._batch_lock = threading.Lock()
+
+ def run(self) -> None:
+ logger.debug(f"{self.__class__.__name__}: started")
+ while not self.stopping:
+ try:
+ current_time = time.monotonic()
+ should_publish = False
+ # with self._batch_lock:
+ batch_size = len(self._batch)
+ logger.debug(
+ f"{self.__class__.__name__} current batch size: {batch_size}"
)
- except ApiServerException as e:
- return logger.error(f"Could not update session - {e}")
- def _flush_queue(self) -> None:
- if not self.is_running:
- return
- with self.lock:
- queue_copy = self.queue[:] # Copy the current items
- self.queue = []
+ # Try to collect events up to max batch size
+ while (
+ len(self._batch) < self.s.config.max_queue_size
+ and not self.stopping
+ ):
+ try:
+ event = self.s._events.get_nowait()
+ if event: # Make sure we got a valid event
+ self._batch.append(event)
+ logger.debug(
+ f"{self.__class__.__name__} added event to batch: {event}"
+ )
+ except queue.Empty:
+ break
+
+ new_batch_size = len(self._batch)
+ if new_batch_size > batch_size:
+ logger.debug(
+ f"{self.__class__.__name__} added {new_batch_size - batch_size} events to batch"
+ )
- if len(queue_copy) > 0:
- payload = {
- "events": queue_copy,
- }
+ # Determine if we should publish based on conditions
+ time_elapsed = current_time - self._last_batch_time
+ should_publish = (
+ len(self._batch) >= self.s.config.max_queue_size
+ or ( # Batch is full
+ len(self._batch) > 0
+ and time_elapsed # Have events and max time elapsed
+ >= self.s.config.max_wait_time / 1000
+ )
+ or (
+ len(self._batch) > 0
+ and (self.s._events.empty() or self.stopping)
+ ) # Have events and queue is empty or stopping
+ )
- serialized_payload = safe_serialize(payload).encode("utf-8")
+ if should_publish and self._batch: # Only publish if we have events
+ try:
+ logger.debug(
+ f"{self.__class__.__name__} publishing batch of {len(self._batch)} events: {self._batch}"
+ )
+ self.s.api.batch(self._batch[:]) # Send a copy of the batch
+ self._batch.clear()
+ self._last_batch_time = current_time
+ except Exception as e:
+ logger.error(
+ f"{self.__class__.__name__} failed to publish batch: {e}"
+ )
+
+ if not should_publish and not self.stopping:
+ # Sleep briefly to prevent tight polling
+ time.sleep(0.1)
+
+ except Exception as e:
+ logger.error(f"{self.__class__.__name__} error: {e}")
+ if not self.stopping:
+ time.sleep(0.1)
+
+ def stop(self) -> None:
+ """Ensure any remaining events are published before stopping"""
+ super().stop() # Set stop flag first
+
+ with self._batch_lock:
+ if self._batch:
try:
- HttpClient.post(
- f"{self.config.endpoint}/v2/create_events",
- serialized_payload,
- jwt=self.jwt,
+ logger.debug(
+ f"{self.__class__.__name__} final batch of {len(self._batch)} events"
+ )
+ self.s.api.batch(self._batch)
+ except Exception as e:
+ logger.error(
+ f"{self.__class__.__name__} failed to publish final batch during shutdown: {e}"
)
- except ApiServerException as e:
- return logger.error(f"Could not post events - {e}")
+ finally:
+ self._batch.clear()
- logger.debug("\n")
- logger.debug(
- f"Session request to {self.config.endpoint}/v2/create_events"
- )
- logger.debug(serialized_payload)
- logger.debug("\n")
-
- # Count total events created based on type
- events = payload["events"]
- for event in events:
- event_type = event["event_type"]
- if event_type == "llms":
- self.event_counts["llms"] += 1
- elif event_type == "tools":
- self.event_counts["tools"] += 1
- elif event_type == "actions":
- self.event_counts["actions"] += 1
- elif event_type == "errors":
- self.event_counts["errors"] += 1
- elif event_type == "apis":
- self.event_counts["apis"] += 1
-
- def _run(self) -> None:
- while not self.stop_flag.is_set():
- time.sleep(self.config.max_wait_time / 1000)
- if self.queue:
- self._flush_queue()
- def create_agent(self, name, agent_id):
- if not self.is_running:
- return
- if agent_id is None:
- agent_id = str(uuid4())
+class ChangesObserverThread(_SessionThread):
+ """Observes changes in the session and performs API calls for event publishing."""
- payload = {
- "id": agent_id,
- "name": name,
- }
+ def run(self) -> None:
+ logger.debug(f"{self.__class__.__name__}: started")
+ while not self.stopping:
+ try:
+ # Wait for explicit notification instead of continuous polling
+ with self.s.conditions["changes"]:
+ # Use wait with timeout to allow checking stopping condition
+ self.s.conditions["changes"].wait(timeout=0.5)
- serialized_payload = safe_serialize(payload).encode("utf-8")
- try:
- HttpClient.post(
- f"{self.config.endpoint}/v2/create_agent",
- serialized_payload,
- jwt=self.jwt,
- )
- except ApiServerException as e:
- return logger.error(f"Could not create agent - {e}")
+ if self.stopping:
+ break
+
+ # Only update if explicitly notified (not due to timeout)
+ with self.s._locks["session"]:
+ if not self.stopping:
+ self.s.api.update_session()
+
+ except Exception as e:
+ logger.error(f"{self.__class__.__name__} error: {e}")
+ if self.stopping:
+ break
+
+ logger.debug(f"{self.__class__.__name__} exited")
+
+ def stop(self) -> None:
+ """Signal thread to stop"""
+ logger.debug(f"{self.__class__.__name__} stopping")
+ self._local_stop.set()
+ if self.is_alive():
+ self.join(timeout=0.5)
- return agent_id
- def patch(self, func):
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- kwargs["session"] = self
- return func(*args, **kwargs)
+active_sessions = SessionsCollection()
- return wrapper
+__all__ = ["Session"]
-active_sessions: List[Session] = []
+if __name__ == "__main__":
+ session = Session(uuid4(), config=Configuration())
+ try:
+ # Use session...
+ session.pause() # Temporarily pause processing
+ # Do something...
+ session.resume() # Resume processing
+ finally:
+ session.stop() # Explicit cleanup
diff --git a/pyproject.toml b/pyproject.toml
index ea3df6d1..6b2286cc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -49,3 +49,4 @@ agentops = "agentops.cli:main"
[tool.pytest.ini_options]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
+addopts = "-s"
diff --git a/test.py b/test.py
new file mode 100644
index 00000000..9cddc60f
--- /dev/null
+++ b/test.py
@@ -0,0 +1,38 @@
+import threading
+
+condition = threading.Condition()
+shared_data = []
+
+# Producer function
+def producer():
+ with condition:
+ shared_data.append("New Data")
+ print("Produced data, notifying consumers.")
+ condition.notify() # Notify one waiting consumer
+
+# Consumer function
+def consumer():
+ with condition:
+ condition.wait() # Wait until notified by the producer
+ print("Consumed:", shared_data.pop())
+
+# Start threads
+threading.Thread(target=consumer).start()
+threading.Thread(target=producer).start()
+#
+#
+#
+#
+# import os
+# import time
+#
+# import agentops
+# import dotenv
+#
+# dotenv.load_dotenv()
+#
+# agentops.init(api_key=os.environ["AGENTOPS_API_KEY"])
+#
+#
+# while True:
+# time.sleep(0.2)
diff --git a/tests/test_session.py b/tests/test_session.py
index b40fd04e..a63802f2 100644
--- a/tests/test_session.py
+++ b/tests/test_session.py
@@ -1,11 +1,24 @@
+import time
+from uuid import uuid4
+
import pytest
+from requests import patch
import requests_mock
-import time
+
import agentops
+from agentops.config import Configuration
from agentops import ActionEvent, Client
+from agentops.config import Configuration
+from agentops.session import EventPublisherThread, Session, active_sessions
from agentops.singleton import clear_singletons
+from unittest.mock import MagicMock
+
+import logging
+agentops.logger.setLevel(logging.DEBUG) # TODO: Move into client initialization
+
+
@pytest.fixture(autouse=True)
def setup_teardown(mock_req):
clear_singletons()
@@ -58,7 +71,7 @@ def test_session(self, mock_req):
assert len(mock_req.request_history) == 3
time.sleep(0.15)
- assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt"
+ assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt"
request_json = mock_req.last_request.json()
assert request_json["events"][0]["event_type"] == self.event_type
@@ -68,7 +81,7 @@ def test_session(self, mock_req):
# We should have 4 requests (additional end session)
assert len(mock_req.request_history) == 4
- assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt"
+ assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt"
request_json = mock_req.last_request.json()
assert request_json["session"]["end_state"] == end_state
assert len(request_json["session"]["tags"]) == 0
@@ -217,7 +230,7 @@ def test_two_sessions(self, mock_req):
# 5 requests: check_for_updates, 2 start_session, 2 record_event
assert len(mock_req.request_history) == 5
- assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt"
+ assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt"
request_json = mock_req.last_request.json()
assert request_json["events"][0]["event_type"] == self.event_type
@@ -228,7 +241,7 @@ def test_two_sessions(self, mock_req):
# Additional end session request
assert len(mock_req.request_history) == 6
- assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt"
+ assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt"
request_json = mock_req.last_request.json()
assert request_json["session"]["end_state"] == end_state
assert len(request_json["session"]["tags"]) == 0
@@ -236,7 +249,7 @@ def test_two_sessions(self, mock_req):
session_2.end_session(end_state)
# Additional end session request
assert len(mock_req.request_history) == 7
- assert mock_req.last_request.headers["Authorization"] == f"Bearer some_jwt"
+ assert mock_req.last_request.headers["Authorization"] == "Bearer some_jwt"
request_json = mock_req.last_request.json()
assert request_json["session"]["end_state"] == end_state
assert len(request_json["session"]["tags"]) == 0
@@ -284,3 +297,177 @@ def test_add_tags(self, mock_req):
"session-2",
"session-2-added",
]
+
+
+#
+#
+# class TestSessionInterrupts:
+# def setup_method(self):
+# self.api_key = "11111111-1111-4111-8111-111111111111"
+# self.event_type = "test_event_type"
+# agentops.init(api_key=self.api_key, max_wait_time=50, auto_start_session=False)
+#
+# def test_cleanup_on_interrupt(self, mock_req):
+# session = agentops.start_session()
+# assert session is not None
+#
+# # Record some events
+# for _ in range(3): # Add a few events
+# session.record(ActionEvent(self.event_type))
+#
+# # Force cleanup
+# session._cleanup()
+#
+# # Short wait to allow any pending operations
+# time.sleep(0.1)
+#
+# # Verify session is fully stopped
+# assert not session.is_running
+# assert session._cleanup_done
+# assert session.end_timestamp is not None
+#
+# # Try recording after cleanup - should be ignored
+# session.record(ActionEvent(self.event_type))
+#
+# # Verify we can exit cleanly
+# agentops.end_all_sessions()
+#
+# def test_multiple_cleanup_attempts(self, mock_req):
+# session = agentops.start_session()
+# assert session is not None
+#
+# # First cleanup
+# session._cleanup()
+# first_end_timestamp = session.end_timestamp
+#
+# # Try cleaning up again
+# time.sleep(0.1)
+# session._cleanup()
+#
+# # Verify the end timestamp didn't change
+# assert session.end_timestamp == first_end_timestamp
+#
+# # Verify only one end session request was made
+# end_session_requests = [
+# req for req in mock_req.request_history if "update_session" in req.url
+# ]
+# assert len(end_session_requests) == 1
+#
+# def test_thread_safety_during_cleanup(self, mock_req):
+# session = agentops.start_session()
+# assert session is not None
+#
+# # Start recording events
+# for _ in range(5):
+# session.record(ActionEvent(self.event_type))
+#
+# # Cleanup while events are being processed
+# session._cleanup()
+#
+# # Verify all events were flushed
+# time.sleep(0.15) # Allow async operations to complete
+#
+# event_requests = [
+# req for req in mock_req.request_history if "create_events" in req.url
+# ]
+# total_events = sum(len(req.json()["events"]) for req in event_requests)
+# assert total_events == 5 # All events should have been sent
+
+@pytest.fixture(autouse=True)
+def cleanup_sessions():
+ yield
+ active_sessions.clear() # clear is sufficient; __del__ takes care of stopping gracefully
+
+def test_event_publisher_thread_run():
+ # Setup
+ session = Session(session_id=uuid4(), config=Configuration())
+ active_sessions.add(session) # Add to active sessions list
+
+ # Mock the API batch method to track calls
+ session.api.batch = MagicMock()
+
+ # Ensure session is running before starting tests
+ assert session.start()
+
+ # Create some test events
+ test_events = [
+ {"id": str(uuid4()), "event_type": "test", "data": f"event_{i}"}
+ for i in range(5)
+ ]
+
+ # Test Case 1: Batch size trigger
+ session.config.max_queue_size = 3 # Small batch size for testing
+
+ # Use session's publisher thread instead of creating new ones
+ publisher = session.publisher_thread
+
+ for event in test_events[:3]:
+ session._enqueue(event)
+
+ # Wait briefly for processing
+ time.sleep(0.2)
+
+ # Verify batch was published due to size
+ session.api.batch.assert_called_once()
+ published_events = session.api.batch.call_args[0][0]
+ assert len(published_events) == 3
+ assert all(e["data"] in [f"event_{i}" for i in range(3)] for e in published_events)
+
+ # Reset for next test
+ session.api.batch.reset_mock()
+
+ # Test Case 2: Time trigger
+ session.config.max_wait_time = 100 # 100ms wait time
+ session._enqueue(test_events[3])
+
+ # Wait for time trigger
+ time.sleep(0.2)
+
+ # Verify batch was published due to time
+ session.api.batch.assert_called_once()
+ published_events = session.api.batch.call_args[0][0]
+ assert len(published_events) == 1
+ assert published_events[0]["data"] == "event_3"
+
+ # Reset for final test
+ session.api.batch.reset_mock()
+
+ # Test Case 3: Empty queue trigger
+ session._enqueue(test_events[4])
+
+ # Wait briefly for processing
+ time.sleep(0.2)
+
+ # Verify batch was published
+ session.api.batch.assert_called_once()
+ published_events = session.api.batch.call_args[0][0]
+ assert len(published_events) == 1
+ assert published_events[0]["data"] == "event_4"
+
+ # Cleanup
+ if session in active_sessions:
+ active_sessions.discard(session)
+ session.stop()
+
+def test_event_publisher_thread_error_handling():
+ # Setup
+ session = Session(session_id=uuid4(), config=Configuration())
+
+ # Ensure session is running
+ assert session.start()
+
+ # Mock API to raise an exception
+ session.api.batch = MagicMock(side_effect=Exception("API Error"))
+
+ # Add test event
+ test_event = {"id": str(uuid4()), "event_type": "test", "data": "error_test"}
+ session._enqueue(test_event)
+
+ # Wait briefly for processing
+ time.sleep(0.2)
+
+ # Verify the API was called
+ session.api.batch.assert_called_once()
+
+ # Cleanup
+ session.stop()