finish dqn
This commit is contained in:
parent
c804662457
commit
5983c6b33d
1
.gitignore
vendored
1
.gitignore
vendored
@ -135,3 +135,4 @@ dmypy.json
|
||||
|
||||
# customize
|
||||
flake8.sh
|
||||
log/
|
||||
|
136
test/test_dqn.py
Normal file
136
test/test_dqn.py
Normal file
@ -0,0 +1,136 @@
|
||||
import gym
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
nn.Linear(np.prod(state_shape), 128),
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.Tensor(s)
|
||||
s = s.to(self.device)
|
||||
batch = s.shape[0]
|
||||
q = self.model(s.view(batch, -1))
|
||||
return q, None
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=20)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_dqn(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
loss = nn.MSELoss()
|
||||
policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(
|
||||
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||
training_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
for epoch in range(args.epoch):
|
||||
desc = f"Epoch #{epoch + 1}"
|
||||
# train
|
||||
policy.train()
|
||||
policy.sync_weight()
|
||||
policy.set_eps(args.eps_train)
|
||||
with tqdm.trange(
|
||||
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
for _ in t:
|
||||
training_collector.collect(n_step=args.collect_per_step)
|
||||
global_step += 1
|
||||
result = training_collector.stat()
|
||||
loss = policy.learn(training_collector.sample(args.batch_size))
|
||||
stat_loss.add(loss)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.6f}')
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_collector.collect(n_episode=args.test_num)
|
||||
result = test_collector.stat()
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if args.task == 'CartPole-v0' and best_reward >= 200:
|
||||
break
|
||||
assert best_reward >= 200
|
||||
return best_reward
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dqn(get_args())
|
@ -65,6 +65,16 @@ class ReplayBuffer(object):
|
||||
info=self.info[indice]
|
||||
), indice
|
||||
|
||||
def __getitem__(self, index):
|
||||
return Batch(
|
||||
obs=self.obs[index],
|
||||
act=self.act[index],
|
||||
rew=self.rew[index],
|
||||
done=self.done[index],
|
||||
obs_next=self.obs_next[index],
|
||||
info=self.info[index]
|
||||
)
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""docstring for PrioritizedReplayBuffer"""
|
||||
|
@ -10,7 +10,7 @@ from tianshou.utils import MovAvg
|
||||
class Collector(object):
|
||||
"""docstring for Collector"""
|
||||
|
||||
def __init__(self, policy, env, buffer, contiguous=True):
|
||||
def __init__(self, policy, env, buffer, stat_size=100):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
@ -19,27 +19,28 @@ class Collector(object):
|
||||
self.process_fn = policy.process_fn
|
||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||
self._multi_buf = False # buf is a list
|
||||
# need multiple cache buffers only if contiguous in one buffer
|
||||
# need multiple cache buffers only if storing in one buffer
|
||||
self._cached_buf = []
|
||||
if self._multi_env:
|
||||
self.env_num = len(env)
|
||||
if isinstance(self.buffer, list):
|
||||
assert len(self.buffer) == self.env_num,\
|
||||
'# of data buffer does not match the # of input env.'
|
||||
'The number of data buffer does not match the number of '\
|
||||
'input env.'
|
||||
self._multi_buf = True
|
||||
elif isinstance(self.buffer, ReplayBuffer) and contiguous:
|
||||
elif isinstance(self.buffer, ReplayBuffer):
|
||||
self._cached_buf = [
|
||||
deepcopy(buffer) for _ in range(self.env_num)]
|
||||
else:
|
||||
raise TypeError('The buffer in data collector is invalid!')
|
||||
self.reset_env()
|
||||
self.clear_buffer()
|
||||
# state over batch is either a list, an np.ndarray, or torch.Tensor
|
||||
self.reset_buffer()
|
||||
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
||||
self.state = None
|
||||
self.stat_reward = MovAvg()
|
||||
self.stat_length = MovAvg()
|
||||
self.stat_reward = MovAvg(stat_size)
|
||||
self.stat_length = MovAvg(stat_size)
|
||||
|
||||
def clear_buffer(self):
|
||||
def reset_buffer(self):
|
||||
if self._multi_buf:
|
||||
for b in self.buffer:
|
||||
b.reset()
|
||||
@ -57,6 +58,18 @@ class Collector(object):
|
||||
for b in self._cached_buf:
|
||||
b.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
if hasattr(self.env, 'seed'):
|
||||
self.env.seed(seed)
|
||||
|
||||
def render(self):
|
||||
if hasattr(self.env, 'render'):
|
||||
self.env.render()
|
||||
|
||||
def close(self):
|
||||
if hasattr(self.env, 'close'):
|
||||
self.env.close()
|
||||
|
||||
def _make_batch(data):
|
||||
if isinstance(data, np.ndarray):
|
||||
return data[None]
|
||||
@ -66,9 +79,10 @@ class Collector(object):
|
||||
def collect(self, n_step=0, n_episode=0):
|
||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
cur_step, cur_episode = 0, 0
|
||||
cur_step = 0
|
||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||
while True:
|
||||
if self.multi_env:
|
||||
if self._multi_env:
|
||||
batch_data = Batch(
|
||||
obs=self._obs, act=self._act, rew=self._rew,
|
||||
done=self._done, obs_next=None, info=self._info)
|
||||
@ -78,8 +92,9 @@ class Collector(object):
|
||||
act=self._make_batch(self._act),
|
||||
rew=self._make_batch(self._rew),
|
||||
done=self._make_batch(self._done),
|
||||
obs_next=None, info=self._make_batch(self._info))
|
||||
result = self.policy.act(batch_data, self.state)
|
||||
obs_next=None,
|
||||
info=self._make_batch(self._info))
|
||||
result = self.policy(batch_data, self.state)
|
||||
self.state = result.state if hasattr(result, 'state') else None
|
||||
self._act = result.act
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||
@ -88,6 +103,9 @@ class Collector(object):
|
||||
self.reward += self._rew
|
||||
if self._multi_env:
|
||||
for i in range(self.env_num):
|
||||
if not self.env.is_reset_after_done()\
|
||||
and cur_episode[i] > 0:
|
||||
continue
|
||||
data = {
|
||||
'obs': self._obs[i], 'act': self._act[i],
|
||||
'rew': self._rew[i], 'done': self._done[i],
|
||||
@ -101,7 +119,7 @@ class Collector(object):
|
||||
self.buffer.add(**data)
|
||||
cur_step += 1
|
||||
if self._done[i]:
|
||||
cur_episode += 1
|
||||
cur_episode[i] += 1
|
||||
self.stat_reward.add(self.reward[i])
|
||||
self.stat_length.add(self.length[i])
|
||||
self.reward[i], self.length[i] = 0, 0
|
||||
@ -111,12 +129,12 @@ class Collector(object):
|
||||
self._cached_buf[i].reset()
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
else:
|
||||
elif self.state is not None:
|
||||
self.state[i] = self.state[i] * 0
|
||||
if isinstance(self.state, torch.Tensor):
|
||||
# remove ref in torch (?)
|
||||
# remove ref count in pytorch (?)
|
||||
self.state = self.state.detach()
|
||||
if n_episode > 0 and cur_episode >= n_episode:
|
||||
if n_episode > 0 and cur_episode.sum() >= n_episode:
|
||||
break
|
||||
else:
|
||||
self.buffer.add(
|
||||
@ -141,13 +159,13 @@ class Collector(object):
|
||||
if batch_size > 0:
|
||||
lens = [len(b) for b in self.buffer]
|
||||
total = sum(lens)
|
||||
ib = np.random.choice(
|
||||
batch_index = np.random.choice(
|
||||
total, batch_size, p=np.array(lens) / total)
|
||||
else:
|
||||
ib = np.array([])
|
||||
batch_index = np.array([])
|
||||
batch_data = Batch()
|
||||
for i, b in enumerate(self.buffer):
|
||||
cur_batch = (ib == i).sum()
|
||||
cur_batch = (batch_index == i).sum()
|
||||
if batch_size and cur_batch or batch_size <= 0:
|
||||
batch, indice = b.sample(cur_batch)
|
||||
batch = self.process_fn(batch, b, indice)
|
||||
|
37
tianshou/env/wrapper.py
vendored
37
tianshou/env/wrapper.py
vendored
@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
from abc import ABC
|
||||
from collections import deque
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, Pipe
|
||||
try:
|
||||
import ray
|
||||
@ -63,6 +63,32 @@ class BaseVectorEnv(ABC):
|
||||
self.env_num = len(env_fns)
|
||||
self._reset_after_done = reset_after_done
|
||||
|
||||
def is_reset_after_done(self):
|
||||
return self._reset_after_done
|
||||
|
||||
def __len__(self):
|
||||
return self.env_num
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed=None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""docstring for VectorEnv"""
|
||||
@ -71,9 +97,6 @@ class VectorEnv(BaseVectorEnv):
|
||||
super().__init__(env_fns, reset_after_done)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.envs)
|
||||
|
||||
def reset(self):
|
||||
return np.stack([e.reset() for e in self.envs])
|
||||
|
||||
@ -148,9 +171,6 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
for c in self.child_remote:
|
||||
c.close()
|
||||
|
||||
def __len__(self):
|
||||
return self.env_num
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
for p, a in zip(self.parent_remote, action):
|
||||
@ -203,9 +223,6 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
|
||||
def __len__(self):
|
||||
return self.env_num
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||
|
@ -6,25 +6,21 @@ class BasePolicy(ABC):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = None
|
||||
|
||||
@abstractmethod
|
||||
def act(self, batch, hidden_state=None):
|
||||
def __call__(self, batch, hidden_state=None):
|
||||
# return Batch(policy, action, hidden)
|
||||
pass
|
||||
|
||||
def train(self):
|
||||
pass
|
||||
|
||||
def eval(self):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
@abstractmethod
|
||||
def learn(self, batch):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
return batch
|
||||
|
||||
def sync_weights(self):
|
||||
def sync_weight(self):
|
||||
pass
|
||||
|
||||
def exploration(self):
|
||||
|
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from copy import deepcopy
|
||||
|
||||
@ -9,25 +10,86 @@ from tianshou.policy import BasePolicy
|
||||
class DQNPolicy(BasePolicy, nn.Module):
|
||||
"""docstring for DQNPolicy"""
|
||||
|
||||
def __init__(self, model, discount_factor=0.99, estimation_step=1,
|
||||
def __init__(self, model, optim, loss,
|
||||
discount_factor=0.99,
|
||||
estimation_step=1,
|
||||
use_target_network=True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.loss = loss
|
||||
self.eps = 0
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
self._n_step = estimation_step
|
||||
self._target = use_target_network
|
||||
if use_target_network:
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
|
||||
def act(self, batch, hidden_state=None):
|
||||
batch_result = Batch()
|
||||
return batch_result
|
||||
def __call__(self, batch, hidden_state=None,
|
||||
model='model', input='obs', eps=None):
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
|
||||
act = q.max(dim=1)[1].detach().cpu().numpy()
|
||||
# add eps to act
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < self.eps:
|
||||
act[i] = np.random.randint(q.shape[1])
|
||||
return Batch(Q=q, act=act, state=h)
|
||||
|
||||
def sync_weights(self):
|
||||
if self._use_target_network:
|
||||
for old, new in zip(
|
||||
self.model_old.parameters(), self.model.parameters()):
|
||||
old.data.copy_(new.data)
|
||||
def set_eps(self, eps):
|
||||
self.eps = eps
|
||||
|
||||
def train(self):
|
||||
self.training = True
|
||||
self.model.train()
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
self.model.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
if self._target:
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
returns = np.zeros_like(indice)
|
||||
gammas = np.zeros_like(indice) + self._n_step
|
||||
for n in range(self._n_step - 1, -1, -1):
|
||||
now = (indice + n) % len(buffer)
|
||||
gammas[buffer.done[now] > 0] = n
|
||||
returns[buffer.done[now] > 0] = 0
|
||||
returns = buffer.rew[now] + self._gamma * returns
|
||||
terminal = (indice + self._n_step - 1) % len(buffer)
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(buffer[terminal], input='obs_next').act
|
||||
target_q = self(
|
||||
buffer[terminal], model='model_old', input='obs_next').Q
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = target_q.detach().cpu().numpy()
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(buffer[terminal], input='obs_next').Q
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = target_q.detach().cpu().numpy()
|
||||
target_q = target_q.max(axis=1)
|
||||
target_q[gammas != self._n_step] = 0
|
||||
returns += (self._gamma ** gammas) * target_q
|
||||
batch.update(returns=returns)
|
||||
return batch
|
||||
|
||||
def learn(self, batch):
|
||||
self.optim.zero_grad()
|
||||
q = self(batch).Q
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = batch.returns
|
||||
if isinstance(r, np.ndarray):
|
||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||
loss = self.loss(q, r)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return loss.detach().cpu().numpy()
|
||||
|
@ -21,3 +21,11 @@ class MovAvg(object):
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.mean(self.cache)
|
||||
|
||||
def mean(self):
|
||||
return self.get()
|
||||
|
||||
def std(self):
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.std(self.cache)
|
||||
|
Loading…
x
Reference in New Issue
Block a user