You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

324 lines
9.8 KiB
Python

# %%
from time import sleep
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
torch.set_float32_matmul_precision('high')
torch.set_default_device("cuda:1")
random.seed(114514)
env = gym.make("CartPole-v1")
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
from IPython import display
plt.ion()
# if gpu is to be used
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
Transition = namedtuple('Transition',
('state', 'action', 'next_state', 'reward', 'term'))
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, *args):
"""Save a transition"""
self.memory.append(Transition(*args))
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
# class DQN(nn.Module):
# def __init__(self, n_observations, n_actions):
# super(DQN, self).__init__()
# self.layer1 = nn.Linear(n_observations, 128)
# self.layer2 = nn.Linear(128, 128)
# self.layer3 = nn.Linear(128, n_actions)
# # Called with either one element to determine next action, or a batch
# # during optimization. Returns tensor([[left0exp,right0exp]...]).
# def forward(self, x):
# x = F.relu(self.layer1(x))
# x = F.relu(self.layer2(x))
# return self.layer3(x)
def soft_update(dst: nn.Module, src: nn.Module, tau):
dst_states = dst.state_dict()
src_states = src.state_dict()
new_states = {
k: src_states[k] * tau + dst_states[k] * (1 - tau)
for k in src_states
}
dst.load_state_dict(new_states)
def DQN(n_observations, n_actions):
return nn.Sequential(
nn.Linear(n_observations, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, n_actions),
)
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the AdamW optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)
policy_net = DQN(n_observations, n_actions)
target_net = DQN(n_observations, n_actions)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
# optimizer = optim.SGD(policy_net.parameters(), lr=LR)
memory = ReplayMemory(10000)
steps_done = 0
class Break(Exception):
def _render_traceback_(self):
print("Break")
def select_action(state):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * \
math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[env.action_space.sample()]], dtype=torch.long)
episode_durations = []
def plot_durations(show_result=False):
plt.figure(1)
durations_t = torch.tensor(episode_durations, dtype=torch.float)
if show_result:
plt.title('Result')
else:
plt.clf()
plt.title('Training...')
plt.xlabel('Episode')
plt.ylabel('Duration')
plt.plot(durations_t.numpy())
# Take 100 episode averages and plot them too
if len(durations_t) >= 100:
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
means = torch.cat((torch.zeros(99), means))
plt.plot(means.numpy())
plt.pause(0.001) # pause a bit so that plots are updated
if is_ipython:
if not show_result:
display.display(plt.gcf())
display.clear_output(wait=True)
else:
display.display(plt.gcf())
first_run = True
def seq_equal(seq1, seq2):
if len(seq1) != len(seq2):
return False
for c1, c2 in zip(seq1, seq2):
if c1 != c2:
return False
return True
def optimize_model():
global first_run
if len(memory) < BATCH_SIZE:
return 0
transitions = memory.sample(BATCH_SIZE)
if False:
memory.memory.clear()
first_run = False
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
# detailed explanation). This converts batch-array of Transitions
# to Transition of batch-arrays.
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
# (a final state would've been the one after which simulation ended)
non_final_mask = torch.tensor(tuple(
map(lambda s: s is not None, batch.next_state)),
dtype=torch.bool)
# non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
# print(f"{non_final_next_states.shape=}")
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# print(batch.next_state)
termflt = [x == None for x in batch.next_state]
# termflt = batch.term
if not seq_equal(termflt,batch.term):
raise "Bad!"
# print(env.observation_space.sample())
batch_next_state = [(x if x != None else torch.zeros(1, 4))
for x in batch.next_state]
# print(termflt)
nxtstate_batch = torch.cat(batch_next_state)
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
# columns of actions taken. These are the actions which would've been taken
# for each batch state according to policy_net
state_action_values = policy_net(state_batch).gather(1, action_batch)
# Compute V(s_{t+1}) for all next states.
# Expected values of actions for non_final_next_states are computed based
# on the "older" target_net; selecting their best reward with max(1)[0].
# This is merged based on the mask, such that we'll have either the expected
# state value or 0 in case the state was final.
# print(f"{ nxtstate_batch=}")
# print(f"{ nxtstate_batch.shape=}")
# print(f"{non_final_next_states=}")
# next_state_values = torch.zeros(BATCH_SIZE, )
# with torch.no_grad():
# next_state_values[non_final_mask] = target_net(
# non_final_next_states).max(1)[0]
with torch.no_grad():
next_state_values = target_net(nxtstate_batch).max(1)[0]
# print(next_state_values)
next_state_values[batch.term] = 0
print(next_state_values[batch.term])
print(next_state_values[list(batch.term)])
# print(next_state_values)
raise Break
# Compute the expected Q values
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
# Compute Huber loss
# criterion = nn.SmoothL1Loss()
loss = F.smooth_l1_loss(state_action_values,
expected_state_action_values.unsqueeze(1))
# Optimize the model
optimizer.zero_grad()
loss.backward()
# In-place gradient clipping
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
optimizer.step()
return loss.item()
import plotly.graph_objects as go
fig_reward = go.FigureWidget(go.Scatter())
fig_loss = go.FigureWidget(go.Scatter())
from IPython import display
display.display(fig_reward)
display.display(fig_loss)
from pprint import pprint
ln_reward = fig_reward.data[0]
ln_loss = fig_loss.data[0]
loss_hst = []
if torch.cuda.is_available():
num_episodes = 100
else:
num_episodes = 500
from tqdm.auto import tqdm
for i_episode in range(num_episodes):
# Initialize the environment and get it's state
state, info = env.reset()
state = torch.tensor(
state,
dtype=torch.float32,
).unsqueeze(0)
for t in count():
action = select_action(state)
observation, reward, terminated, truncated, _ = env.step(action.item())
reward = torch.tensor([reward], )
done = terminated or truncated
if terminated:
next_state = None
else:
next_state = torch.tensor(
observation,
dtype=torch.float32,
).unsqueeze(0)
# Store the transition in memory
memory.push(state, action, next_state, reward, terminated)
# Move to the next state
state = next_state
# Perform one step of the optimization (on the policy network)
opt_res = optimize_model()
loss_hst.append(opt_res)
# Soft update of the target network's weights
# θ′ ← τ θ + (1 −τ )θ′
# target_net_state_dict = target_net.state_dict()
# policy_net_state_dict = policy_net.state_dict()
# new_state_dict = {
# key: policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
# for key in policy_net_state_dict
# }
# target_net.load_state_dict(new_state_dict)
soft_update(target_net, policy_net, TAU)
if done:
episode_durations.append(t + 1)
# plot_durations()
with fig_reward.batch_update():
ln_reward.y = episode_durations
with fig_loss.batch_update():
ln_loss.y = loss_hst
# tqdm.write(f"opt={opt_res} rew={t + 1}")
break
print('Complete')
# plot_durations(show_result=True)
plt.ioff()
plt.show()
# %%