Jun / Sep 25 2019

Chapter11 Counter Example

using ReinforcementLearning, ReinforcementLearningEnvironments, RLIntro.BairdCounter
env = BairdCounterEnv()
BairdCounterEnv(1, DiscreteSpace{Int64}(1, 7, 7), DiscreteSpace{Int64}(1, 2, 2))
Base.@kwdef struct RecordWeights <: AbstractHook
    weights::Vector{Vector{Float64}}=[]
end

(h::RecordWeights)(::PostActStage, agent, env, action_obs) = push!(h.weights, agent.π.π_target.learner.approximator.weights |> deepcopy)
Base.@kwdef struct StateMapping <: AbstractPreprocessor
    mapping::Array{Int,2}=features
end

(p::StateMapping)(s::Int) = @view(p.mapping[s, :])
env = BairdCounterEnv()
ns = length(observation_space(env))
na = length(action_space(env))
init_weights = ones(Float64, 8)
init_weights[7] = 10

features = zeros(ns, length(init_weights))
for i in 1:6
    features[i, i] = 2
    features[i, 8] = 1
end
features[7, 7] = 1
features[7, 8] = 2

π_b = obs -> rand() < 6/7 ? 1 : 2
π_t = obs -> 2
prob_b = [6/7, 1/7]
prob_t = [0., 1.]
RL.get_prob(f::typeof(π_b), s, a) = prob_b[a]
RL.get_prob(f::typeof(π_t), s, a) = prob_t[a]

agent = Agent(
    π=OffPolicy(
        π_target=VBasedPolicy(
            learner=TDLearner(
                approximator=LinearVApproximator(init_weights),
                γ=0.99,
                optimizer=Descent(0.01),
                n=0,
                method=:SRS
            ),
            f=π_t
        ),
        π_behavior=π_b
    ),
    buffer=episode_RTSA_buffer(state_eltype=Any)
)

env = WrappedEnv(
    env=BairdCounterEnv(),
    preprocessor=StateMapping()
)
hook = RecordWeights()
run(agent, env, StopAfterStep(1000);hook=hook)

p = plot(legend=:topleft)
for i in 1:length(init_weights)
    plot!(p, [w[i] for w in hook.weights])
end
p