Skip to content

Commit

Permalink
add blackjack (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Dec 28, 2021
1 parent db536e6 commit 90c265f
Show file tree
Hide file tree
Showing 10 changed files with 232 additions and 2 deletions.
22 changes: 22 additions & 0 deletions docs/api/toy_text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,25 @@ The board is a 4x12 matrix, with (using NumPy matrix indexing):
Each time step incurs -1 reward, and stepping into the cliff incurs -100
reward and a reset to the start. An episode terminates when the agent reaches
the goal.


Blackjack-v1
------------

`gym Blackjack-v1 source code
<https://github.com/openai/gym/blob/master/gym/envs/toy_text/blackjack.py>`_

Blackjack is a card game where the goal is to obtain cards that sum to as near
as possible to 21 without going over. They're playing against a fixed dealer.
Face cards (Jack, Queen, King) have point value 10. Aces can either count as
11 or 1, and it's called 'usable' at 11.

This game is placed with an infinite deck (or with replacement). The game
starts with dealer having one face up and one face down card, while player
having two face up cards. (Virtually for all Blackjack games today). The player
can request additional cards (hit=1) until they decide to stop (stick=0) or
exceed 21 (bust). After the player sticks, the dealer reveals their facedown
card, and draws until their sum is 17 or greater. If the dealer goes bust the
player wins. If neither player nor dealer busts, the outcome (win, lose, draw)
is decided by whose sum is closer to 21. The reward for winning is +1, drawing
is 0, and losing is -1.
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ frictionless
effector
walkable
NChain
facedown
2 changes: 1 addition & 1 deletion envpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
register,
)

__version__ = "0.4.2"
__version__ = "0.4.3"
__all__ = [
"register",
"make",
Expand Down
2 changes: 1 addition & 1 deletion envpool/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_make_classic_and_toytext(self) -> None:
]
toytext = [
"Catch-v0", "FrozenLake-v1", "FrozenLake8x8-v1", "Taxi-v3", "NChain-v0",
"CliffWalking-v0"
"CliffWalking-v0", "Blackjack-v1"
]
for task_id in classic + toytext:
envpool.make_spec(task_id)
Expand Down
9 changes: 9 additions & 0 deletions envpool/toy_text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,21 @@ cc_library(
],
)

cc_library(
name = "blackjack",
hdrs = ["blackjack.h"],
deps = [
"//envpool/core:async_envpool",
],
)

