finish dqn

This commit is contained in:
Trinkle23897 2020-03-15 17:41:00 +08:00
parent c804662457
commit 5983c6b33d
8 changed files with 296 additions and 48 deletions

1
.gitignore vendored
View File

@ -135,3 +135,4 @@ dmypy.json
# customize
flake8.sh
log/

136
test/test_dqn.py Normal file
View File

@ -0,0 +1,136 @@
import gym
import tqdm
import torch
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg
from tianshou.data import Collector, ReplayBuffer
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.Tensor(s)
s = s.to(self.device)
batch = s.shape[0]
q = self.model(s.view(batch, -1))
return q, None
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=20)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_dqn(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
loss = nn.MSELoss()
policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step)
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
training_collector.collect(n_step=args.batch_size)
# log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10
for epoch in range(args.epoch):
desc = f"Epoch #{epoch + 1}"
# train
policy.train()
policy.sync_weight()
policy.set_eps(args.eps_train)
with tqdm.trange(
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
for _ in t:
training_collector.collect(n_step=args.collect_per_step)
global_step += 1
result = training_collector.stat()
loss = policy.learn(training_collector.sample(args.batch_size))
stat_loss.add(loss)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.6f}')
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
policy.set_eps(args.eps_test)
test_collector.collect(n_episode=args.test_num)
result = test_collector.stat()
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'CartPole-v0' and best_reward >= 200:
break
assert best_reward >= 200
return best_reward
if __name__ == '__main__':
test_dqn(get_args())

View File

