Skip to content

Commit

Permalink
weakset, reference tracking, etc
Browse files Browse the repository at this point in the history
  • Loading branch information
teocns committed Nov 12, 2024
1 parent 8500da2 commit 0fc5a87
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 71 deletions.
185 changes: 115 additions & 70 deletions agentops/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import queue
import threading
import time

from decimal import ROUND_HALF_UP, Decimal
from typing import Annotated, Dict, List, Optional, Union
from uuid import UUID, uuid4
Expand All @@ -23,12 +22,14 @@
"""
minor changes:
- `active_sessions` is now a WeakSet. Cleanup is not required. Tests can be simplified
- Removed unsafe usage of __dict__ (which is very dangerous when used without __slots__)
- Introduced Session.active property
"""

from typing import DefaultDict
from weakref import WeakSet


class SessionDict(DefaultDict):
Expand Down Expand Up @@ -223,7 +224,9 @@ def __init__(self, **kwargs):
) # config assigned at object level
logger.debug(f"{self.__class__.__name__}: Session locks initialized")

super().__init__(**kwargs) # Initialize SessionStruct after `locks` were defined
super().__init__(
**kwargs
) # Initialize SessionStruct after `locks` were defined

# Initialize queue with max size from config
self._events = queue.Queue(maxsize=self.config.max_queue_size)
Expand All @@ -242,6 +245,9 @@ def __init__(self, **kwargs):

self._is_running = False

# Add session to active_sessions
active_sessions.add(self)

self._start_session()

if self.is_running:
Expand All @@ -251,6 +257,9 @@ def __init__(self, **kwargs):
else:
self.stop()

def __hash__(self) -> int:
return hash(self.session_id)

@property
def session_id(self) -> UUID:
return self["session_id"]
Expand Down Expand Up @@ -322,7 +331,9 @@ def set_tags(self, tags):
# --- Interactors
def record(self, event: Union[Event, ErrorEvent]):
if not self.is_running:
logger.debug(f"{self.__class__.__name__}: Attempted to record event but session is not running")
logger.debug(
f"{self.__class__.__name__}: Attempted to record event but session is not running"
)
return

logger.debug(f"Recording event: {event.event_type}")
Expand Down Expand Up @@ -350,7 +361,7 @@ def record(self, event: Union[Event, ErrorEvent]):

def end_session(
self,
end_state: str = "Indeterminate",
end_state: str = "Indeterminate",
end_state_reason: Optional[str] = None,
video: Optional[str] = None,
) -> Union[Decimal, None]:
Expand Down Expand Up @@ -437,8 +448,7 @@ def __calc_elapsed():
)
)

if self in active_sessions:
active_sessions.remove(self)
active_sessions.discard(self)

return token_cost_d

Expand Down Expand Up @@ -473,15 +483,25 @@ def _enqueue(self, event: dict) -> None:
logger.warning("Attempted to enqueue event but session is not running")
return

logger.debug(f"{self.__class__.__name__} enqueueing event, current queue size: {self._events.qsize()}")
logger.debug(
f"{self.__class__.__name__} enqueueing event, current queue size: {self._events.qsize()}"
)
try:
self._events.put_nowait(event) # Use put_nowait instead of directly accessing queue
logger.debug(f"{self.__class__.__name__} successfully enqueued event, new size: {self._events.qsize()}")
self._events.put_nowait(
event
) # Use put_nowait instead of directly accessing queue
logger.debug(
f"{self.__class__.__name__} successfully enqueued event, new size: {self._events.qsize()}"
)
except queue.Full:
logger.warning(f"{self.__class__.__name__} event queue is full, event will be dropped")
logger.warning(
f"{self.__class__.__name__} event queue is full, event will be dropped"
)

if self._events.qsize() >= self.config.max_queue_size:
logger.debug(f"{self.__class__.__name__}: Queue reached max size, triggering flush")
logger.debug(
f"{self.__class__.__name__}: Queue reached max size, triggering flush"
)
self._flush_queue()

def _publish(self):
Expand All @@ -496,20 +516,20 @@ def stop(self) -> None:
with self._locks["lifecycle"]:
if not self._is_running:
return

self._is_running = False
self._stop_flag.set() # Signal threads to stop

# Stop threads
if hasattr(self, 'publisher_thread') and self.publisher_thread:
if hasattr(self, "publisher_thread") and self.publisher_thread:
logger.debug(f"{self.__class__.__name__}: Stopping publisher thread...")
self.publisher_thread.stop()
# del self.publisher_thread
self.publisher_thread = None # Remove reference

if hasattr(self, 'observer_thread') and self.observer_thread:
if hasattr(self, "observer_thread") and self.observer_thread:
logger.debug(f"{self.__class__.__name__}: Stopping observer thread...")
self.observer_thread.stop()
# del self.observer_thread
self.observer_thread = None # Remove reference

# Flush any remaining events
if not self._events.empty():
Expand All @@ -525,7 +545,10 @@ def stop(self) -> None:
with self.conditions["cleanup"]:
self._cleanup_done = True
self.conditions["cleanup"].notify_all()


# Remove from active_sessions
active_sessions.discard(self)

logger.debug(f"{self.__class__.__name__}: Session stopped")

def pause(self) -> None:
Expand Down Expand Up @@ -610,7 +633,7 @@ def _start_session(self) -> bool:
if success:
self._is_running = True
return True

return False

