Jun / Sep 25 2019
Chapter06 Random Walk
using ReinforcementLearning, ReinforcementLearningEnvironments
using RLIntro, RLIntro.RandomWalk
using StatsBase
using Plots
true_values = [i/6 for i in 1:5]
5-element Array{Float64,1}:
0.166667
0.333333
0.5
0.666667
0.833333
env = RandomWalkEnv(leftreward=0.) n_states, n_actions = length(observation_space(env)), length(action_space(env))
(7, 2)
struct RecordRMS <: AbstractHook rms::Vector{Float64} RecordRMS() = new([]) end (f::RecordRMS)(::PostEpisodeStage, agent, env, obs) = push!(f.rms, sqrt(mean((agent.π.learner.approximator.table[2:end - 1] - true_values).^2)))
create_TD_agent(α) = Agent( VBasedPolicy( learner = TDLearner( approximator=TabularVApproximator(fill(0.5, n_states)), optimizer=Descent(α), method=:SRS ), f=TabularRandomPolicy(fill(1 / n_actions, n_states, n_actions)) ), episode_RTSA_buffer() ) create_MC_agent(α) = Agent( VBasedPolicy( learner=MonteCarloLearner( approximator=TabularVApproximator(fill(0.5, n_states)), α=α, kind=EVERY_VISIT ), f = TabularRandomPolicy(fill(1 / n_actions, n_states, n_actions)) ), episode_RTSA_buffer() )
create_MC_agent (generic function with 1 method)
p = plot(;legend=:bottomright) for i in [1, 10, 100] agent = create_TD_agent(0.1) run(agent, env, StopAfterEpisode(i)) plot!(p, agent.π.learner.approximator.table[2:end - 1], label="episode = $i") end plot!(p, true_values, label="true value") p
p = plot() for α in [0.05, 0.1, 0.15] rms = [] for _ in 1:100 agent = create_TD_agent(α) hook = RecordRMS() run(agent, env, StopAfterEpisode(100);hook=hook) push!(rms, hook.rms) end plot!(p, mean(rms), label ="TD alpha=$α", linestyle=:dashdot) end for α in [0.01, 0.02, 0.03, 0.04] rms = [] for _ in 1:100 agent = create_MC_agent(α) hook = RecordRMS() run(agent, env, StopAfterEpisode(100);hook=hook) push!(rms, hook.rms) end plot!(p, mean(rms), label ="MC alpha=$α") end p
p = plot() rms = [] for _ in 1:100 agent = create_TD_agent(0.1) hook = RecordRMS() run(agent, env, StopAfterEpisode(100);hook=hook) push!(rms, hook.rms) end plot!(p, mean(rms), label ="TD alpha=0.1", linestyle=:dashdot) rms = [] for _ in 1:100 agent = create_MC_agent(0.1) hook = RecordRMS() run(agent, env, StopAfterEpisode(100);hook=hook) push!(rms, hook.rms) end plot!(p, mean(rms), label ="MC alpha=0.1") p