finish dqn

This commit is contained in:
Trinkle23897 2020-03-15 17:41:00 +08:00
parent c804662457
commit 5983c6b33d
8 changed files with 296 additions and 48 deletions

1
.gitignore vendored
View File

@ -135,3 +135,4 @@ dmypy.json
# customize # customize
flake8.sh flake8.sh
log/

136
test/test_dqn.py Normal file
View File

@ -0,0 +1,136 @@
import gym
import tqdm
import torch
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg
from tianshou.data import Collector, ReplayBuffer
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
def forward(self, s, **kwargs):
if not isinstance(s, torch.Tensor):
s = torch.Tensor(s)
s = s.to(self.device)
batch = s.shape[0]
q = self.model(s.view(batch, -1))
return q, None
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=20)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_dqn(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
loss = nn.MSELoss()
policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step)
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
training_collector.collect(n_step=args.batch_size)
# log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10
for epoch in range(args.epoch):
desc = f"Epoch #{epoch + 1}"
# train
policy.train()
policy.sync_weight()
policy.set_eps(args.eps_train)
with tqdm.trange(
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
for _ in t:
training_collector.collect(n_step=args.collect_per_step)
global_step += 1
result = training_collector.stat()
loss = policy.learn(training_collector.sample(args.batch_size))
stat_loss.add(loss)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.6f}')
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
policy.set_eps(args.eps_test)
test_collector.collect(n_episode=args.test_num)
result = test_collector.stat()
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'CartPole-v0' and best_reward >= 200:
break
assert best_reward >= 200
return best_reward
if __name__ == '__main__':
test_dqn(get_args())

View File