def start(self) -> bool:
Expand All @@ -620,7 +643,7 @@ def start(self) -> bool:
"""
if self.is_running:
return True

success = self._start_session()
if success:
self.publisher_thread.start()
Expand All @@ -634,7 +657,7 @@ class _SessionThread(threading.Thread):
def __init__(self, session: Session):
super().__init__()
self.s = session
self.daemon = True
self.daemon = True # Keep as daemon thread
self._local_stop = threading.Event()

@property
Expand All @@ -646,7 +669,15 @@ def stop(self) -> None:
"""Signal thread to stop and wait for completion"""
self._local_stop.set()
if self.is_alive():
self.join(timeout=0.5) # Wait up to 0.5 seconds for thread to stop
try:
self.join(timeout=1.0) # Increased timeout slightly
if self.is_alive():
logger.warning(
f"{self.__class__.__name__} thread did not stop cleanly"
)
except RuntimeError:
# Handle case where thread is already stopped
pass


class EventPublisherThread(_SessionThread):
Expand All @@ -664,51 +695,60 @@ def run(self) -> None:
try:
current_time = time.monotonic()
should_publish = False
# with self._batch_lock:
batch_size = len(self._batch)
logger.debug(
f"{self.__class__.__name__} current batch size: {batch_size}"
)

with self._batch_lock:
batch_size = len(self._batch)
logger.debug(f"{self.__class__.__name__} current batch size: {batch_size}")

# Try to collect events up to max batch size
while (
len(self._batch) < self.s.config.max_queue_size
and not self.stopping
):
try:
event = self.s._events.get_nowait()
if event: # Make sure we got a valid event
self._batch.append(event)
logger.debug(f"{self.__class__.__name__} added event to batch: {event}")
except queue.Empty:
break

new_batch_size = len(self._batch)
if new_batch_size > batch_size:
logger.debug(f"{self.__class__.__name__} added {new_batch_size - batch_size} events to batch")

# Determine if we should publish based on conditions
time_elapsed = current_time - self._last_batch_time
should_publish = (
len(self._batch) >= self.s.config.max_queue_size
or ( # Batch is full
len(self._batch) > 0
and time_elapsed # Have events and max time elapsed
>= self.s.config.max_wait_time / 1000
)
or (
len(self._batch) > 0
and (self.s._events.empty() or self.stopping)
) # Have events and queue is empty or stopping
# Try to collect events up to max batch size
while (
len(self._batch) < self.s.config.max_queue_size
and not self.stopping
):
try:
event = self.s._events.get_nowait()
if event: # Make sure we got a valid event
self._batch.append(event)
logger.debug(
f"{self.__class__.__name__} added event to batch: {event}"
)
except queue.Empty:
break

new_batch_size = len(self._batch)
if new_batch_size > batch_size:
logger.debug(
f"{self.__class__.__name__} added {new_batch_size - batch_size} events to batch"
)

if should_publish and self._batch: # Only publish if we have events
try:
logger.debug(f"{self.__class__.__name__} publishing batch of {len(self._batch)} events: {self._batch}")
self.s.api.batch(self._batch[:]) # Send a copy of the batch
self._batch.clear()
self._last_batch_time = current_time
except Exception as e:
logger.error(f"{self.__class__.__name__} failed to publish batch: {e}")
# Determine if we should publish based on conditions
time_elapsed = current_time - self._last_batch_time
should_publish = (
len(self._batch) >= self.s.config.max_queue_size
or ( # Batch is full
len(self._batch) > 0
and time_elapsed # Have events and max time elapsed
>= self.s.config.max_wait_time / 1000
)
or (
len(self._batch) > 0
and (self.s._events.empty() or self.stopping)
) # Have events and queue is empty or stopping
)

if should_publish and self._batch: # Only publish if we have events
try:
logger.debug(
f"{self.__class__.__name__} publishing batch of {len(self._batch)} events: {self._batch}"
)
self.s.api.batch(self._batch[:]) # Send a copy of the batch
self._batch.clear()
self._last_batch_time = current_time
except Exception as e:
logger.error(
f"{self.__class__.__name__} failed to publish batch: {e}"
)

if not should_publish and not self.stopping:
# Sleep briefly to prevent tight polling
Expand All @@ -722,14 +762,18 @@ def run(self) -> None:
def stop(self) -> None:
"""Ensure any remaining events are published before stopping"""
super().stop() # Set stop flag first

with self._batch_lock:
if self._batch:
try:
logger.debug(f"{self.__class__.__name__} final batch of {len(self._batch)} events")
logger.debug(
f"{self.__class__.__name__} final batch of {len(self._batch)} events"
)
self.s.api.batch(self._batch)
except Exception as e:
logger.error(f"{self.__class__.__name__} failed to publish final batch during shutdown: {e}")
logger.error(
f"{self.__class__.__name__} failed to publish final batch during shutdown: {e}"
)
finally:
self._batch.clear()

Expand All @@ -743,15 +787,17 @@ def run(self) -> None:
try:
# Use a shorter timeout and don't hold the lock continuously
time.sleep(0.1)

if self.stopping:
logger.debug(f"{self.__class__.__name__} stopping (pre-lock)")
break

# Quick check with lock
with self.s._locks["session"]:
if self.s._events and not self.stopping:
logger.debug(f"{self.__class__.__name__}: Processing session changes")
logger.debug(
f"{self.__class__.__name__}: Processing session changes"
)
self.s.api.update_session()

except Exception as e:
Expand All @@ -768,8 +814,7 @@ def stop(self) -> None:
if self.is_alive():
self.join(timeout=0.5)


active_sessions: List[Session] = []
active_sessions = WeakSet()

__all__ = ["Session"]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def test_add_tags(self, mock_req):
def cleanup_sessions():
yield
# Force cleanup of any remaining sessions
for session in active_sessions[:]: # Create a copy of the list to iterate
for session in active_sessions: # Create a copy of the list to iterate
try:
session.stop()
except:
Expand Down

0 comments on commit 0fc5a87

Please sign in to comment.