You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
195 lines
5.3 KiB
Python
195 lines
5.3 KiB
Python
# %%
|
|
from collections import deque
|
|
from itertools import count
|
|
import math
|
|
import random
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
|
|
import gymnasium as gym
|
|
|
|
random.seed(114514)
|
|
|
|
torch.set_float32_matmul_precision('high')
|
|
torch.set_default_device("cuda:1")
|
|
|
|
|
|
class ReplayMemory:
|
|
def __init__(self, cap) -> None:
|
|
self.memory = deque(maxlen=cap)
|
|
|
|
def sample(self, batch_size):
|
|
return random.sample(self.memory, batch_size)
|
|
|
|
def __len__(self):
|
|
return len(self.memory)
|
|
|
|
|
|
def DQN(n_observations, n_actions):
|
|
return nn.Sequential(
|
|
nn.Linear(n_observations, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, n_actions),
|
|
)
|
|
|
|
|
|
EPS_START = 0.9
|
|
EPS_END = 0.05
|
|
EPS_DECAY = 1000
|
|
|
|
|
|
def get_threshold(steps):
|
|
return EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps / EPS_DECAY)
|
|
|
|
|
|
BATCH_SIZE = 256
|
|
GAMMA = 0.99
|
|
|
|
TAU = 0.005
|
|
LR = 1e-4
|
|
|
|
env = gym.make("CartPole-v1").unwrapped
|
|
n_actions = env.action_space.n
|
|
state, info = env.reset()
|
|
n_observations = len(state)
|
|
|
|
policy_net = DQN(n_observations, n_actions)
|
|
target_net = DQN(n_observations, n_actions)
|
|
target_net.load_state_dict(policy_net.state_dict())
|
|
|
|
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
|
# optimizer = optim.SGD(policy_net.parameters(), lr=LR)
|
|
|
|
memory = ReplayMemory(10000)
|
|
|
|
steps_done = 0
|
|
|
|
ths_list = []
|
|
|
|
|
|
def soft_update(dst: nn.Module, src: nn.Module, tau):
|
|
dst_states = dst.state_dict()
|
|
src_states = src.state_dict()
|
|
new_states = {k: src_states[k] * tau + dst_states[k] * (1 - tau) for k in src_states}
|
|
dst.load_state_dict(new_states)
|
|
|
|
|
|
def select_action(state: torch.Tensor):
|
|
global steps_done
|
|
threshold = get_threshold(steps_done)
|
|
steps_done += 1
|
|
if len(memory) < BATCH_SIZE:
|
|
steps_done = 0
|
|
threshold = 1
|
|
if random.random() > threshold:
|
|
with torch.no_grad():
|
|
return policy_net(state).argmax(1).view(1, 1)
|
|
else:
|
|
return torch.tensor(env.action_space.sample()).view(1, 1)
|
|
|
|
|
|
class Break(Exception):
|
|
def _render_traceback_(self):
|
|
pass
|
|
|
|
|
|
def optimize_model(ddqn: bool):
|
|
if len(memory) < BATCH_SIZE:
|
|
return 0
|
|
tranitions = memory.sample(BATCH_SIZE)
|
|
state, action, next_state, reward, terminated = list(zip(*tranitions))
|
|
state_batch = torch.cat(state)
|
|
action_batch = torch.cat(action)
|
|
next_state_batch = torch.cat(next_state)
|
|
reward_batch = torch.cat(reward)
|
|
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
|
with torch.no_grad():
|
|
if ddqn:
|
|
# select target_net value based on policy_net action
|
|
next_max_actions = policy_net(next_state_batch).max(1)[1]
|
|
next_state_values = target_net(next_state_batch).gather(1, next_max_actions.unsqueeze(1)).squeeze(1)
|
|
else:
|
|
next_state_values = target_net(next_state_batch).max(1)[0]
|
|
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
|
expected_state_action_values[list(terminated)] = 0
|
|
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
|
optimizer.step()
|
|
return loss.item()
|
|
|
|
|
|
num_episodes = 1000
|
|
reward_hist = []
|
|
|
|
from tqdm import tqdm
|
|
|
|
import plotly.graph_objects as go
|
|
|
|
fig_reward = go.FigureWidget(go.Scatter())
|
|
fig_loss = go.FigureWidget(go.Scatter())
|
|
from IPython import display
|
|
|
|
display.display(fig_reward)
|
|
display.display(fig_loss)
|
|
import numpy as np
|
|
|
|
ln = fig_reward.data[0]
|
|
ln_loss = fig_loss.data[0]
|
|
loss_hst = [0]
|
|
first_clean = True
|
|
for eps in tqdm(range(600)):
|
|
state, info = env.reset()
|
|
state = torch.tensor(state).unsqueeze(0)
|
|
cur_epoch_losses = []
|
|
for t in count():
|
|
action = select_action(state)
|
|
observation, reward, terminated, truncated, _ = env.step(action.item())
|
|
reward = torch.tensor(reward).unsqueeze(0)
|
|
next_state = torch.tensor(observation).unsqueeze(0)
|
|
memory.memory.append((state, action, next_state, reward, terminated))
|
|
# pprint((state, action, next_state, reward))
|
|
# loss_hst.append(optimize_model())
|
|
cur_epoch_losses.append(optimize_model(False))
|
|
if first_clean and len(memory) > BATCH_SIZE:
|
|
memory.memory.clear()
|
|
first_clean = False
|
|
# if cur_epoch_losses[-1] > 0.01:
|
|
# if eps > 100 and np.mean(loss_hst[-5:]) < 0.01 and first_clean:
|
|
# memory.memory.clear()
|
|
# first_clean = False
|
|
# print(f"cleared at epoch {eps}")
|
|
|
|
if steps_done % 10 == 1:
|
|
soft_update(target_net, policy_net, TAU * 10)
|
|
done = terminated or truncated
|
|
if done:
|
|
ths_list.append(get_threshold(steps_done))
|
|
reward_hist.append(t)
|
|
loss_hst.append(np.mean(cur_epoch_losses))
|
|
with fig_reward.batch_update():
|
|
ln.y = reward_hist
|
|
with fig_loss.batch_update():
|
|
ln_loss.y = loss_hst
|
|
break
|
|
|
|
state = next_state
|
|
# import plotly.express as px
|
|
|
|
# px.line(loss_hst)
|
|
# ax = plt.axes()
|
|
# ax2 = ax.twinx()
|
|
# ax.plot(reward_hist)
|
|
# ax2.plot(ths_list)
|
|
# ax2.set_ylim(0, 1)
|
|
# plt.show()
|
|
|
|
# %%
|