mirror of
https://github.com/zhigang1992/AI_for_video_games_demo.git
synced 2026-01-12 08:53:51 +08:00
78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
"""
|
|
Solving FrozenLake8x8 environment using Policy iteration.
|
|
Author : Moustafa Alzantot (malzantot@ucla.edu)
|
|
"""
|
|
import numpy as np
|
|
import gym
|
|
from gym import wrappers
|
|
|
|
|
|
def run_episode(env, policy, gamma = 1.0, render = False):
|
|
""" Runs an episode and return the total reward """
|
|
obs = env.reset()
|
|
total_reward = 0
|
|
step_idx = 0
|
|
while True:
|
|
if render:
|
|
env.render()
|
|
obs, reward, done , _ = env.step(int(policy[obs]))
|
|
total_reward += (gamma ** step_idx * reward)
|
|
step_idx += 1
|
|
if done:
|
|
break
|
|
return total_reward
|
|
|
|
|
|
def evaluate_policy(env, policy, gamma = 1.0, n = 100):
|
|
scores = [run_episode(env, policy, gamma, False) for _ in range(n)]
|
|
return np.mean(scores)
|
|
|
|
def extract_policy(v, gamma = 1.0):
|
|
""" Extract the policy given a value-function """
|
|
policy = np.zeros(env.nS)
|
|
for s in range(env.nS):
|
|
q_sa = np.zeros(env.nA)
|
|
for a in range(env.nA):
|
|
q_sa[a] = sum([p * (r + gamma * v[s_]) for p, s_, r, _ in env.P[s][a]])
|
|
policy[s] = np.argmax(q_sa)
|
|
return policy
|
|
|
|
def compute_policy_v(env, policy, gamma=1.0):
|
|
""" Iteratively evaluate the value-function under policy.
|
|
Alternatively, we could formulate a set of linear equations in iterms of v[s]
|
|
and solve them to find the value function.
|
|
"""
|
|
v = np.zeros(env.nS)
|
|
eps = 1e-10
|
|
while True:
|
|
prev_v = np.copy(v)
|
|
for s in range(env.nS):
|
|
policy_a = policy[s]
|
|
v[s] = sum([p * (r + gamma * prev_v[s_]) for p, s_, r, _ in env.P[s][policy_a]])
|
|
if (np.sum((np.fabs(prev_v - v))) <= eps):
|
|
# value converged
|
|
break
|
|
return v
|
|
|
|
def policy_iteration(env, gamma = 1.0):
|
|
""" Policy-Iteration algorithm """
|
|
policy = np.random.choice(env.nA, size=(env.nS)) # initialize a random policy
|
|
max_iterations = 200000
|
|
gamma = 1.0
|
|
for i in range(max_iterations):
|
|
old_policy_v = compute_policy_v(env, policy, gamma)
|
|
new_policy = extract_policy(old_policy_v, gamma)
|
|
if (np.all(policy == new_policy)):
|
|
print ('Policy-Iteration converged at step %d.' %(i+1))
|
|
break
|
|
policy = new_policy
|
|
return policy
|
|
|
|
|
|
if __name__ == '__main__':
|
|
env_name = 'FrozenLake8x8-v0'
|
|
env = gym.make(env_name).unwrapped
|
|
optimal_policy = policy_iteration(env, gamma = 1.0)
|
|
scores = evaluate_policy(env, optimal_policy, gamma = 1.0)
|
|
print('Average scores = ', np.mean(scores))
|