pybind_extension(
name = "toy_text_envpool",
srcs = [
"toy_text.cc",
],
deps = [
":blackjack",
":catch",
":cliffwalking",
":frozen_lake",
Expand Down
9 changes: 9 additions & 0 deletions envpool/toy_text/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from envpool.python.api import py_env

from .toy_text_envpool import (
_BlackjackEnvPool,
_BlackjackEnvSpec,
_CatchEnvPool,
_CatchEnvSpec,
_CliffWalkingEnvPool,
Expand Down Expand Up @@ -46,6 +48,10 @@
_CliffWalkingEnvSpec, _CliffWalkingEnvPool
)

BlackjackEnvSpec, BlackjackDMEnvPool, BlackjackGymEnvPool = py_env(
_BlackjackEnvSpec, _BlackjackEnvPool
)

__all__ = [
"CatchEnvSpec",
"CatchDMEnvPool",
Expand All @@ -62,4 +68,7 @@
"CliffWalkingEnvSpec",
"CliffWalkingDMEnvPool",
"CliffWalkingGymEnvPool",
"BlackjackEnvSpec",
"BlackjackDMEnvPool",
"BlackjackGymEnvPool",
]
153 changes: 153 additions & 0 deletions envpool/toy_text/blackjack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// https://github.com/openai/gym/blob/master/gym/envs/toy_text/blackjack.py

#ifndef ENVPOOL_TOY_TEXT_BLACKJACK_H_
#define ENVPOOL_TOY_TEXT_BLACKJACK_H_

#include <algorithm>
#include <cmath>
#include <random>
#include <string>
#include <vector>

#include "envpool/core/async_envpool.h"
#include "envpool/core/env.h"

namespace toy_text {

class BlackjackEnvFns {
public:
static decltype(auto) DefaultConfig() {
return MakeDict("natural"_.bind(false), "sab"_.bind(true));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs"_.bind(Spec<int>({3}, {0, 31})));
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
return MakeDict("action"_.bind(Spec<int>({-1}, {0, 1})));
}
};

typedef class EnvSpec<BlackjackEnvFns> BlackjackEnvSpec;

class BlackjackEnv : public Env<BlackjackEnvSpec> {
protected:
bool natural_, sab_;
std::vector<int> player_, dealer_;
std::uniform_int_distribution<> dist_;
bool done_;

public:
BlackjackEnv(const Spec& spec, int env_id)
: Env<BlackjackEnvSpec>(spec, env_id),
natural_(spec.config["natural"_]),
sab_(spec.config["sab"_]),
dist_(1, 13),
done_(true) {}

bool IsDone() override { return done_; }

void Reset() override {
player_.clear();
player_.push_back(DrawCard());
player_.push_back(DrawCard());
dealer_.clear();
dealer_.push_back(DrawCard());
dealer_.push_back(DrawCard());
done_ = false;
State state = Allocate();
WriteObs(state, 0.0f);
}

void Step(const Action& action) override {
int act = action["action"_];
float reward = 0.0f;
if (act) { // hit: add a card to players hand and return
player_.push_back(DrawCard());
if (IsBust(player_)) {
done_ = true;
reward = -1.0f;
}
} else { // stick: play out the dealers hand, and score
done_ = true;
while (SumHand(dealer_) < 17) {
dealer_.push_back(DrawCard());
}
int player_score = Score(player_);
int dealer_score = Score(dealer_);
reward = (player_score > dealer_score ? 1.0f : 0.0f) -
(player_score < dealer_score ? 1.0f : 0.0f);
if (sab_ && IsNatural(player_) && !IsNatural(dealer_)) {
reward = 1.0f;
} else if (!sab_ && natural_ && IsNatural(player_) && reward == 1.0f) {
reward = 1.5f;
}
}
State state = Allocate();
WriteObs(state, reward);
}

private:
void WriteObs(State& state, float reward) { // NOLINT
state["obs"_][0] = SumHand(player_);
state["obs"_][1] = dealer_[0];
state["obs"_][2] = UsableAce(player_);
state["reward"_] = reward;
}

int DrawCard() { return std::min(10, dist_(gen_)); }

int UsableAce(const std::vector<int>& hand) {
for (auto i : hand) {
if (i == 1) {
return 1;
}
}
return 0;
}

int SumHand(const std::vector<int>& hand) {
int sum = 0;
for (auto i : hand) {
sum += i;
}
if (UsableAce(hand) && sum + 10 <= 21) {
return sum + 10;
}
return sum;
}

bool IsBust(const std::vector<int>& hand) { return SumHand(hand) > 21; }

int Score(const std::vector<int>& hand) {
int result = SumHand(hand);
return result > 21 ? 0 : result;
}

bool IsNatural(const std::vector<int>& hand) {
return hand.size() == 2 &&
((hand[0] == 1 && hand[1] == 10) || (hand[0] == 10 && hand[1] == 1));
}
};

typedef AsyncEnvPool<BlackjackEnv> BlackjackEnvPool;

} // namespace toy_text

#endif // ENVPOOL_TOY_TEXT_BLACKJACK_H_
10 changes: 10 additions & 0 deletions envpool/toy_text/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,13 @@
dm_cls="CliffWalkingDMEnvPool",
gym_cls="CliffWalkingGymEnvPool",
)

register(
task_id="Blackjack-v1",
import_path="envpool.toy_text",
spec_cls="BlackjackEnvSpec",
dm_cls="BlackjackDMEnvPool",
gym_cls="BlackjackGymEnvPool",
sab=True,
natural=False,
)
5 changes: 5 additions & 0 deletions envpool/toy_text/toy_text.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "envpool/core/py_envpool.h"
#include "envpool/toy_text/blackjack.h"
#include "envpool/toy_text/catch.h"
#include "envpool/toy_text/cliffwalking.h"
#include "envpool/toy_text/frozen_lake.h"
Expand All @@ -34,10 +35,14 @@ typedef PyEnvPool<toy_text::NChainEnvPool> NChainEnvPool;
typedef PyEnvSpec<toy_text::CliffWalkingEnvSpec> CliffWalkingEnvSpec;
typedef PyEnvPool<toy_text::CliffWalkingEnvPool> CliffWalkingEnvPool;

typedef PyEnvSpec<toy_text::BlackjackEnvSpec> BlackjackEnvSpec;
typedef PyEnvPool<toy_text::BlackjackEnvPool> BlackjackEnvPool;

PYBIND11_MODULE(toy_text_envpool, m) {
REGISTER(m, CatchEnvSpec, CatchEnvPool)
REGISTER(m, FrozenLakeEnvSpec, FrozenLakeEnvPool)
REGISTER(m, TaxiEnvSpec, TaxiEnvPool)
REGISTER(m, NChainEnvSpec, NChainEnvPool)
REGISTER(m, CliffWalkingEnvSpec, CliffWalkingEnvPool)
REGISTER(m, BlackjackEnvSpec, BlackjackEnvPool)
}
21 changes: 21 additions & 0 deletions envpool/toy_text/toy_text_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from dm_env import TimeStep

from envpool.toy_text import (
BlackjackEnvSpec,
BlackjackGymEnvPool,
CatchDMEnvPool,
CatchEnvSpec,
CatchGymEnvPool,
Expand Down Expand Up @@ -229,6 +231,25 @@ def test_cliffwalking(self) -> None:
if ref_done:
break

def test_blackjack(self) -> None:
np.random.seed(0)
num_envs = 100
spec = BlackjackEnvSpec(BlackjackEnvSpec.gen_config(num_envs=num_envs))
env = BlackjackGymEnvPool(spec)
assert isinstance(env.observation_space, gym.spaces.Box)
assert env.observation_space.shape == (3,)
assert isinstance(env.action_space, gym.spaces.Discrete)
assert env.action_space.n == 2
reward, rewards = np.zeros(num_envs), []
for _ in range(1000):
obs, rew, done, info = env.step(np.random.randint(2, size=(num_envs,)))
reward += rew
if np.any(done):
rewards += reward[done].tolist()
reward[done] = 0
assert abs(np.mean(rewards) + 0.395) < 0.05
assert abs(np.std(rewards) - 0.89) < 0.05


if __name__ == "__main__":
absltest.main()

0 comments on commit 90c265f

Please sign in to comment.