Skip to content

Commit

Permalink
Added example of how to compute the mean episode return of a datset.
Browse files Browse the repository at this point in the history
  • Loading branch information
jcformanek committed Jan 4, 2024
1 parent 1bbee26 commit e752264
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
17 changes: 16 additions & 1 deletion examples/download_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,20 @@
# limitations under the License.

from og_marl.offline_dataset import download_and_unzip_dataset
from og_marl.environments import smacv1
from og_marl.offline_dataset import OfflineMARLDataset

download_and_unzip_dataset("voltage_control", "case33_3min_final", dataset_base_dir="datasets")
# Comment this out if you already downloaded the dataset
download_and_unzip_dataset("smac_v1", "3m", dataset_base_dir="datasets")

# Compute mean episode return of Good dataset

env = smacv1.SMACv1("3m") # Change SMAC Scenario Here
dataset = OfflineMARLDataset(env, f"datasets/smac_v1/3m/Good")

sample_cnt =0
tot_returns = 0
for sample in dataset._tf_dataset:
sample_cnt+=1
tot_returns += sample["episode_return"].numpy()
print("Mean Episode return:", tot_returns / sample_cnt)
3 changes: 1 addition & 2 deletions og_marl/offline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import os
import requests

Sample = namedtuple('Sample', ['observations', 'actions', 'rewards', 'done', 'episode_return', 'legal_actions', 'env_state', 'zero_padding_mask'])

def get_schema_dtypes(environment):
act_type = list(environment.action_spaces.values())[0].dtype
schema = {}
Expand Down Expand Up @@ -99,6 +97,7 @@ def _decode_fn(self, record_bytes):

sample["mask"] = example["zero_padding_mask"]
sample["state"] = example["env_state"]
sample["episode_return"] = example["episode_return"]

return sample

Expand Down

0 comments on commit e752264

Please sign in to comment.