Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
alezana committed Jan 2, 2025
1 parent 3178807 commit 71149d8
Showing 1 changed file with 7 additions and 34 deletions.
41 changes: 7 additions & 34 deletions waymax/rewards/linear_transformed_reward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,62 +45,35 @@ def test_config(self):
expected_reward = jnp.array([1.0, 0.0, 1.0]) # Using x^2 on the mock metric
self.assertTrue(jnp.allclose(result, expected_reward))

def test_default_transform(self):
def test_transform_offroad(self):
reward_config = _config.LinearTransformedRewardConfig(
rewards={
"gokart_distance_to_bounds": 0.1,
"offroad": 0.1,
},
transform=defaultdict(
lambda: lambda x: x,
gokart_distance_to_bounds=lambda x: jnp.minimum(x, 0.5),
offroad=lambda x: jnp.minimum(x, 0.5),
),
)

reward = LinearTransformedReward(reward_config)

# Set up mock simulation state and agent mask
simulator_state = test_utils.simulator_state_with_overlap()
simulator_state = test_utils.simulator_state_with_offroad()
agent_mask = jnp.array([1, 1, 1]) # Assume all agents are active

# Compute the reward
result = reward.compute(simulator_state, None, agent_mask)

# Expected computation for "gokart_distance_to_bounds" using the capped transform
# Simulating reward metric masked_values as 1.0 for simple example
gokart_distance_metric = jnp.array([1.0, 0.5, -1]) # Sample masked values
offroad_metric = jnp.array([1.0]) # Sample masked values

# Apply transform and rewards calculation
capped_values = jnp.minimum(gokart_distance_metric, 0.5)
expected_reward = capped_values * 0.1 # Reward weight for "gokart_distance_to_bounds"

capped_values = jnp.minimum(offroad_metric, 0.5)
expected_reward = capped_values * 0.1 # Reward weight for "offroad_metric"
self.assertTrue(jnp.allclose(result, expected_reward))


def test_default_factory(self):
reward_config = _config.LinearTransformedRewardConfig(
rewards={"offroad": 0.9},
transform=defaultdict(
lambda: lambda x: x, # Default to identity
),
)

reward = LinearTransformedReward(reward_config)

# Set up mock simulation state and agent mask
simulator_state = test_utils.simulator_state_with_overlap()
agent_mask = jnp.array([1, 0, 1]) # Assume some agents are active

# Compute the reward
result = reward.compute(simulator_state, None, agent_mask)

# Expected computation using identity transform
gokart_progress_metric = jnp.array([1.0, -0.5, 0.7]) # Sample masked values

# No transformation means just multiply by weight
expected_reward = gokart_progress_metric * 0.9 * agent_mask

self.assertTrue(jnp.allclose(result, expected_reward))


# Run the tests
if __name__ == "__main__":
Expand Down

0 comments on commit 71149d8

Please sign in to comment.