diff --git a/agentops/session.py b/agentops/session.py index 99905eb8..cbe3297e 100644 --- a/agentops/session.py +++ b/agentops/session.py @@ -5,7 +5,6 @@ import queue import threading import time - from decimal import ROUND_HALF_UP, Decimal from typing import Annotated, Dict, List, Optional, Union from uuid import UUID, uuid4 @@ -23,12 +22,14 @@ """ minor changes: +- `active_sessions` is now a WeakSet. Cleanup is not required. Tests can be simplified - Removed unsafe usage of __dict__ (which is very dangerous when used without __slots__) - Introduced Session.active property """ from typing import DefaultDict +from weakref import WeakSet class SessionDict(DefaultDict): @@ -223,7 +224,9 @@ def __init__(self, **kwargs): ) # config assigned at object level logger.debug(f"{self.__class__.__name__}: Session locks initialized") - super().__init__(**kwargs) # Initialize SessionStruct after `locks` were defined + super().__init__( + **kwargs + ) # Initialize SessionStruct after `locks` were defined # Initialize queue with max size from config self._events = queue.Queue(maxsize=self.config.max_queue_size) @@ -242,6 +245,9 @@ def __init__(self, **kwargs): self._is_running = False + # Add session to active_sessions + active_sessions.add(self) + self._start_session() if self.is_running: @@ -251,6 +257,9 @@ def __init__(self, **kwargs): else: self.stop() + def __hash__(self) -> int: + return hash(self.session_id) + @property def session_id(self) -> UUID: return self["session_id"] @@ -322,7 +331,9 @@ def set_tags(self, tags): # --- Interactors def record(self, event: Union[Event, ErrorEvent]): if not self.is_running: - logger.debug(f"{self.__class__.__name__}: Attempted to record event but session is not running") + logger.debug( + f"{self.__class__.__name__}: Attempted to record event but session is not running" + ) return logger.debug(f"Recording event: {event.event_type}") @@ -350,7 +361,7 @@ def record(self, event: Union[Event, ErrorEvent]): def end_session( self, - end_state: str = "Indeterminate", + end_state: str = "Indeterminate", end_state_reason: Optional[str] = None, video: Optional[str] = None, ) -> Union[Decimal, None]: @@ -437,8 +448,7 @@ def __calc_elapsed(): ) ) - if self in active_sessions: - active_sessions.remove(self) + active_sessions.discard(self) return token_cost_d @@ -473,15 +483,25 @@ def _enqueue(self, event: dict) -> None: logger.warning("Attempted to enqueue event but session is not running") return - logger.debug(f"{self.__class__.__name__} enqueueing event, current queue size: {self._events.qsize()}") + logger.debug( + f"{self.__class__.__name__} enqueueing event, current queue size: {self._events.qsize()}" + ) try: - self._events.put_nowait(event) # Use put_nowait instead of directly accessing queue - logger.debug(f"{self.__class__.__name__} successfully enqueued event, new size: {self._events.qsize()}") + self._events.put_nowait( + event + ) # Use put_nowait instead of directly accessing queue + logger.debug( + f"{self.__class__.__name__} successfully enqueued event, new size: {self._events.qsize()}" + ) except queue.Full: - logger.warning(f"{self.__class__.__name__} event queue is full, event will be dropped") + logger.warning( + f"{self.__class__.__name__} event queue is full, event will be dropped" + ) if self._events.qsize() >= self.config.max_queue_size: - logger.debug(f"{self.__class__.__name__}: Queue reached max size, triggering flush") + logger.debug( + f"{self.__class__.__name__}: Queue reached max size, triggering flush" + ) self._flush_queue() def _publish(self): @@ -496,20 +516,20 @@ def stop(self) -> None: with self._locks["lifecycle"]: if not self._is_running: return - + self._is_running = False self._stop_flag.set() # Signal threads to stop # Stop threads - if hasattr(self, 'publisher_thread') and self.publisher_thread: + if hasattr(self, "publisher_thread") and self.publisher_thread: logger.debug(f"{self.__class__.__name__}: Stopping publisher thread...") self.publisher_thread.stop() - # del self.publisher_thread + self.publisher_thread = None # Remove reference - if hasattr(self, 'observer_thread') and self.observer_thread: + if hasattr(self, "observer_thread") and self.observer_thread: logger.debug(f"{self.__class__.__name__}: Stopping observer thread...") self.observer_thread.stop() - # del self.observer_thread + self.observer_thread = None # Remove reference # Flush any remaining events if not self._events.empty(): @@ -525,7 +545,10 @@ def stop(self) -> None: with self.conditions["cleanup"]: self._cleanup_done = True self.conditions["cleanup"].notify_all() - + + # Remove from active_sessions + active_sessions.discard(self) + logger.debug(f"{self.__class__.__name__}: Session stopped") def pause(self) -> None: @@ -610,7 +633,7 @@ def _start_session(self) -> bool: if success: self._is_running = True return True - + return False def start(self) -> bool: @@ -620,7 +643,7 @@ def start(self) -> bool: """ if self.is_running: return True - + success = self._start_session() if success: self.publisher_thread.start() @@ -634,7 +657,7 @@ class _SessionThread(threading.Thread): def __init__(self, session: Session): super().__init__() self.s = session - self.daemon = True + self.daemon = True # Keep as daemon thread self._local_stop = threading.Event() @property @@ -646,7 +669,15 @@ def stop(self) -> None: """Signal thread to stop and wait for completion""" self._local_stop.set() if self.is_alive(): - self.join(timeout=0.5) # Wait up to 0.5 seconds for thread to stop + try: + self.join(timeout=1.0) # Increased timeout slightly + if self.is_alive(): + logger.warning( + f"{self.__class__.__name__} thread did not stop cleanly" + ) + except RuntimeError: + # Handle case where thread is already stopped + pass class EventPublisherThread(_SessionThread): @@ -664,51 +695,60 @@ def run(self) -> None: try: current_time = time.monotonic() should_publish = False + # with self._batch_lock: + batch_size = len(self._batch) + logger.debug( + f"{self.__class__.__name__} current batch size: {batch_size}" + ) - with self._batch_lock: - batch_size = len(self._batch) - logger.debug(f"{self.__class__.__name__} current batch size: {batch_size}") - - # Try to collect events up to max batch size - while ( - len(self._batch) < self.s.config.max_queue_size - and not self.stopping - ): - try: - event = self.s._events.get_nowait() - if event: # Make sure we got a valid event - self._batch.append(event) - logger.debug(f"{self.__class__.__name__} added event to batch: {event}") - except queue.Empty: - break - - new_batch_size = len(self._batch) - if new_batch_size > batch_size: - logger.debug(f"{self.__class__.__name__} added {new_batch_size - batch_size} events to batch") - - # Determine if we should publish based on conditions - time_elapsed = current_time - self._last_batch_time - should_publish = ( - len(self._batch) >= self.s.config.max_queue_size - or ( # Batch is full - len(self._batch) > 0 - and time_elapsed # Have events and max time elapsed - >= self.s.config.max_wait_time / 1000 - ) - or ( - len(self._batch) > 0 - and (self.s._events.empty() or self.stopping) - ) # Have events and queue is empty or stopping + # Try to collect events up to max batch size + while ( + len(self._batch) < self.s.config.max_queue_size + and not self.stopping + ): + try: + event = self.s._events.get_nowait() + if event: # Make sure we got a valid event + self._batch.append(event) + logger.debug( + f"{self.__class__.__name__} added event to batch: {event}" + ) + except queue.Empty: + break + + new_batch_size = len(self._batch) + if new_batch_size > batch_size: + logger.debug( + f"{self.__class__.__name__} added {new_batch_size - batch_size} events to batch" ) - if should_publish and self._batch: # Only publish if we have events - try: - logger.debug(f"{self.__class__.__name__} publishing batch of {len(self._batch)} events: {self._batch}") - self.s.api.batch(self._batch[:]) # Send a copy of the batch - self._batch.clear() - self._last_batch_time = current_time - except Exception as e: - logger.error(f"{self.__class__.__name__} failed to publish batch: {e}") + # Determine if we should publish based on conditions + time_elapsed = current_time - self._last_batch_time + should_publish = ( + len(self._batch) >= self.s.config.max_queue_size + or ( # Batch is full + len(self._batch) > 0 + and time_elapsed # Have events and max time elapsed + >= self.s.config.max_wait_time / 1000 + ) + or ( + len(self._batch) > 0 + and (self.s._events.empty() or self.stopping) + ) # Have events and queue is empty or stopping + ) + + if should_publish and self._batch: # Only publish if we have events + try: + logger.debug( + f"{self.__class__.__name__} publishing batch of {len(self._batch)} events: {self._batch}" + ) + self.s.api.batch(self._batch[:]) # Send a copy of the batch + self._batch.clear() + self._last_batch_time = current_time + except Exception as e: + logger.error( + f"{self.__class__.__name__} failed to publish batch: {e}" + ) if not should_publish and not self.stopping: # Sleep briefly to prevent tight polling @@ -722,14 +762,18 @@ def run(self) -> None: def stop(self) -> None: """Ensure any remaining events are published before stopping""" super().stop() # Set stop flag first - + with self._batch_lock: if self._batch: try: - logger.debug(f"{self.__class__.__name__} final batch of {len(self._batch)} events") + logger.debug( + f"{self.__class__.__name__} final batch of {len(self._batch)} events" + ) self.s.api.batch(self._batch) except Exception as e: - logger.error(f"{self.__class__.__name__} failed to publish final batch during shutdown: {e}") + logger.error( + f"{self.__class__.__name__} failed to publish final batch during shutdown: {e}" + ) finally: self._batch.clear() @@ -743,7 +787,7 @@ def run(self) -> None: try: # Use a shorter timeout and don't hold the lock continuously time.sleep(0.1) - + if self.stopping: logger.debug(f"{self.__class__.__name__} stopping (pre-lock)") break @@ -751,7 +795,9 @@ def run(self) -> None: # Quick check with lock with self.s._locks["session"]: if self.s._events and not self.stopping: - logger.debug(f"{self.__class__.__name__}: Processing session changes") + logger.debug( + f"{self.__class__.__name__}: Processing session changes" + ) self.s.api.update_session() except Exception as e: @@ -768,8 +814,7 @@ def stop(self) -> None: if self.is_alive(): self.join(timeout=0.5) - -active_sessions: List[Session] = [] +active_sessions = WeakSet() __all__ = ["Session"] diff --git a/tests/test_session.py b/tests/test_session.py index 45ed1132..03326c1b 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -377,7 +377,7 @@ def test_add_tags(self, mock_req): def cleanup_sessions(): yield # Force cleanup of any remaining sessions - for session in active_sessions[:]: # Create a copy of the list to iterate + for session in active_sessions: # Create a copy of the list to iterate try: session.stop() except: