Chapter01 Tic Tac Toe
]st
using ReinforcementLearning, ReinforcementLearningEnvironments, RLIntro using RLIntro.TicTacToe env = TicTacToeEnv()
nstates, nactions = length(observation_space(env)), length(action_space(env))
If you are curious why there are 5478
states, you may see the discussions here
observe(env)
Now we'll use the Monte Carlo based method to estimate the value of each state for each player. Think about this, if we have the precise estimation of each state after taking some specific observation according to current observation, then we can just choose the action that leads to the maximum estimation.
Let's create a value approximator first (here we use the TabularVApproximator
defined in ReinforcementLearning.jl
):
V1 = TabularVApproximator(nstates) V2 = TabularVApproximator(nstates)
As you can see, by default all the estimations are initialed with 0.0
. Usually it won't be a problem, but here we can initialize it with a better starting point. For each state, we can check that if the state is a final state and set the initial estimation accordingly.
function init_V!(V, role) for i in 1:length(V.table) s = TicTacToe.ID2STATE[i] isdone, winner = TicTacToe.STATES_INFO[s] if isdone if winner === nothing V.table[i] = 0.5 elseif winner === role V.table[i] = 1. else V.table[i] = 0. end else V.table[i] = 0.5 end end V end
init_V!(V1, TicTacToe.offensive) init_V!(V2, TicTacToe.defensive)
Then we construct a MonteCarloLearner
for each player. Here the MonteCarloLearner
is just a wrapper around the approximator.
learner_1 = MonteCarloLearner(;approximator=V1, α=0.1, kind=EVERY_VISIT) learner_2 = MonteCarloLearner(;approximator=V2, α=0.1, kind=EVERY_VISIT)
Finally we will create the MonteCarloAgent
. To create such an agent, we need to provide a learner
and a policy
. We already have the learners above. Now let's create a policy.
A policy is a mapping from states to actions. Considering that we already have the estimations of states, a simple policy would be checking the estimation of the following up states and select one action which will result to the best state.
function create_policy(V, role, ϵ=0.01) obs -> begin legal_actions, state = findall(get_legal_actions(obs)), get_state(obs) next_states = TicTacToe.get_next_states(TicTacToe.ID2STATE[state], role, legal_actions) next_state_estimations = [V(TicTacToe.STATE2ID[ns]) for ns in next_states] max_val, inds = findallmax(next_state_estimations) rand() < ϵ ? rand(legal_actions) : legal_actions[rand(inds)] end end
π_1 = create_policy(V1, TicTacToe.offensive) π_2 = create_policy(V2, TicTacToe.defensive)
agent_1 = Agent( VBasedPolicy(learner_1, π_1), episode_RTSA_buffer(); role=TicTacToe.offensive); agent_2 = Agent( VBasedPolicy(learner_2, π_2), episode_RTSA_buffer(); role=TicTacToe.defensive);
run((agent_1, agent_2), env, StopAfterStep(1000000; is_show_progress=false))
eval_agent_1 = Agent(VBasedPolicy(agent_1.π.learner, create_policy(agent_1.π.learner, agent_1.role, 0.0)), similar(agent_1.buffer);role=agent_1.role); eval_agent_2 = Agent(VBasedPolicy(agent_2.π.learner, create_policy(agent_2.π.learner, agent_2.role, 0.0)), similar(agent_2.buffer);role=agent_2.role);
reset!(env)
[eval_agent_1.π.learner.approximator(TicTacToe.STATE2ID[s]) for s in TicTacToe.get_next_states(env.board, eval_agent_1.role)]
Now it's your turn to play this game!
function read_action_from_stdin() print("Your input:") input = parse(Int, readline()) !in(input, 1:9) && error("invalid input!") input end function play() env = TicTacToeEnv() println("""You play first! 1 4 7 2 5 8 3 6 9""") while true action = read_action_from_stdin() env(action) println(env) obs = observe(env, TicTacToe.offensive) if get_terminal(obs) if get_reward(obs) == 0.5 println("Tie!") elseif get_reward(obs) == 1.0 println("You win!") else println("Invalid input!") end break end env(eval_agent_2(observe(env))) println(env) obs = observe(env, TicTacToe.defensive) if get_terminal(obs) if get_reward(obs) == 0.5 println("Tie!") elseif get_reward(obs) == 1.0 println("Your lose!") else println("You win!") end break end end end
# play()