You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
5.3 KiB
Python

# %%
from collections import deque
from itertools import count
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gymnasium as gym
random.seed(114514)
torch.set_float32_matmul_precision('high')
torch.set_default_device("cuda:1")
class ReplayMemory:
def __init__(self, cap) -> None:
self.memory = deque(maxlen=cap)
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
def DQN(n_observations, n_actions):
return nn.Sequential(
nn.Linear(n_observations, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, n_actions),
)
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
def get_threshold(steps):
return EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps / EPS_DECAY)
BATCH_SIZE = 256
GAMMA = 0.99
TAU = 0.005
LR = 1e-4
env = gym.make("CartPole-v1").unwrapped
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)
policy_net = DQN(n_observations, n_actions)
target_net = DQN(n_observations, n_actions)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
# optimizer = optim.SGD(policy_net.parameters(), lr=LR)
memory = ReplayMemory(10000)
steps_done = 0
ths_list = []
def soft_update(dst: nn.Module, src: nn.Module, tau):
dst_states = dst.state_dict()
src_states = src.state_dict()
new_states = {k: src_states[k] * tau + dst_states[k] * (1 - tau) for k in src_states}
dst.load_state_dict(new_states)
def select_action(state: torch.Tensor):
global steps_done
threshold = get_threshold(steps_done)
steps_done += 1
if len(memory) < BATCH_SIZE:
steps_done = 0
threshold = 1
if random.random() > threshold:
with torch.no_grad():
return policy_net(state).argmax(1).view(1, 1)
else:
return torch.tensor(env.action_space.sample()).view(1, 1)
class Break(Exception):
def _render_traceback_(self):
pass
def optimize_model(ddqn: bool):
if len(memory) < BATCH_SIZE:
return 0
tranitions = memory.sample(BATCH_SIZE)
state, action, next_state, reward, terminated = list(zip(*tranitions))
state_batch = torch.cat(state)
action_batch = torch.cat(action)
next_state_batch = torch.cat(next_state)
reward_batch = torch.cat(reward)
state_action_values = policy_net(state_batch).gather(1, action_batch)
with torch.no_grad():
if ddqn:
# select target_net value based on policy_net action
next_max_actions = policy_net(next_state_batch).max(1)[1]
next_state_values = target_net(next_state_batch).gather(1, next_max_actions.unsqueeze(1)).squeeze(1)
else:
next_state_values = target_net(next_state_batch).max(1)[0]
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
expected_state_action_values[list(terminated)] = 0
loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
optimizer.step()
return loss.item()
num_episodes = 1000
reward_hist = []
from tqdm import tqdm
import plotly.graph_objects as go
fig_reward = go.FigureWidget(go.Scatter())
fig_loss = go.FigureWidget(go.Scatter())
from IPython import display
display.display(fig_reward)
display.display(fig_loss)
import numpy as np
ln = fig_reward.data[0]
ln_loss = fig_loss.data[0]
loss_hst = [0]
first_clean = True
for eps in tqdm(range(600)):
state, info = env.reset()
state = torch.tensor(state).unsqueeze(0)
cur_epoch_losses = []
for t in count():
action = select_action(state)
observation, reward, terminated, truncated, _ = env.step(action.item())
reward = torch.tensor(reward).unsqueeze(0)
next_state = torch.tensor(observation).unsqueeze(0)
memory.memory.append((state, action, next_state, reward, terminated))
# pprint((state, action, next_state, reward))
# loss_hst.append(optimize_model())
cur_epoch_losses.append(optimize_model(False))
if first_clean and len(memory) > BATCH_SIZE:
memory.memory.clear()
first_clean = False
# if cur_epoch_losses[-1] > 0.01:
# if eps > 100 and np.mean(loss_hst[-5:]) < 0.01 and first_clean:
# memory.memory.clear()
# first_clean = False
# print(f"cleared at epoch {eps}")
if steps_done % 10 == 1:
soft_update(target_net, policy_net, TAU * 10)
done = terminated or truncated
if done:
ths_list.append(get_threshold(steps_done))
reward_hist.append(t)
loss_hst.append(np.mean(cur_epoch_losses))
with fig_reward.batch_update():
ln.y = reward_hist
with fig_loss.batch_update():
ln_loss.y = loss_hst
break
state = next_state
# import plotly.express as px
# px.line(loss_hst)
# ax = plt.axes()
# ax2 = ax.twinx()
# ax.plot(reward_hist)
# ax2.plot(ths_list)
# ax2.set_ylim(0, 1)
# plt.show()
# %%