@ -65,6 +65,16 @@ class ReplayBuffer(object):
info=self.info[indice]
), indice
def __getitem__(self, index):
return Batch(
obs=self.obs[index],
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.obs_next[index],
info=self.info[index]
)
class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer"""

View File

@ -10,7 +10,7 @@ from tianshou.utils import MovAvg
class Collector(object):
"""docstring for Collector"""
def __init__(self, policy, env, buffer, contiguous=True):
def __init__(self, policy, env, buffer, stat_size=100):
super().__init__()
self.env = env
self.env_num = 1
@ -19,27 +19,28 @@ class Collector(object):
self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv)
self._multi_buf = False # buf is a list
# need multiple cache buffers only if contiguous in one buffer
# need multiple cache buffers only if storing in one buffer
self._cached_buf = []
if self._multi_env:
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num,\
'# of data buffer does not match the # of input env.'
'The number of data buffer does not match the number of '\
'input env.'
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer) and contiguous:
elif isinstance(self.buffer, ReplayBuffer):
self._cached_buf = [
deepcopy(buffer) for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.reset_env()
self.clear_buffer()
# state over batch is either a list, an np.ndarray, or torch.Tensor
self.reset_buffer()
# state over batch is either a list, an np.ndarray, or a torch.Tensor
self.state = None
self.stat_reward = MovAvg()
self.stat_length = MovAvg()
self.stat_reward = MovAvg(stat_size)
self.stat_length = MovAvg(stat_size)
def clear_buffer(self):
def reset_buffer(self):
if self._multi_buf:
for b in self.buffer:
b.reset()
@ -57,6 +58,18 @@ class Collector(object):
for b in self._cached_buf:
b.reset()
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
self.env.seed(seed)
def render(self):
if hasattr(self.env, 'render'):
self.env.render()
def close(self):
if hasattr(self.env, 'close'):
self.env.close()
def _make_batch(data):
if isinstance(data, np.ndarray):
return data[None]
@ -66,9 +79,10 @@ class Collector(object):
def collect(self, n_step=0, n_episode=0):
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!"
cur_step, cur_episode = 0, 0
cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
while True:
if self.multi_env:
if self._multi_env:
batch_data = Batch(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=None, info=self._info)
@ -78,8 +92,9 @@ class Collector(object):
act=self._make_batch(self._act),
rew=self._make_batch(self._rew),
done=self._make_batch(self._done),
obs_next=None, info=self._make_batch(self._info))
result = self.policy.act(batch_data, self.state)
obs_next=None,
info=self._make_batch(self._info))
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(
@ -88,6 +103,9 @@ class Collector(object):
self.reward += self._rew
if self._multi_env:
for i in range(self.env_num):
if not self.env.is_reset_after_done()\
and cur_episode[i] > 0:
continue
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
@ -101,7 +119,7 @@ class Collector(object):
self.buffer.add(**data)
cur_step += 1
if self._done[i]:
cur_episode += 1
cur_episode[i] += 1
self.stat_reward.add(self.reward[i])
self.stat_length.add(self.length[i])
self.reward[i], self.length[i] = 0, 0
@ -111,12 +129,12 @@ class Collector(object):
self._cached_buf[i].reset()
if isinstance(self.state, list):
self.state[i] = None
else:
elif self.state is not None:
self.state[i] = self.state[i] * 0
if isinstance(self.state, torch.Tensor):
# remove ref in torch (?)
# remove ref count in pytorch (?)
self.state = self.state.detach()
if n_episode > 0 and cur_episode >= n_episode:
if n_episode > 0 and cur_episode.sum() >= n_episode:
break
else:
self.buffer.add(
@ -141,13 +159,13 @@ class Collector(object):
if batch_size > 0:
lens = [len(b) for b in self.buffer]
total = sum(lens)
ib = np.random.choice(
batch_index = np.random.choice(
total, batch_size, p=np.array(lens) / total)
else:
ib = np.array([])
batch_index = np.array([])
batch_data = Batch()
for i, b in enumerate(self.buffer):
cur_batch = (ib == i).sum()
cur_batch = (batch_index == i).sum()
if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice)

View File

@ -1,6 +1,6 @@
import numpy as np
from abc import ABC
from collections import deque
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe
try:
import ray
@ -63,6 +63,32 @@ class BaseVectorEnv(ABC):
self.env_num = len(env_fns)
self._reset_after_done = reset_after_done
def is_reset_after_done(self):
return self._reset_after_done
def __len__(self):
return self.env_num
@abstractmethod
def reset(self):
pass
@abstractmethod
def step(self, action):
pass
@abstractmethod
def seed(self, seed=None):
pass
@abstractmethod
def render(self):
pass
@abstractmethod
def close(self):
pass
class VectorEnv(BaseVectorEnv):
"""docstring for VectorEnv"""
@ -71,9 +97,6 @@ class VectorEnv(BaseVectorEnv):
super().__init__(env_fns, reset_after_done)
self.envs = [_() for _ in env_fns]
def __len__(self):
return len(self.envs)
def reset(self):
return np.stack([e.reset() for e in self.envs])
@ -148,9 +171,6 @@ class SubprocVectorEnv(BaseVectorEnv):
for c in self.child_remote:
c.close()
def __len__(self):
return self.env_num
def step(self, action):
assert len(action) == self.env_num
for p, a in zip(self.parent_remote, action):
@ -203,9 +223,6 @@ class RayVectorEnv(BaseVectorEnv):
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
for e in env_fns]
def __len__(self):
return self.env_num
def step(self, action):
assert len(action) == self.env_num
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]

View File

@ -6,25 +6,21 @@ class BasePolicy(ABC):
def __init__(self):
super().__init__()
self.model = None
@abstractmethod
def act(self, batch, hidden_state=None):
def __call__(self, batch, hidden_state=None):
# return Batch(policy, action, hidden)
pass
def train(self):
pass
def eval(self):
pass
def reset(self):
@abstractmethod
def learn(self, batch):
pass
def process_fn(self, batch, buffer, indice):
return batch
def sync_weights(self):
def sync_weight(self):
pass
def exploration(self):

View File

@ -1,4 +1,5 @@
import torch
import numpy as np
from torch import nn
from copy import deepcopy
@ -9,25 +10,86 @@ from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy, nn.Module):
"""docstring for DQNPolicy"""
def __init__(self, model, discount_factor=0.99, estimation_step=1,
def __init__(self, model, optim, loss,
discount_factor=0.99,
estimation_step=1,
use_target_network=True):
super().__init__()
self.model = model
self.optim = optim
self.loss = loss
self.eps = 0
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
self._target = use_target_network
if use_target_network:
self.model_old = deepcopy(self.model)
self.model_old.eval()
def act(self, batch, hidden_state=None):
batch_result = Batch()
return batch_result
def __call__(self, batch, hidden_state=None,
model='model', input='obs', eps=None):
model = getattr(self, model)
obs = getattr(batch, input)
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy()
# add eps to act
for i in range(len(q)):
if np.random.rand() < self.eps:
act[i] = np.random.randint(q.shape[1])
return Batch(Q=q, act=act, state=h)
def sync_weights(self):
if self._use_target_network:
for old, new in zip(
self.model_old.parameters(), self.model.parameters()):
old.data.copy_(new.data)
def set_eps(self, eps):
self.eps = eps
def train(self):
self.training = True
self.model.train()
def eval(self):
self.training = False
self.model.eval()
def sync_weight(self):
if self._target:
self.model_old.load_state_dict(self.model.state_dict())
def process_fn(self, batch, buffer, indice):
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + self._n_step
for n in range(self._n_step - 1, -1, -1):
now = (indice + n) % len(buffer)
gammas[buffer.done[now] > 0] = n
returns[buffer.done[now] > 0] = 0
returns = buffer.rew[now] + self._gamma * returns
terminal = (indice + self._n_step - 1) % len(buffer)
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(buffer[terminal], input='obs_next').act
target_q = self(
buffer[terminal], model='model_old', input='obs_next').Q
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(buffer[terminal], input='obs_next').Q
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = target_q.max(axis=1)
target_q[gammas != self._n_step] = 0
returns += (self._gamma ** gammas) * target_q
batch.update(returns=returns)
return batch
def learn(self, batch):
self.optim.zero_grad()
q = self(batch).Q
q = q[np.arange(len(q)), batch.act]
r = batch.returns
if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype)
loss = self.loss(q, r)
loss.backward()
self.optim.step()
return loss.detach().cpu().numpy()

View File

@ -21,3 +21,11 @@ class MovAvg(object):
if len(self.cache) == 0:
return 0
return np.mean(self.cache)
def mean(self):
return self.get()
def std(self):
if len(self.cache) == 0:
return 0
return np.std(self.cache)