# %% from collections import deque from itertools import count import math import random import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import gymnasium as gym random.seed(114514) torch.set_float32_matmul_precision('high') torch.set_default_device("cuda:1") class ReplayMemory: def __init__(self, cap) -> None: self.memory = deque(maxlen=cap) def sample(self, batch_size): return random.sample(self.memory, batch_size) def __len__(self): return len(self.memory) def DQN(n_observations, n_actions): return nn.Sequential( nn.Linear(n_observations, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, n_actions), ) EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 1000 def get_threshold(steps): return EPS_END + (EPS_START - EPS_END) * math.exp(-1 * steps / EPS_DECAY) BATCH_SIZE = 256 GAMMA = 0.99 TAU = 0.005 LR = 1e-4 env = gym.make("CartPole-v1") n_actions = env.action_space.n state, info = env.reset() n_observations = len(state) policy_net = DQN(n_observations, n_actions) target_net = DQN(n_observations, n_actions) target_net.load_state_dict(policy_net.state_dict()) optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True) # optimizer = optim.SGD(policy_net.parameters(), lr=LR) memory = ReplayMemory(10000) steps_done = 0 ths_list = [] def soft_update(dst: nn.Module, src: nn.Module, tau): dst_states = dst.state_dict() src_states = src.state_dict() new_states = {k: src_states[k] * tau + dst_states[k] * (1 - tau) for k in src_states} dst.load_state_dict(new_states) def select_action(state: torch.Tensor): global steps_done threshold = get_threshold(steps_done) steps_done += 1 if len(memory) < BATCH_SIZE: steps_done = 0 threshold = 1 if random.random() > threshold: with torch.no_grad(): return policy_net(state).argmax(1).view(1, 1) else: return torch.tensor(env.action_space.sample()).view(1, 1) class Break(Exception): def _render_traceback_(self): pass def optimize_model(ddqn: bool): if len(memory) < BATCH_SIZE: return 0 tranitions = memory.sample(BATCH_SIZE) state, action, next_state, reward, terminated = list(zip(*tranitions)) state_batch = torch.cat(state) action_batch = torch.cat(action) next_state_batch = torch.cat(next_state) reward_batch = torch.cat(reward) state_action_values = policy_net(state_batch).gather(1, action_batch) with torch.no_grad(): if ddqn: # select target_net value based on policy_net action next_max_actions = policy_net(next_state_batch).max(1)[1] next_state_values = target_net(next_state_batch).gather(1, next_max_actions.unsqueeze(1)).squeeze(1) else: next_state_values = target_net(next_state_batch).max(1)[0] expected_state_action_values = (next_state_values * GAMMA) + reward_batch expected_state_action_values[list(terminated)] = 0 loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100) optimizer.step() return loss.item() num_episodes = 1000 reward_hist = [] from tqdm import tqdm import plotly.graph_objects as go fig_reward = go.FigureWidget(go.Scatter()) fig_loss = go.FigureWidget(go.Scatter()) from IPython import display display.display(fig_reward) display.display(fig_loss) import numpy as np ln = fig_reward.data[0] ln_loss = fig_loss.data[0] loss_hst = [0] first_clean = True for eps in tqdm(range(600)): state, info = env.reset() state = torch.tensor(state).unsqueeze(0) cur_epoch_losses = [] for t in count(): action = select_action(state) observation, reward, terminated, truncated, _ = env.step(action.item()) reward = torch.tensor(reward).unsqueeze(0) next_state = torch.tensor(observation).unsqueeze(0) memory.memory.append((state, action, next_state, reward, terminated)) # pprint((state, action, next_state, reward)) # loss_hst.append(optimize_model()) cur_epoch_losses.append(optimize_model(False)) if first_clean and len(memory) > BATCH_SIZE: memory.memory.clear() first_clean = False # if cur_epoch_losses[-1] > 0.01: # if eps > 100 and np.mean(loss_hst[-5:]) < 0.01 and first_clean: # memory.memory.clear() # first_clean = False # print(f"cleared at epoch {eps}") if steps_done % 10 == 1: soft_update(target_net, policy_net, TAU * 10) done = terminated or truncated if done: ths_list.append(get_threshold(steps_done)) reward_hist.append(t) loss_hst.append(np.mean(cur_epoch_losses)) with fig_reward.batch_update(): ln.y = reward_hist with fig_loss.batch_update(): ln_loss.y = loss_hst break state = next_state # import plotly.express as px # px.line(loss_hst) # ax = plt.axes() # ax2 = ax.twinx() # ax.plot(reward_hist) # ax2.plot(ths_list) # ax2.set_ylim(0, 1) # plt.show() # %%