-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathgridworld.jl
124 lines (93 loc) · 4.42 KB
/
gridworld.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
const GWPos = SVector{2,Int}
"""
SimpleGridWorld(;kwargs...)
Create a simple grid world MDP. Options are specified with keyword arguments.
# States and Actions
The states are represented by 2-element static vectors of integers. Typically any Julia `AbstractVector` e.g. `[x,y]` can also be used for arguments. Actions are the symbols `:up`, `:left`, `:down`, and `:right`.
# Keyword Arguments
- `size::Tuple{Int, Int}`: Number of cells in the x and y direction [default: `(10,10)`]
- `rewards::Dict`: Dictionary mapping cells to the reward in that cell, e.g. `Dict([1,2]=>10.0)`. Default reward for unlisted cells is 0.0
- `terminate_from::Set`: Set of cells from which the problem will terminate. Note that these states are not themselves terminal, but from these states, the next transition will be to a terminal state. [default: `Set(keys(rewards))`]
- `tprob::Float64`: Probability of a successful transition in the direction specified by the action. The remaining probability is divided between the other neighbors. [default: `0.7`]
- `discount::Float64`: Discount factor [default: `0.95`]
"""
Base.@kwdef struct SimpleGridWorld <: MDP{GWPos, Symbol}
size::Tuple{Int, Int} = (10,10)
rewards::Dict{GWPos, Float64} = Dict(GWPos(4,3)=>-10.0, GWPos(4,6)=>-5.0, GWPos(9,3)=>10.0, GWPos(8,8)=>3.0)
terminate_from::Set{GWPos} = Set(keys(rewards))
tprob::Float64 = 0.7
discount::Float64 = 0.95
end
# States
function POMDPs.states(mdp::SimpleGridWorld)
ss = vec(GWPos[GWPos(x, y) for x in 1:mdp.size[1], y in 1:mdp.size[2]])
push!(ss, GWPos(-1,-1))
return ss
end
function POMDPs.stateindex(mdp::SimpleGridWorld, s::AbstractVector{Int})
if all(s.>0)
return LinearIndices(mdp.size)[s...]
else
return prod(mdp.size) + 1 # TODO: Change
end
end
struct GWUniform
size::Tuple{Int, Int}
end
Base.rand(rng::AbstractRNG, d::GWUniform) = GWPos(rand(rng, 1:d.size[1]), rand(rng, 1:d.size[2]))
function POMDPs.pdf(d::GWUniform, s::GWPos)
if all(1 .<= s[1] .<= d.size)
return 1/prod(d.size)
else
return 0.0
end
end
POMDPs.support(d::GWUniform) = (GWPos(x, y) for x in 1:d.size[1], y in 1:d.size[2])
POMDPs.initialstate(mdp::SimpleGridWorld) = GWUniform(mdp.size)
# Actions
POMDPs.actions(mdp::SimpleGridWorld) = (:up, :down, :left, :right)
Base.rand(rng::AbstractRNG, t::NTuple{L,Symbol}) where L = t[rand(rng, 1:length(t))] # don't know why this doesn't work out of the box
const dir = Dict(:up=>GWPos(0,1), :down=>GWPos(0,-1), :left=>GWPos(-1,0), :right=>GWPos(1,0))
const aind = Dict(:up=>1, :down=>2, :left=>3, :right=>4)
POMDPs.actionindex(mdp::SimpleGridWorld, a::Symbol) = aind[a]
# Transitions
POMDPs.isterminal(m::SimpleGridWorld, s::AbstractVector{Int}) = any(s.<0)
function POMDPs.transition(mdp::SimpleGridWorld, s::AbstractVector{Int}, a::Symbol)
if s in mdp.terminate_from || isterminal(mdp, s)
return Deterministic(GWPos(-1,-1))
end
destinations = MVector{length(actions(mdp))+1, GWPos}(undef)
destinations[1] = s
probs = @MVector(zeros(length(actions(mdp))+1))
for (i, act) in enumerate(actions(mdp))
if act == a
prob = mdp.tprob # probability of transitioning to the desired cell
else
prob = (1.0 - mdp.tprob)/(length(actions(mdp)) - 1) # probability of transitioning to another cell
end
dest = s + dir[act]
destinations[i+1] = dest
if !inbounds(mdp, dest) # hit an edge and come back
probs[1] += prob
destinations[i+1] = GWPos(-1, -1) # dest was out of bounds - this will have probability zero, but it should be a valid state
else
probs[i+1] += prob
end
end
return SparseCat(convert(SVector, destinations), convert(SVector, probs))
end
function inbounds(m::SimpleGridWorld, s::AbstractVector{Int})
return 1 <= s[1] <= m.size[1] && 1 <= s[2] <= m.size[2]
end
# Rewards
POMDPs.reward(mdp::SimpleGridWorld, s::AbstractVector{Int}) = get(mdp.rewards, s, 0.0)
POMDPs.reward(mdp::SimpleGridWorld, s::AbstractVector{Int}, a::Symbol) = reward(mdp, s)
# discount
POMDPs.discount(mdp::SimpleGridWorld) = mdp.discount
# Conversion
function POMDPs.convert_a(::Type{V}, a::Symbol, m::SimpleGridWorld) where {V<:AbstractArray}
convert(V, [aind[a]])
end
function POMDPs.convert_a(::Type{Symbol}, vec::V, m::SimpleGridWorld) where {V<:AbstractArray}
actions(m)[convert(Int, first(vec))]
end