Jun / Sep 25 2019
Chapter12 Random Walk
using ReinforcementLearning, ReinforcementLearningEnvironments, RLIntro.RandomWalk
using StatsBase, Plots
const N = 21 const true_values = -1:0.1:1
-1.0:0.1:1.0
Base. struct RecordRMS <: AbstractHook rms::Vector{Float64}=[] end (h::RecordRMS)(::PostEpisodeStage, agent, env, obs) = push!(h.rms, sqrt(mean((agent.π.learner.approximator.table[2:end-1] - true_values[2:end-1]).^2)))
function create_agent_env(α, λ) env = RandomWalkEnv(N=21) ns, na = length(observation_space(env)), length(action_space(env)) agent = Agent( π=VBasedPolicy( learner=TDλReturnLearner( approximator=TabularVApproximator(zeros(ns)), γ=1.0, α=α, λ=λ ), f = obs -> rand(1:na) ), buffer=episode_RTSA_buffer() ) agent, env end function records(α, λ, nruns=10) rms = [] for _ in 1:nruns hook = RecordRMS() run(create_agent_env(α, λ)..., StopAfterEpisode(10, is_show_progress=false);hook=hook) push!(rms, mean(hook.rms)) end mean(rms) end
records (generic function with 2 methods)
As = [0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.1:1, 0:0.05:0.5, 0:0.02:0.2, 0:0.01:0.1] Λ = [0., 0.4, .8, 0.9, 0.95, 0.975, 0.99, 1.] p = plot(legend=:topright) for (A, λ) in zip(As, Λ) plot!(p, A, [records(α, λ) for α in A], label="lambda = $λ") end ylims!(p, (0.25, 0.55)) p