Skip to content

Commit

Permalink
Merge pull request #1275 from nathanlct:fix_pttt_dh3_info_state
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 667950162
Change-Id: Ie258bbd85194496c58bd8804bb44ba5d72489f0f
  • Loading branch information
lanctot committed Aug 27, 2024
2 parents 42ff9ba + 5598fe1 commit 77a03df
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 83 deletions.
48 changes: 35 additions & 13 deletions open_spiel/games/dark_hex/dark_hex.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@

#include "open_spiel/games/dark_hex/dark_hex.h"

#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "open_spiel/abseil-cpp/absl/strings/str_cat.h"
#include "open_spiel/abseil-cpp/absl/types/span.h"
#include "open_spiel/games/hex/hex.h"
#include "open_spiel/game_parameters.h"
#include "open_spiel/observer.h"
#include "open_spiel/spiel.h"
#include "open_spiel/spiel_utils.h"

namespace open_spiel {
Expand All @@ -31,7 +37,6 @@ using hex::kCellStates;
using hex::CellState;
using hex::kMinValueCellState;

using hex::PlayerToState;
using hex::StateToString;

// Game Facts
Expand Down Expand Up @@ -107,11 +112,18 @@ DarkHexState::DarkHexState(std::shared_ptr<const Game> game, int num_cols,
game_version_(game_version),
num_cols_(num_cols),
num_rows_(num_rows),
num_cells_(num_cols * num_rows),
bits_per_action_(num_cells_ + 1),
longest_sequence_(num_cells_ * 2 - 1) {
num_cells_(num_cols * num_rows) {
black_view_.resize(num_cols * num_rows, CellState::kEmpty);
white_view_.resize(num_cols * num_rows, CellState::kEmpty);
if (obs_type == ObservationType::kRevealNothing) {
bits_per_action_ = num_cells_;
longest_sequence_ = num_cells_;
} else {
SPIEL_CHECK_EQ(obs_type_, ObservationType::kRevealNumTurns);
// Reserve 0 for the player and 10 as "I don't know."
bits_per_action_ = num_cells_ + 2;
longest_sequence_ = num_cells_ * 2 - 1;
}
}

void DarkHexState::DoApplyAction(Action move) {
Expand Down Expand Up @@ -218,7 +230,7 @@ void DarkHexState::InformationStateTensor(Player player,
const auto& player_view = (player == 0 ? black_view_ : white_view_);

SPIEL_CHECK_EQ(values.size(), num_cells_ * kCellStates +
longest_sequence_ * (1 + bits_per_action_));
longest_sequence_ * bits_per_action_);
std::fill(values.begin(), values.end(), 0.);
for (int cell = 0; cell < num_cells_; ++cell) {
values[cell * kCellStates +
Expand All @@ -230,18 +242,26 @@ void DarkHexState::InformationStateTensor(Player player,
for (const auto& player_with_action : action_sequence_) {
if (player_with_action.first == player) {
// Always include the observing player's actions.
values[offset] = player_with_action.first;
values[offset + 1 + player_with_action.second] = 1.0;
if (obs_type_ == ObservationType::kRevealNumTurns) {
values[offset] = player_with_action.first; // Player 0 or 1
values[offset + 1 + player_with_action.second] = 1.0;
} else {
// Here we don't need to encode the player since we won't see opponent
// moves.
SPIEL_CHECK_EQ(obs_type_, ObservationType::kRevealNothing);
values[offset + player_with_action.second] = 1.0;
}
offset += bits_per_action_;
} else if (obs_type_ == ObservationType::kRevealNumTurns) {
// If the number of turns are revealed, then each of the other player's
// actions will show up as unknowns. Here, num_cells_ is used to
// encode "unknown".
values[offset] = player_with_action.first;
values[offset + 1 + num_cells_] = 1.0;
offset += bits_per_action_;
} else {
SPIEL_CHECK_EQ(obs_type_, ObservationType::kRevealNothing);
}
offset += (1 + bits_per_action_);
}
}

Expand Down Expand Up @@ -290,14 +310,17 @@ DarkHexGame::DarkHexGame(const GameParameters& params, GameType game_type)
ParameterValue<int>("num_cols", ParameterValue<int>("board_size"))),
num_rows_(
ParameterValue<int>("num_rows", ParameterValue<int>("board_size"))),
num_cells_(num_cols_ * num_rows_),
bits_per_action_(num_cells_ + 1),
longest_sequence_(num_cells_ * 2 - 1) {
num_cells_(num_cols_ * num_rows_) {
std::string obs_type = ParameterValue<std::string>("obstype");
if (obs_type == "reveal-nothing") {
obs_type_ = ObservationType::kRevealNothing;
bits_per_action_ = num_cells_;
longest_sequence_ = num_cells_;
} else if (obs_type == "reveal-numturns") {
obs_type_ = ObservationType::kRevealNumTurns;
// Reserve 0 for the player and 10 as "I don't know."
bits_per_action_ = num_cells_ + 2;
longest_sequence_ = num_cells_ * 2 - 1;
} else {
SpielFatalError(absl::StrCat("Unrecognized observation type: ", obs_type));
}
Expand All @@ -313,8 +336,7 @@ DarkHexGame::DarkHexGame(const GameParameters& params, GameType game_type)
}

std::vector<int> DarkHexGame::InformationStateTensorShape() const {
return {num_cells_ * kCellStates +
longest_sequence_ * (1 + bits_per_action_)};
return {num_cells_ * kCellStates + longest_sequence_ * bits_per_action_};
}

std::vector<int> DarkHexGame::ObservationTensorShape() const {
Expand Down
8 changes: 4 additions & 4 deletions open_spiel/games/dark_hex/dark_hex.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ class DarkHexState : public State {
const int num_cols_; // x
const int num_rows_; // y
const int num_cells_;
const int bits_per_action_;
const int longest_sequence_;
int bits_per_action_;
int longest_sequence_;

// Change this to _history on base class
std::vector<std::pair<int, Action>> action_sequence_;
Expand Down Expand Up @@ -166,8 +166,8 @@ class DarkHexGame : public Game {
const int num_cols_;
const int num_rows_;
const int num_cells_;
const int bits_per_action_;
const int longest_sequence_;
int bits_per_action_;
int longest_sequence_;
};

class ImperfectRecallDarkHexState : public DarkHexState {
Expand Down
41 changes: 32 additions & 9 deletions open_spiel/games/phantom_ttt/phantom_ttt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ PhantomTTTState::PhantomTTTState(std::shared_ptr<const Game> game,
: State(game), state_(game), obs_type_(obs_type) {
std::fill(begin(x_view_), end(x_view_), CellState::kEmpty);
std::fill(begin(o_view_), end(o_view_), CellState::kEmpty);
if (obs_type_ == ObservationType::kRevealNumTurns) {
// Reserve 0 for the player and 10 as "I don't know."
bits_per_action_ = kNumCells + 2;
// Longest sequence is 17 moves, e.g. 0011223344556677889
longest_sequence_ = 2 * kNumCells - 1;
} else {
SPIEL_CHECK_EQ(obs_type_, ObservationType::kRevealNothing);
bits_per_action_ = kNumCells;
longest_sequence_ = kNumCells;
}
}

void PhantomTTTState::DoApplyAction(Action move) {
Expand Down Expand Up @@ -193,7 +203,7 @@ void PhantomTTTState::InformationStateTensor(Player player,
// which may contain action value 10 to represent "I don't know."
const auto& player_view = player == 0 ? x_view_ : o_view_;
SPIEL_CHECK_EQ(values.size(), kNumCells * kCellStates +
kLongestSequence * (1 + kBitsPerAction));
longest_sequence_ * bits_per_action_);
std::fill(values.begin(), values.end(), 0.);
for (int cell = 0; cell < kNumCells; ++cell) {
values[kNumCells * static_cast<int>(player_view[cell]) + cell] = 1.0;
Expand All @@ -206,19 +216,26 @@ void PhantomTTTState::InformationStateTensor(Player player,
for (const auto& player_with_action : action_sequence_) {
if (player_with_action.first == player) {
// Always include the observing player's actions.
values[offset] = player_with_action.first; // Player 0 or 1
values[offset + 1 + player_with_action.second] = 1.0;
if (obs_type_ == ObservationType::kRevealNumTurns) {
values[offset] = player_with_action.first; // Player 0 or 1
values[offset + 1 + player_with_action.second] = 1.0;
} else {
// Here we don't need to encode the player since we won't see opponent
// moves.
SPIEL_CHECK_EQ(obs_type_, ObservationType::kRevealNothing);
values[offset + player_with_action.second] = 1.0;
}
offset += bits_per_action_;
} else if (obs_type_ == ObservationType::kRevealNumTurns) {
// If the number of turns are revealed, then each of the other player's
// actions will show up as unknowns.
values[offset] = player_with_action.first;
values[offset + 1 + 10] = 1.0; // I don't know.
values[offset + 1 + kNumCells] = 1.0; // I don't know.
offset += bits_per_action_;
} else {
// Do not reveal anything about the number of actions taken by opponent.
SPIEL_CHECK_EQ(obs_type_, ObservationType::kRevealNothing);
}

offset += (1 + kBitsPerAction);
}
}

Expand Down Expand Up @@ -283,25 +300,31 @@ PhantomTTTGame::PhantomTTTGame(const GameParameters& params, GameType game_type)
std::string obs_type = ParameterValue<std::string>("obstype");
if (obs_type == "reveal-nothing") {
obs_type_ = ObservationType::kRevealNothing;
bits_per_action_ = kNumCells;
longest_sequence_ = kNumCells;
} else if (obs_type == "reveal-numturns") {
obs_type_ = ObservationType::kRevealNumTurns;
// Reserve 0 for the player and 10 as "I don't know."
bits_per_action_ = kNumCells + 2;
// Longest sequence is 17 moves, e.g. 0011223344556677889
longest_sequence_ = 2 * kNumCells - 1;
} else {
SpielFatalError(absl::StrCat("Unrecognized observation type: ", obs_type));
}
}

std::vector<int> PhantomTTTGame::InformationStateTensorShape() const {
// Enc
return {1, kNumCells * kCellStates + kLongestSequence * (1 + kBitsPerAction)};
return {1, kNumCells * kCellStates + longest_sequence_ * bits_per_action_};
}

std::vector<int> PhantomTTTGame::ObservationTensorShape() const {
if (obs_type_ == ObservationType::kRevealNothing) {
return {kNumCells * kCellStates};
} else if (obs_type_ == ObservationType::kRevealNumTurns) {
return {kNumCells * kCellStates + kLongestSequence};
return {kNumCells * kCellStates + longest_sequence_};
} else {
SpielFatalError("Uknown observation type");
SpielFatalError("Unknown observation type");
}
}

Expand Down
11 changes: 6 additions & 5 deletions open_spiel/games/phantom_ttt/phantom_ttt.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,6 @@ namespace phantom_ttt {

inline constexpr const char* kDefaultObsType = "reveal-nothing";

// Longest sequence is 17 moves, e.g. 0011223344556677889
inline constexpr int kLongestSequence = 2 * tic_tac_toe::kNumCells - 1;
inline constexpr int kBitsPerAction = 10; // Reserve 9 as "I don't know."

enum class ObservationType {
kRevealNothing,
kRevealNumTurns,
Expand Down Expand Up @@ -88,6 +84,9 @@ class PhantomTTTState : public State {

tic_tac_toe::TicTacToeState state_;
ObservationType obs_type_;
int bits_per_action_;
int longest_sequence_;

// TODO(author2): Use the base class history_ instead.
std::vector<std::pair<int, Action>> action_sequence_;
std::array<tic_tac_toe::CellState, tic_tac_toe::kNumCells> x_view_;
Expand Down Expand Up @@ -119,13 +118,15 @@ class PhantomTTTGame : public Game {
// These will depend on the obstype parameter.
std::vector<int> InformationStateTensorShape() const override;
std::vector<int> ObservationTensorShape() const override;
int MaxGameLength() const override { return kLongestSequence; }
int MaxGameLength() const override { return tic_tac_toe::kNumCells * 2 - 1; }

ObservationType obs_type() const { return obs_type_; }

private:
std::shared_ptr<const tic_tac_toe::TicTacToeGame> game_;
ObservationType obs_type_;
int bits_per_action_;
int longest_sequence_;
};

// Implements the FOE abstraction from Lanctot et al. '12
Expand Down
Loading

0 comments on commit 77a03df

Please sign in to comment.