Jun / Sep 25 2019
Chapter05 Black Jack (Fig_5_3)
using ReinforcementLearning, ReinforcementLearningEnvironments
using RLIntro, RLIntro.BlackJack
env = BlackJackEnv(;init=(BlackJack.Hands(13, [1, 2], true), BlackJack.Hands(2, [2], false))) ns, na = length(observation_space(env)), length(action_space(env))
(220, 2)
init_state = get_state(observe(env))
27
const GOLD_VAL = -0.27726
-0.27726
stick_action = findall(x -> x == :stick, BlackJack.ACTIONS)[] table = fill(1, size(BlackJack.INDS)...) table[:, 10:11, :] .= stick_action table = reshape(table, :); π_behavior = TabularRandomPolicy(fill(0.5, length(table), na))
TabularRandomPolicy([0.5 0.5; 0.5 0.5; … ; 0.5 0.5; 0.5 0.5])
struct StoreMSE <: AbstractHook mse::Vector{Float64} end StoreMSE() = StoreMSE([]) (f::StoreMSE)(::PostEpisodeStage, agent, env, obs) = push!(f.mse, (GOLD_VAL - agent.π.π_target.learner.approximator(init_state))^2)
function mse_of_ordinary_sampling() agent = Agent( π=OffPolicy( VBasedPolicy( learner=MonteCarloLearner( approximator=TabularVApproximator(ns), kind=FIRST_VISIT, sampling=ORDINARY_IMPORTANCE_SAMPLING ), f= TabularDeterministicPolicy(table=table,nactions=na) ), π_behavior ), buffer=episode_RTSA_buffer() ) hook = StoreMSE([]) run(agent, env, StopAfterEpisode(10000, is_show_progress=false);hook=hook) hook.mse end
mse_of_ordinary_sampling (generic function with 1 method)
function mse_of_weighted_sampling() agent = Agent( π=OffPolicy( VBasedPolicy( learner=MonteCarloLearner( approximator=TabularVApproximator(ns), kind=FIRST_VISIT, sampling=WEIGHTED_IMPORTANCE_SAMPLING, returns=(CachedSum(), CachedSum()) ), f= TabularDeterministicPolicy(table=table,nactions=na) ), π_behavior ), buffer=episode_RTSA_buffer() ) hook = StoreMSE([]) run(agent, env, StopAfterEpisode(10000, is_show_progress=false);hook=hook) hook.mse end
mse_of_weighted_sampling (generic function with 1 method)
using Plots, StatsBase plot(mean((mse_of_ordinary_sampling() for _ in 1:100)); xscale=:log10, label="Ordinary Importance Sampling") plot!(mean((mse_of_weighted_sampling() for _ in 1:100)); xscale=:log10, label="Weighted Importance Sampling")