Skip to content

Commit

Permalink
Fix typing issue in base system.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Jul 24, 2024
1 parent 82d7082 commit cf3441a
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion og_marl/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_stats(self) -> Dict:
"""
return {}

def render(self):
def render(self) -> Any:
"""Return frame for rendering"""
return np.zeros((10, 10, 3), "float32")

Expand Down
1 change: 0 additions & 1 deletion og_marl/environments/jumanji_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def step(self, actions: Dict[str, np.ndarray]) -> StepReturn:

return observations, rewards, terminals, truncations, info


def render(self) -> Any:
frame = self._state

Expand Down
6 changes: 3 additions & 3 deletions og_marl/tf2/systems/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import time
from typing import Dict, Optional
from typing import Dict, Optional, Tuple, List

import numpy as np
from chex import Numeric
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(
def get_stats(self) -> Dict[str, Numeric]:
return {}

def evaluate(self, num_eval_episodes: int = 4) -> Dict[str, Numeric]:
def evaluate(self, num_eval_episodes: int = 4) -> Tuple[Dict[str, Numeric], list]:
"""Method to evaluate the system online (i.e. in the environment)."""
episode_returns = []
all_frames = []
Expand Down Expand Up @@ -187,7 +187,7 @@ def train_offline(
evaluate_every: int = 1000,
num_eval_episodes: int = 4,
json_writer: Optional[JsonWriter] = None,
) -> None:
) -> List:
"""Method to train the system offline.
WARNING: make sure evaluate_every % log_every == 0 and log_every < evaluate_every,
Expand Down

0 comments on commit cf3441a

Please sign in to comment.