@ -65,6 +65,16 @@ class ReplayBuffer(object):
info=self.info[indice] info=self.info[indice]
), indice ), indice
def __getitem__(self, index):
return Batch(
obs=self.obs[index],
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.obs_next[index],
info=self.info[index]
)
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer""" """docstring for PrioritizedReplayBuffer"""

View File

@ -10,7 +10,7 @@ from tianshou.utils import MovAvg
class Collector(object): class Collector(object):
"""docstring for Collector""" """docstring for Collector"""
def __init__(self, policy, env, buffer, contiguous=True): def __init__(self, policy, env, buffer, stat_size=100):
super().__init__() super().__init__()
self.env = env self.env = env
self.env_num = 1 self.env_num = 1
@ -19,27 +19,28 @@ class Collector(object):
self.process_fn = policy.process_fn self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv) self._multi_env = isinstance(env, BaseVectorEnv)
self._multi_buf = False # buf is a list self._multi_buf = False # buf is a list
# need multiple cache buffers only if contiguous in one buffer # need multiple cache buffers only if storing in one buffer
self._cached_buf = [] self._cached_buf = []
if self._multi_env: if self._multi_env:
self.env_num = len(env) self.env_num = len(env)
if isinstance(self.buffer, list): if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num,\ assert len(self.buffer) == self.env_num,\
'# of data buffer does not match the # of input env.' 'The number of data buffer does not match the number of '\
'input env.'
self._multi_buf = True self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer) and contiguous: elif isinstance(self.buffer, ReplayBuffer):
self._cached_buf = [ self._cached_buf = [
deepcopy(buffer) for _ in range(self.env_num)] deepcopy(buffer) for _ in range(self.env_num)]
else: else:
raise TypeError('The buffer in data collector is invalid!') raise TypeError('The buffer in data collector is invalid!')
self.reset_env() self.reset_env()
self.clear_buffer() self.reset_buffer()
# state over batch is either a list, an np.ndarray, or torch.Tensor # state over batch is either a list, an np.ndarray, or a torch.Tensor
self.state = None self.state = None
self.stat_reward = MovAvg() self.stat_reward = MovAvg(stat_size)
self.stat_length = MovAvg() self.stat_length = MovAvg(stat_size)
def clear_buffer(self): def reset_buffer(self):
if self._multi_buf: if self._multi_buf:
for b in self.buffer: for b in self.buffer:
b.reset() b.reset()
@ -57,6 +58,18 @@ class Collector(object):
for b in self._cached_buf: for b in self._cached_buf:
b.reset() b.reset()
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
self.env.seed(seed)
def render(self):
if hasattr(self.env, 'render'):
self.env.render()
def close(self):
if hasattr(self.env, 'close'):
self.env.close()
def _make_batch(data): def _make_batch(data):
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
return data[None] return data[None]
@ -66,9 +79,10 @@ class Collector(object):
def collect(self, n_step=0, n_episode=0): def collect(self, n_step=0, n_episode=0):
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!" "One and only one collection number specification permitted!"
cur_step, cur_episode = 0, 0 cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
while True: while True:
if self.multi_env: if self._multi_env:
batch_data = Batch( batch_data = Batch(
obs=self._obs, act=self._act, rew=self._rew, obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=None, info=self._info) done=self._done, obs_next=None, info=self._info)
@ -78,8 +92,9 @@ class Collector(object):
act=self._make_batch(self._act), act=self._make_batch(self._act),
rew=self._make_batch(self._rew), rew=self._make_batch(self._rew),
done=self._make_batch(self._done), done=self._make_batch(self._done),
obs_next=None, info=self._make_batch(self._info)) obs_next=None,
result = self.policy.act(batch_data, self.state) info=self._make_batch(self._info))
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None self.state = result.state if hasattr(result, 'state') else None
self._act = result.act self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step( obs_next, self._rew, self._done, self._info = self.env.step(
@ -88,6 +103,9 @@ class Collector(object):
self.reward += self._rew self.reward += self._rew
if self._multi_env: if self._multi_env:
for i in range(self.env_num): for i in range(self.env_num):
if not self.env.is_reset_after_done()\
and cur_episode[i] > 0:
continue
data = { data = {
'obs': self._obs[i], 'act': self._act[i], 'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i], 'rew': self._rew[i], 'done': self._done[i],
@ -101,7 +119,7 @@ class Collector(object):
self.buffer.add(**data) self.buffer.add(**data)
cur_step += 1 cur_step += 1
if self._done[i]: if self._done[i]:
cur_episode += 1 cur_episode[i] += 1
self.stat_reward.add(self.reward[i]) self.stat_reward.add(self.reward[i])
self.stat_length.add(self.length[i]) self.stat_length.add(self.length[i])
self.reward[i], self.length[i] = 0, 0 self.reward[i], self.length[i] = 0, 0
@ -111,12 +129,12 @@ class Collector(object):
self._cached_buf[i].reset() self._cached_buf[i].reset()
if isinstance(self.state, list): if isinstance(self.state, list):
self.state[i] = None self.state[i] = None
else: elif self.state is not None:
self.state[i] = self.state[i] * 0 self.state[i] = self.state[i] * 0
if isinstance(self.state, torch.Tensor): if isinstance(self.state, torch.Tensor):
# remove ref in torch (?) # remove ref count in pytorch (?)
self.state = self.state.detach() self.state = self.state.detach()
if n_episode > 0 and cur_episode >= n_episode: if n_episode > 0 and cur_episode.sum() >= n_episode:
break break
else: else:
self.buffer.add( self.buffer.add(
@ -141,13 +159,13 @@ class Collector(object):
if batch_size > 0: if batch_size > 0:
lens = [len(b) for b in self.buffer] lens = [len(b) for b in self.buffer]
total = sum(lens) total = sum(lens)
ib = np.random.choice( batch_index = np.random.choice(
total, batch_size, p=np.array(lens) / total) total, batch_size, p=np.array(lens) / total)
else: else:
ib = np.array([]) batch_index = np.array([])
batch_data = Batch() batch_data = Batch()
for i, b in enumerate(self.buffer): for i, b in enumerate(self.buffer):
cur_batch = (ib == i).sum() cur_batch = (batch_index == i).sum()
if batch_size and cur_batch or batch_size <= 0: if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch) batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice) batch = self.process_fn(batch, b, indice)

View File

@ -1,6 +1,6 @@
import numpy as np import numpy as np
from abc import ABC
from collections import deque from collections import deque
from abc import ABC, abstractmethod
from multiprocessing import Process, Pipe from multiprocessing import Process, Pipe
try: try:
import ray import ray
@ -63,6 +63,32 @@ class BaseVectorEnv(ABC):
self.env_num = len(env_fns) self.env_num = len(env_fns)
self._reset_after_done = reset_after_done self._reset_after_done = reset_after_done
def is_reset_after_done(self):
return self._reset_after_done
def __len__(self):
return self.env_num
@abstractmethod
def reset(self):
pass
@abstractmethod
def step(self, action):
pass
@abstractmethod
def seed(self, seed=None):
pass
@abstractmethod
def render(self):
pass
@abstractmethod
def close(self):
pass
class VectorEnv(BaseVectorEnv): class VectorEnv(BaseVectorEnv):
"""docstring for VectorEnv""" """docstring for VectorEnv"""
@ -71,9 +97,6 @@ class VectorEnv(BaseVectorEnv):
super().__init__(env_fns, reset_after_done) super().__init__(env_fns, reset_after_done)
self.envs = [_() for _ in env_fns] self.envs = [_() for _ in env_fns]
def __len__(self):
return len(self.envs)
def reset(self): def reset(self):
return np.stack([e.reset() for e in self.envs]) return np.stack([e.reset() for e in self.envs])
@ -148,9 +171,6 @@ class SubprocVectorEnv(BaseVectorEnv):
for c in self.child_remote: for c in self.child_remote:
c.close() c.close()
def __len__(self):
return self.env_num
def step(self, action): def step(self, action):
assert len(action) == self.env_num assert len(action) == self.env_num
for p, a in zip(self.parent_remote, action): for p, a in zip(self.parent_remote, action):
@ -203,9 +223,6 @@ class RayVectorEnv(BaseVectorEnv):
ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
for e in env_fns] for e in env_fns]
def __len__(self):
return self.env_num
def step(self, action): def step(self, action):
assert len(action) == self.env_num assert len(action) == self.env_num
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]

View File

@ -6,25 +6,21 @@ class BasePolicy(ABC):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.model = None
@abstractmethod @abstractmethod
def act(self, batch, hidden_state=None): def __call__(self, batch, hidden_state=None):
# return Batch(policy, action, hidden) # return Batch(policy, action, hidden)
pass pass
def train(self): @abstractmethod
pass def learn(self, batch):
def eval(self):
pass
def reset(self):
pass pass
def process_fn(self, batch, buffer, indice): def process_fn(self, batch, buffer, indice):
return batch return batch
def sync_weights(self): def sync_weight(self):
pass pass
def exploration(self): def exploration(self):

View File

@ -1,4 +1,5 @@
import torch import torch
import numpy as np
from torch import nn from torch import nn
from copy import deepcopy from copy import deepcopy
@ -9,25 +10,86 @@ from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy, nn.Module): class DQNPolicy(BasePolicy, nn.Module):
"""docstring for DQNPolicy""" """docstring for DQNPolicy"""
def __init__(self, model, discount_factor=0.99, estimation_step=1, def __init__(self, model, optim, loss,
discount_factor=0.99,
estimation_step=1,
use_target_network=True): use_target_network=True):
super().__init__() super().__init__()
self.model = model self.model = model
self.optim = optim
self.loss = loss
self.eps = 0
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor self._gamma = discount_factor
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step self._n_step = estimation_step
self._target = use_target_network self._target = use_target_network
if use_target_network: if use_target_network:
self.model_old = deepcopy(self.model) self.model_old = deepcopy(self.model)
self.model_old.eval()
def act(self, batch, hidden_state=None): def __call__(self, batch, hidden_state=None,
batch_result = Batch() model='model', input='obs', eps=None):
return batch_result model = getattr(self, model)
obs = getattr(batch, input)
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy()
# add eps to act
for i in range(len(q)):
if np.random.rand() < self.eps:
act[i] = np.random.randint(q.shape[1])
return Batch(Q=q, act=act, state=h)
def sync_weights(self): def set_eps(self, eps):
if self._use_target_network: self.eps = eps
for old, new in zip(
self.model_old.parameters(), self.model.parameters()): def train(self):
old.data.copy_(new.data) self.training = True
self.model.train()
def eval(self):
self.training = False
self.model.eval()
def sync_weight(self):
if self._target:
self.model_old.load_state_dict(self.model.state_dict())
def process_fn(self, batch, buffer, indice): def process_fn(self, batch, buffer, indice):
returns = np.zeros_like(indice)
gammas = np.zeros_like(indice) + self._n_step
for n in range(self._n_step - 1, -1, -1):
now = (indice + n) % len(buffer)
gammas[buffer.done[now] > 0] = n
returns[buffer.done[now] > 0] = 0
returns = buffer.rew[now] + self._gamma * returns
terminal = (indice + self._n_step - 1) % len(buffer)
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(buffer[terminal], input='obs_next').act
target_q = self(
buffer[terminal], model='model_old', input='obs_next').Q
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(buffer[terminal], input='obs_next').Q
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = target_q.max(axis=1)
target_q[gammas != self._n_step] = 0
returns += (self._gamma ** gammas) * target_q
batch.update(returns=returns)
return batch return batch
def learn(self, batch):
self.optim.zero_grad()
q = self(batch).Q
q = q[np.arange(len(q)), batch.act]
r = batch.returns
if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype)
loss = self.loss(q, r)
loss.backward()
self.optim.step()
return loss.detach().cpu().numpy()

View File

@ -21,3 +21,11 @@ class MovAvg(object):
if len(self.cache) == 0: if len(self.cache) == 0:
return 0 return 0
return np.mean(self.cache) return np.mean(self.cache)
def mean(self):
return self.get()
def std(self):
if len(self.cache) == 0:
return 0
return np.std(self.cache)