fix some bugs

This commit is contained in:
Trinkle23897 2020-03-16 11:11:29 +08:00
parent 5983c6b33d
commit cef5de8b83
9 changed files with 144 additions and 36 deletions

View File

@ -1,3 +1,4 @@
import pytest
import numpy as np import numpy as np
from tianshou.data import Batch from tianshou.data import Batch
@ -10,6 +11,9 @@ def test_batch():
batch.append(batch) batch.append(batch)
assert batch.obs == [1, 1] assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4) assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs
with pytest.raises(IndexError):
batch[2]
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -8,16 +8,23 @@ else: # pytest
def test_replaybuffer(size=10, bufsize=20): def test_replaybuffer(size=10, bufsize=20):
env = MyTestEnv(size) env = MyTestEnv(size)
buf = ReplayBuffer(bufsize) buf = ReplayBuffer(bufsize)
buf2 = ReplayBuffer(bufsize)
obs = env.reset() obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 15 action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list): for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a) obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info) buf.add(obs, a, rew, done, obs_next, info)
obs = obs_next
assert len(buf) == min(bufsize, i + 1), print(len(buf), i) assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
data, indice = buf.sample(bufsize * 2) data, indice = buf.sample(bufsize * 2)
assert (indice < len(buf)).all() assert (indice < len(buf)).all()
assert (data.obs < size).all() assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all() assert (0 <= data.done).all() and (data.done <= 1).all()
assert len(buf) > len(buf2)
buf2.update(buf)
assert len(buf) == len(buf2)
assert buf2[0].obs == buf[5].obs
assert buf2[-1].obs == buf[4].obs
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,4 +1,5 @@
import gym import gym
import time
import tqdm import tqdm
import torch import torch
import argparse import argparse
@ -36,7 +37,7 @@ class Net(nn.Module):
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0) parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05) parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1) parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000) parser.add_argument('--buffer-size', type=int, default=20000)
@ -65,6 +66,7 @@ def test_dqn(args=get_args()):
train_envs = SubprocVectorEnv( train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)], [lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True) reset_after_done=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv( test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)], [lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False) reset_after_done=False)
@ -124,11 +126,26 @@ def test_dqn(args=get_args()):
if best_reward < result['reward']: if best_reward < result['reward']:
best_reward = result['reward'] best_reward = result['reward']
best_epoch = epoch best_epoch = epoch
print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, ' print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}') f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'CartPole-v0' and best_reward >= 200: if args.task == 'CartPole-v0' and best_reward >= 200:
break break
assert best_reward >= 200 assert best_reward >= 200
if __name__ == '__main__':
# let's watch its performance!
env = gym.make(args.task)
obs = env.reset()
done = False
total = 0
while not done:
q, _ = net([obs])
action = q.max(dim=1)[1]
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
total += rew
env.render()
time.sleep(1 / 100)
env.close()
print(f'Final test: {total}')
return best_reward return best_reward

View File

@ -1,5 +1,6 @@
import gym import gym
import time import time
import pytest
import numpy as np import numpy as np
from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
@ -8,24 +9,28 @@ class MyTestEnv(gym.Env):
def __init__(self, size, sleep=0): def __init__(self, size, sleep=0):
self.size = size self.size = size
self.sleep = sleep self.sleep = sleep
self.index = 0 self.reset()
def reset(self): def reset(self):
self.done = False
self.index = 0 self.index = 0
return self.index return self.index
def step(self, action): def step(self, action):
if self.done:
raise ValueError('step after done !!!')
if self.sleep > 0: if self.sleep > 0:
time.sleep(self.sleep) time.sleep(self.sleep)
if self.index == self.size: if self.index == self.size:
self.done = True
return self.index, 0, True, {} return self.index, 0, True, {}
if action == 0: if action == 0:
self.index = max(self.index - 1, 0) self.index = max(self.index - 1, 0)
return self.index, 0, False, {} return self.index, 0, False, {}
elif action == 1: elif action == 1:
self.index += 1 self.index += 1
finished = self.index == self.size self.done = self.index == self.size
return self.index, int(finished), finished, {} return self.index, int(self.done), self.done, {}
def test_framestack(k=4, size=10): def test_framestack(k=4, size=10):
@ -47,19 +52,21 @@ def test_framestack(k=4, size=10):
obs, rew, done, info = fsenv.step(1) obs, rew, done, info = fsenv.step(1)
assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0 assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0
assert (rew, done) == (1, True) assert (rew, done) == (1, True)
obs, rew, done, info = fsenv.step(0) with pytest.raises(ValueError):
assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0 obs, rew, done, info = fsenv.step(0)
assert (rew, done) == (0, True) # assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0
# assert (rew, done) == (0, True)
fsenv.close() fsenv.close()
def test_vecenv(verbose=False, size=10, num=8, sleep=0.001): def test_vecenv(size=10, num=8, sleep=0.001):
verbose = __name__ == '__main__'
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)] env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
venv = [ venv = [
VectorEnv(env_fns, reset_after_done=True), VectorEnv(env_fns, reset_after_done=True),
SubprocVectorEnv(env_fns, reset_after_done=True), SubprocVectorEnv(env_fns, reset_after_done=True),
] ]
if __name__ == '__main__': if verbose:
venv.append(RayVectorEnv(env_fns, reset_after_done=True)) venv.append(RayVectorEnv(env_fns, reset_after_done=True))
for v in venv: for v in venv:
v.seed() v.seed()
@ -86,6 +93,40 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
v.close() v.close()
def test_vecenv2():
verbose = __name__ == '__main__'
env_fns = [
lambda: MyTestEnv(size=1),
lambda: MyTestEnv(size=2),
lambda: MyTestEnv(size=3),
lambda: MyTestEnv(size=4),
]
num = len(env_fns)
venv = [
VectorEnv(env_fns, reset_after_done=False),
SubprocVectorEnv(env_fns, reset_after_done=False),
]
if verbose:
venv.append(RayVectorEnv(env_fns, reset_after_done=False))
for v in venv:
v.seed()
o = [v.reset() for v in venv]
action_list = [1] * 6
for i, a in enumerate(action_list):
o = [v.step([a] * num) for v in venv]
if verbose:
print(o[0])
print(o[1])
print(o[2])
print('---')
for i in zip(*o):
for j in range(1, len(i)):
assert (i[0] == i[j]).all()
for v in venv:
v.close()
if __name__ == '__main__': if __name__ == '__main__':
test_framestack() test_framestack()
test_vecenv(True) test_vecenv()
test_vecenv2()

