Skip to content

Commit

Permalink
live view into game with pygame. expanded frames, and click for pause…
Browse files Browse the repository at this point in the history
… in script
  • Loading branch information
syrkis committed Jun 23, 2024
1 parent d9e3a4d commit e9239b4
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 45 deletions.
24 changes: 7 additions & 17 deletions ludens.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,9 @@
"outputs": [],
"source": [
"# types\n",
"Seed = jnp.ndarray\n",
"State = jaxmarl.environments.smax.smax_env.State\n",
"DictType = Dict[str, jnp.ndarray]\n",
"Obs = Reward = Done = Action = DictType\n",
"StateSeq = List[Tuple[Seed, State, Action]]"
"Obs = Reward = Done = Action = Dict[str, jnp.ndarray]\n",
"StateSeq = List[Tuple[jnp.ndarray, State, Action]]"
]
},
{
Expand Down Expand Up @@ -154,20 +152,12 @@
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"ename": "SyntaxError",
"evalue": "invalid syntax. Perhaps you forgot a comma? (3167373708.py, line 12)",
"output_type": "error",
"traceback": [
"\u001b[0;36m Cell \u001b[0;32mIn[21], line 12\u001b[0;36m\u001b[0m\n\u001b[0;31m clock=pygame.time.Clock()\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax. Perhaps you forgot a comma?\n"
]
}
],
"outputs": [],
"source": [
"# pygame setup\n",
"pygame.init()\n",
"screen = pygame.display.set_mode((720, 720))\n",
"# enable retina display\n",
"screen = pygame.display.set_mode((1000, 1000))\n",
"render = partial(render_fn, screen)\n",
"rng, key = random.split(random.PRNGKey(0))\n",
"obs, state = env.reset(key)\n",
Expand All @@ -176,7 +166,7 @@
" env=env,\n",
" rng=rng,\n",
" state_seq=[], # [(key, state, action)]\n",
" clock=pygame.time.Clock()\n",
" clock=pygame.time.Clock(),\n",
" state=state,\n",
" obs=obs,\n",
")\n",
Expand All @@ -185,7 +175,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
Expand Down
137 changes: 109 additions & 28 deletions parabellum/run.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,125 @@
# run.py
# parabellum run functions for running interactive environments
# parabellum run game live
# by: Noah Syrkis

# imports
import jax.numpy as jnp
import jax
import time
# Noah Syrkis
import pygame
from jax import random
from functools import partial
import darkdetect
import jax.numpy as jnp
from chex import dataclass
import jaxmarl
from typing import Tuple, List, Dict, Optional
import parabellum as pb


def plot_frame(env, screen, state):
positions = state.unit_positions / env.map_width * 640
for position in positions:
pygame.draw.circle(screen, (255, 0, 0), position.tolist(), 5)
# constants
fg = (255, 255, 255) if darkdetect.isDark() else (0, 0, 0)
bg = (0, 0, 0) if darkdetect.isDark() else (255, 255, 255)


# functions
def run_fn():
from parabellum import Parabellum, scenarios
# types
State = jaxmarl.environments.smax.smax_env.State
Obs = Reward = Done = Action = Dict[str, jnp.ndarray]
StateSeq = List[Tuple[jnp.ndarray, State, Action]]

scenario = scenarios["default"]
env = Parabellum(scenario=scenario, map_width=32, map_height=32)
rng, key = jax.random.split(jax.random.PRNGKey(0))
obs, state = env.reset(key)
pygame.init()

screen = pygame.display.set_mode((640, 640))
@dataclass
class Control:
running: bool = True
paused: bool = False
click: Optional[Tuple[int, int]] = None


@dataclass
class Game:
clock: pygame.time.Clock
state: State
obs: Dict
state_seq: StateSeq
control: Control
env: pb.Parabellum
rng: random.PRNGKey


def handle_event(event, control_state):
"""Handle pygame events."""
if event.type == pygame.QUIT:
control_state.running = False
if event.type == pygame.MOUSEBUTTONDOWN:
pos = pygame.mouse.get_pos()
control_state.click = pos
if event.type == pygame.MOUSEBUTTONUP:
control_state.click = None
if event.type == pygame.KEYDOWN: # any key press pauses
control_state.paused = not control_state.paused
return control_state


for i in range(10):
# take random actions and show the environment
actions = {a: env.action_space(a).sample(rng) for a in env.agents}
obs, state, _, _, _ = env.step(rng, state, actions)
plot_frame(env, screen, state)
def control_fn(game):
"""Handle pygame events."""
for event in pygame.event.get():
game.control = handle_event(event, game.control)
return game


def render_fn(screen, game):
"""Render the game."""
if len(game.state_seq) < 3:
return game
for rng, state, action in env.expand_state_seq(game.state_seq[-2:])[-8:]:
screen.fill(bg)
if game.control.click is not None:
pygame.draw.circle(screen, "red", game.control.click, 10)
unit_positions = state.unit_positions
for pos in unit_positions:
pos = (pos / env.map_width * 800).tolist()
pygame.draw.circle(screen, fg, pos, 5)
pygame.display.flip()
pygame.time.wait(1000)
time.sleep(0.1)
game.clock.tick(24) # limits FPS to 24
return game

# exit loop
pygame.quit()

def step_fn(game):
"""Step in parabellum."""
rng, act_rng, step_key = random.split(game.rng, 3)
act_key = random.split(act_rng, env.num_agents)
action = {
a: env.action_space(a).sample(act_key[i]) for i, a in enumerate(env.agents)
}
state_seq_entry = (step_key, game.state, action)
# append state_seq_entry to state_seq
game.state_seq.append(state_seq_entry)
obs, state, reward, done, info = env.step(step_key, game.state, action)
game.state = state
game.obs = obs
game.rng = rng
return game


# state
if __name__ == "__main__":
run_fn()
env = pb.Parabellum(pb.scenarios["default"])
pygame.init()
screen = pygame.display.set_mode((1000, 1000))
render = partial(render_fn, screen)
rng, key = random.split(random.PRNGKey(0))
obs, state = env.reset(key)
kwargs = dict(
control=Control(),
env=env,
rng=rng,
state_seq=[], # [(key, state, action)]
clock=pygame.time.Clock(),
state=state,
obs=obs,
)
game = Game(**kwargs)

while game.control.running:
game = control_fn(game)
game = game if game.control.paused else step_fn(game)
game = game if game.control.paused else render(game)

pygame.quit()

0 comments on commit e9239b4

Please sign in to comment.