Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix pd series with datetime #48

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agentlib/core/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,8 @@ def convert_to_pd_series(value):
)
if isinstance(srs.index[0], str):
srs.index = srs.index.astype(float)
if isinstance(srs.index, pd.DatetimeIndex):
srs.index = pd.to_numeric(srs.index) / 10**9
return srs


Expand Down
13 changes: 11 additions & 2 deletions agentlib/modules/communicator/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,13 @@ def _send(self, payload: CommunicationDict):
"This method needs to be implemented " "individually for each communicator"
)

def short_dict(self, variable: AgentVariable) -> CommunicationDict:
def short_dict(self, variable: AgentVariable, parse_json: bool = True) -> CommunicationDict:
"""Creates a short dict serialization of the Variable.

Only contains attributes of the AgentVariable, that are relevant for other
modules or agents. For performance and privacy reasons, this function should
be called for communicators."""
if isinstance(variable.value, pd.Series):
if isinstance(variable.value, pd.Series) and parse_json:
value = variable.value.to_json()
else:
value = variable.value
Expand Down Expand Up @@ -187,6 +187,15 @@ def setup_broker(self):
"This method needs to be implemented " "individually for each communicator"
)

def _send_only_shared_variables(self, variable: AgentVariable):
"""Send only variables with field ``shared=True``"""
if not self._variable_can_be_send(variable):
return

payload = self.short_dict(variable, parse_json=self.config.parse_json)
self.logger.debug("Sending variable %s=%s", variable.alias, variable.value)
self._send(payload=payload)

def _process(self):
"""Waits for new messages, sends them to the broker."""
yield self.env.event()
Expand Down
15 changes: 12 additions & 3 deletions tests/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,21 @@ def test_pd_series(self):
"""Tests whether pandas series are sent correctly"""
data = {**default_data, "value": pd.Series({0: 1, 10: 2}), "type": "pd.Series"}
variable = AgentVariable(**data)
comm = LocalClient(config=self.test_config, agent=self.agent_send)
payload = comm.short_dict(variable)
var_json = comm.to_json(payload)
comm_parse = LocalClient(config=self.test_config, agent=self.agent_send)
comm_no_parse = LocalClient(config={**self.test_config, "parse_json": False}, agent=self.agent_send)

# communicator with json parsing
payload = comm_parse.short_dict(variable)
var_json = comm_parse.to_json(payload)
variable2 = AgentVariable.from_json(var_json)
pd.testing.assert_series_equal(variable.value, variable2.value)

# communicator without json parsing
payload = comm_no_parse.short_dict(variable, parse_json=comm_no_parse.config.parse_json)
payload["name"] = payload["alias"]
variable2 = AgentVariable(**payload)
pd.testing.assert_series_equal(variable.value, variable2.value)


if __name__ == "__main__":
unittest.main()
66 changes: 64 additions & 2 deletions tests/test_datamodels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Module to test all datamodels in the agentlib"""

import unittest
from pydantic import ValidationError
from agentlib.core import datamodels, errors

import numpy as np
import pandas as pd

from agentlib.core import datamodels


class TestVariables(unittest.TestCase):
"""Class with tests for Variables"""
Expand Down Expand Up @@ -109,6 +111,66 @@ def test_pd_series(self):
self.assertEqual(var.allowed_values, [])
self.assertIsInstance(var.value, pd.Series)

def test_pd_series_index_handling(self):
"""Test pd.Series index handling, especially for datetime and numeric indices."""

# Test datetime index conversion
datetime_index = pd.date_range('2023-01-01', periods=5, freq='D')
original_series = pd.Series([1., 2., 3., 4., 5.], index=datetime_index)

# Convert to JSON (simulating sending)
series_dict = original_series.to_dict()
converted_series = datamodels.convert_to_pd_series(series_dict)

# Check that index is numeric (float)
self.assertTrue(converted_series.index.dtype == np.float64)

# Check values are preserved
pd.testing.assert_series_equal(
converted_series,
pd.Series(original_series.values, index=pd.to_numeric(datetime_index) / 10**9)
)

# Test numeric index preservation
numeric_index = [1.5, 2.5, 3.5, 4.5, 5.5]
original_series = pd.Series([1., 2., 3., 4., 5.], index=numeric_index)
series_dict = original_series.to_dict()
converted_series = datamodels.convert_to_pd_series(series_dict)

# Check that numeric index remains numeric
self.assertTrue(converted_series.index.dtype == np.float64)
pd.testing.assert_series_equal(converted_series, original_series)

# Test string index conversion
string_index = ['1.0', '2.0', '3.0', '4.0', '5.0']
original_series = pd.Series([1., 2., 3., 4., 5.], index=string_index)
series_dict = original_series.to_dict()
converted_series = datamodels.convert_to_pd_series(series_dict)

# Check that string index is converted to float
self.assertTrue(converted_series.index.dtype == np.float64)
pd.testing.assert_series_equal(
converted_series,
pd.Series(original_series.values, index=[float(x) for x in string_index])
)

# Test full variable serialization/deserialization with datetime index
datetime_series = pd.Series([1., 2., 3.],
index=pd.date_range('2023-01-01', periods=3))
datetime_series.index = pd.to_numeric(datetime_series.index) / 10**9
var = datamodels.AgentVariable(
name="test",
type="pd.Series",
value=datetime_series
)

# Serialize and deserialize
json_str = var.json()
var_after = datamodels.AgentVariable.from_json(json_str)

# Check that the index remained numeric and didn't convert back to datetime
self.assertTrue(var_after.value.index.dtype == np.float64)

def test_bounds(self):
"""Test if ub and lb work"""
var = datamodels.BaseVariable.validate_data({"name": "test"})
Expand Down