View File

@ -13,7 +13,7 @@ class Batch(object):
b = Batch() b = Batch()
for k in self.__dict__.keys(): for k in self.__dict__.keys():
if self.__dict__[k] is not None: if self.__dict__[k] is not None:
b.update(k=self.__dict__[k][index]) b.update(**{k: self.__dict__[k][index]})
return b return b
def update(self, **kwargs): def update(self, **kwargs):

View File

@ -27,10 +27,14 @@ class ReplayBuffer(object):
self.__dict__[name][self._index] = inst self.__dict__[name][self._index] = inst
def update(self, buffer): def update(self, buffer):
for i in range(len(buffer)): i = begin = buffer._index % len(buffer)
while True:
self.add( self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i], buffer.obs[i], buffer.act[i], buffer.rew[i],
buffer.done[i], buffer.obs_next[i], buffer.info[i]) buffer.done[i], buffer.obs_next[i], buffer.info[i])
i = (i + 1) % len(buffer)
if i == begin:
break
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None): def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
''' '''

View File

@ -1,3 +1,4 @@
import time
import torch import torch
import numpy as np import numpy as np
from copy import deepcopy from copy import deepcopy
@ -70,13 +71,13 @@ class Collector(object):
if hasattr(self.env, 'close'): if hasattr(self.env, 'close'):
self.env.close() self.env.close()
def _make_batch(data): def _make_batch(self, data):
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
return data[None] return data[None]
else: else:
return [data] return [data]
def collect(self, n_step=0, n_episode=0): def collect(self, n_step=0, n_episode=0, render=0):
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!" "One and only one collection number specification permitted!"
cur_step = 0 cur_step = 0
@ -98,7 +99,10 @@ class Collector(object):
self.state = result.state if hasattr(result, 'state') else None self.state = result.state if hasattr(result, 'state') else None
self._act = result.act self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step( obs_next, self._rew, self._done, self._info = self.env.step(
self._act) self._act if self._multi_env else self._act[0])
if render > 0:
self.env.render()
time.sleep(render)
self.length += 1 self.length += 1
self.reward += self._rew self.reward += self._rew
if self._multi_env: if self._multi_env:
@ -147,6 +151,7 @@ class Collector(object):
self.stat_length.add(self.length) self.stat_length.add(self.length)
self.reward, self.length = 0, 0 self.reward, self.length = 0, 0
self.state = None self.state = None
self._obs = self.env.reset()
if n_episode > 0 and cur_episode >= n_episode: if n_episode > 0 and cur_episode >= n_episode:
break break
if n_step > 0 and cur_step >= n_step: if n_step > 0 and cur_step >= n_step:

View File

@ -62,6 +62,7 @@ class BaseVectorEnv(ABC):
self._env_fns = env_fns self._env_fns = env_fns
self.env_num = len(env_fns) self.env_num = len(env_fns)
self._reset_after_done = reset_after_done self._reset_after_done = reset_after_done
self._done = np.zeros(self.env_num)
def is_reset_after_done(self): def is_reset_after_done(self):
return self._reset_after_done return self._reset_after_done
@ -98,17 +99,26 @@ class VectorEnv(BaseVectorEnv):
self.envs = [_() for _ in env_fns] self.envs = [_() for _ in env_fns]
def reset(self): def reset(self):
return np.stack([e.reset() for e in self.envs]) self._done = np.zeros(self.env_num)
self._obs = np.stack([e.reset() for e in self.envs])
return self._obs
def step(self, action): def step(self, action):
assert len(action) == self.env_num assert len(action) == self.env_num
result = zip(*[e.step(a) for e, a in zip(self.envs, action)]) result = []
obs, rew, done, info = result for i, e in enumerate(self.envs):
if self._reset_after_done and sum(done): if not self.is_reset_after_done() and self._done[i]:
obs = np.stack(obs) result.append([
for i in np.where(done)[0]: self._obs[i], self._rew[i], self._done[i], self._info[i]])
obs[i] = self.envs[i].reset() else:
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) result.append(e.step(action[i]))
self._obs, self._rew, self._done, self._info = zip(*result)
if self.is_reset_after_done() and sum(self._done):
self._obs = np.stack(self._obs)
for i in np.where(self._done)[0]:
self._obs[i] = self.envs[i].reset()
return np.stack(self._obs), np.stack(self._rew),\
np.stack(self._done), np.stack(self._info)
def seed(self, seed=None): def seed(self, seed=None):
if np.isscalar(seed) or seed is None: if np.isscalar(seed) or seed is None:
@ -130,15 +140,18 @@ class VectorEnv(BaseVectorEnv):
def worker(parent, p, env_fn_wrapper, reset_after_done): def worker(parent, p, env_fn_wrapper, reset_after_done):
parent.close() parent.close()
env = env_fn_wrapper.data() env = env_fn_wrapper.data()
done = False
while True: while True:
cmd, data = p.recv() cmd, data = p.recv()
if cmd == 'step': if cmd == 'step':
obs, rew, done, info = env.step(data) if reset_after_done or not done:
obs, rew, done, info = env.step(data)
if reset_after_done and done: if reset_after_done and done:
# s_ is useless when episode finishes # s_ is useless when episode finishes
obs = env.reset() obs = env.reset()
p.send([obs, rew, done, info]) p.send([obs, rew, done, info])
elif cmd == 'reset': elif cmd == 'reset':
done = False
p.send(env.reset()) p.send(env.reset())
elif cmd == 'close': elif cmd == 'close':
p.close() p.close()
@ -225,21 +238,36 @@ class RayVectorEnv(BaseVectorEnv):
def step(self, action): def step(self, action):
assert len(action) == self.env_num assert len(action) == self.env_num
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)] result_obj = []
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj]) for i, e in enumerate(self.envs):
if self._reset_after_done and sum(done): if not self.is_reset_after_done() and self._done[i]:
obs = np.stack(obs) result_obj.append(None)
index = np.where(done)[0] else:
result_obj.append(e.step.remote(action[i]))
result = []
for i, r in enumerate(result_obj):
if r is None:
result.append([
self._obs[i], self._rew[i], self._done[i], self._info[i]])
else:
result.append(ray.get(r))
self._obs, self._rew, self._done, self._info = zip(*result)
if self.is_reset_after_done() and sum(self._done):
self._obs = np.stack(self._obs)
index = np.where(self._done)[0]
result_obj = [] result_obj = []
for i in range(len(index)): for i in range(len(index)):
result_obj.append(self.envs[index[i]].reset.remote()) result_obj.append(self.envs[index[i]].reset.remote())
for i in range(len(index)): for i in range(len(index)):
obs[index[i]] = ray.get(result_obj[i]) self._obs[index[i]] = ray.get(result_obj[i])
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info) return np.stack(self._obs), np.stack(self._rew),\
np.stack(self._done), np.stack(self._info)
def reset(self): def reset(self):
self._done = np.zeros(self.env_num)
result_obj = [e.reset.remote() for e in self.envs] result_obj = [e.reset.remote() for e in self.envs]
return np.stack([ray.get(r) for r in result_obj]) self._obs = np.stack([ray.get(r) for r in result_obj])
return self._obs
def seed(self, seed=None): def seed(self, seed=None):
if not hasattr(self.envs[0], 'seed'): if not hasattr(self.envs[0], 'seed'):

View File

@ -35,8 +35,10 @@ class DQNPolicy(BasePolicy, nn.Module):
q, h = model(obs, hidden_state=hidden_state, info=batch.info) q, h = model(obs, hidden_state=hidden_state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy() act = q.max(dim=1)[1].detach().cpu().numpy()
# add eps to act # add eps to act
if eps is None:
eps = self.eps
for i in range(len(q)): for i in range(len(q)):
if np.random.rand() < self.eps: if np.random.rand() < eps:
act[i] = np.random.randint(q.shape[1]) act[i] = np.random.randint(q.shape[1])
return Batch(Q=q, act=act, state=h) return Batch(Q=q, act=act, state=h)
@ -66,7 +68,7 @@ class DQNPolicy(BasePolicy, nn.Module):
terminal = (indice + self._n_step - 1) % len(buffer) terminal = (indice + self._n_step - 1) % len(buffer)
if self._target: if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *))) # target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(buffer[terminal], input='obs_next').act a = self(buffer[terminal], input='obs_next', eps=0).act
target_q = self( target_q = self(
buffer[terminal], model='model_old', input='obs_next').Q buffer[terminal], model='model_old', input='obs_next').Q
if isinstance(target_q, torch.Tensor): if isinstance(target_q, torch.Tensor):