fix some bugs

This commit is contained in:
Trinkle23897 2020-03-16 11:11:29 +08:00
parent 5983c6b33d
commit cef5de8b83
9 changed files with 144 additions and 36 deletions

View File

@ -1,3 +1,4 @@
import pytest
import numpy as np
from tianshou.data import Batch
@ -10,6 +11,9 @@ def test_batch():
batch.append(batch)
assert batch.obs == [1, 1]
assert batch.np.shape == (6, 4)
assert batch[0].obs == batch[1].obs
with pytest.raises(IndexError):
batch[2]
if __name__ == '__main__':

View File

@ -8,16 +8,23 @@ else: # pytest
def test_replaybuffer(size=10, bufsize=20):
env = MyTestEnv(size)
buf = ReplayBuffer(bufsize)
buf2 = ReplayBuffer(bufsize)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 15
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
buf.add(obs, a, rew, done, obs_next, info)
obs = obs_next
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
data, indice = buf.sample(bufsize * 2)
assert (indice < len(buf)).all()
assert (data.obs < size).all()
assert (0 <= data.done).all() and (data.done <= 1).all()
assert len(buf) > len(buf2)
buf2.update(buf)
assert len(buf) == len(buf2)
assert buf2[0].obs == buf[5].obs
assert buf2[-1].obs == buf[4].obs
if __name__ == '__main__':

View File

@ -1,4 +1,5 @@
import gym
import time
import tqdm
import torch
import argparse
@ -36,7 +37,7 @@ class Net(nn.Module):
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
@ -65,6 +66,7 @@ def test_dqn(args=get_args()):
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
@ -124,11 +126,26 @@ def test_dqn(args=get_args()):
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, '
print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'CartPole-v0' and best_reward >= 200:
break
assert best_reward >= 200
if __name__ == '__main__':
# let's watch its performance!
env = gym.make(args.task)
obs = env.reset()
done = False
total = 0
while not done:
q, _ = net([obs])
action = q.max(dim=1)[1]
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
total += rew
env.render()
time.sleep(1 / 100)
env.close()
print(f'Final test: {total}')
return best_reward

View File

@ -1,5 +1,6 @@
import gym
import time
import pytest
import numpy as np
from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
@ -8,24 +9,28 @@ class MyTestEnv(gym.Env):
def __init__(self, size, sleep=0):
self.size = size
self.sleep = sleep
self.index = 0
self.reset()
def reset(self):
self.done = False
self.index = 0
return self.index
def step(self, action):
if self.done:
raise ValueError('step after done !!!')
if self.sleep > 0:
time.sleep(self.sleep)
if self.index == self.size:
self.done = True
return self.index, 0, True, {}
if action == 0:
self.index = max(self.index - 1, 0)
return self.index, 0, False, {}
elif action == 1:
self.index += 1
finished = self.index == self.size
return self.index, int(finished), finished, {}
self.done = self.index == self.size
return self.index, int(self.done), self.done, {}
def test_framestack(k=4, size=10):
@ -47,19 +52,21 @@ def test_framestack(k=4, size=10):
obs, rew, done, info = fsenv.step(1)
assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0
assert (rew, done) == (1, True)
obs, rew, done, info = fsenv.step(0)
assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0
assert (rew, done) == (0, True)
with pytest.raises(ValueError):
obs, rew, done, info = fsenv.step(0)
# assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0
# assert (rew, done) == (0, True)
fsenv.close()
def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
def test_vecenv(size=10, num=8, sleep=0.001):
verbose = __name__ == '__main__'
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
venv = [
VectorEnv(env_fns, reset_after_done=True),
SubprocVectorEnv(env_fns, reset_after_done=True),
]
if __name__ == '__main__':
if verbose:
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
for v in venv:
v.seed()
@ -86,6 +93,40 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
v.close()
def test_vecenv2():
verbose = __name__ == '__main__'
env_fns = [
lambda: MyTestEnv(size=1),
lambda: MyTestEnv(size=2),
lambda: MyTestEnv(size=3),
lambda: MyTestEnv(size=4),
]
num = len(env_fns)
venv = [
VectorEnv(env_fns, reset_after_done=False),
SubprocVectorEnv(env_fns, reset_after_done=False),
]
if verbose:
venv.append(RayVectorEnv(env_fns, reset_after_done=False))
for v in venv:
v.seed()
o = [v.reset() for v in venv]
action_list = [1] * 6
for i, a in enumerate(action_list):
o = [v.step([a] * num) for v in venv]
if verbose:
print(o[0])
print(o[1])
print(o[2])
print('---')
for i in zip(*o):
for j in range(1, len(i)):
assert (i[0] == i[j]).all()
for v in venv:
v.close()
if __name__ == '__main__':
test_framestack()
test_vecenv(True)
test_vecenv()
test_vecenv2()

View File

@ -13,7 +13,7 @@ class Batch(object):
b = Batch()
for k in self.__dict__.keys():
if self.__dict__[k] is not None:
b.update(k=self.__dict__[k][index])
b.update(**{k: self.__dict__[k][index]})
return b
def update(self, **kwargs):

View File

@ -27,10 +27,14 @@ class ReplayBuffer(object):
self.__dict__[name][self._index] = inst
def update(self, buffer):
for i in range(len(buffer)):
i = begin = buffer._index % len(buffer)
while True:
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i],
buffer.done[i], buffer.obs_next[i], buffer.info[i])
i = (i + 1) % len(buffer)
if i == begin:
break
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
'''

View File

@ -1,3 +1,4 @@
import time
import torch
import numpy as np
from copy import deepcopy
@ -70,13 +71,13 @@ class Collector(object):
if hasattr(self.env, 'close'):
self.env.close()
def _make_batch(data):
def _make_batch(self, data):
if isinstance(data, np.ndarray):
return data[None]
else:
return [data]
def collect(self, n_step=0, n_episode=0):
def collect(self, n_step=0, n_episode=0, render=0):
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!"
cur_step = 0
@ -98,7 +99,10 @@ class Collector(object):
self.state = result.state if hasattr(result, 'state') else None
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(
self._act)
self._act if self._multi_env else self._act[0])
if render > 0:
self.env.render()
time.sleep(render)
self.length += 1
self.reward += self._rew
if self._multi_env:
@ -147,6 +151,7 @@ class Collector(object):
self.stat_length.add(self.length)
self.reward, self.length = 0, 0
self.state = None
self._obs = self.env.reset()
if n_episode > 0 and cur_episode >= n_episode:
break
if n_step > 0 and cur_step >= n_step:

View File

@ -62,6 +62,7 @@ class BaseVectorEnv(ABC):
self._env_fns = env_fns
self.env_num = len(env_fns)
self._reset_after_done = reset_after_done
self._done = np.zeros(self.env_num)
def is_reset_after_done(self):
return self._reset_after_done
@ -98,17 +99,26 @@ class VectorEnv(BaseVectorEnv):
self.envs = [_() for _ in env_fns]
def reset(self):
return np.stack([e.reset() for e in self.envs])
self._done = np.zeros(self.env_num)
self._obs = np.stack([e.reset() for e in self.envs])
return self._obs
def step(self, action):
assert len(action) == self.env_num
result = zip(*[e.step(a) for e, a in zip(self.envs, action)])
obs, rew, done, info = result
if self._reset_after_done and sum(done):
obs = np.stack(obs)
for i in np.where(done)[0]:
obs[i] = self.envs[i].reset()
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
result = []
for i, e in enumerate(self.envs):
if not self.is_reset_after_done() and self._done[i]:
result.append([
self._obs[i], self._rew[i], self._done[i], self._info[i]])
else:
result.append(e.step(action[i]))
self._obs, self._rew, self._done, self._info = zip(*result)
if self.is_reset_after_done() and sum(self._done):
self._obs = np.stack(self._obs)
for i in np.where(self._done)[0]:
self._obs[i] = self.envs[i].reset()
return np.stack(self._obs), np.stack(self._rew),\
np.stack(self._done), np.stack(self._info)
def seed(self, seed=None):
if np.isscalar(seed) or seed is None:
@ -130,15 +140,18 @@ class VectorEnv(BaseVectorEnv):
def worker(parent, p, env_fn_wrapper, reset_after_done):
parent.close()
env = env_fn_wrapper.data()
done = False
while True:
cmd, data = p.recv()
if cmd == 'step':
obs, rew, done, info = env.step(data)
if reset_after_done or not done:
obs, rew, done, info = env.step(data)
if reset_after_done and done:
# s_ is useless when episode finishes
obs = env.reset()
p.send([obs, rew, done, info])
elif cmd == 'reset':
done = False
p.send(env.reset())
elif cmd == 'close':
p.close()
@ -225,21 +238,36 @@ class RayVectorEnv(BaseVectorEnv):
def step(self, action):
assert len(action) == self.env_num
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
if self._reset_after_done and sum(done):
obs = np.stack(obs)
index = np.where(done)[0]
result_obj = []
for i, e in enumerate(self.envs):
if not self.is_reset_after_done() and self._done[i]:
result_obj.append(None)
else:
result_obj.append(e.step.remote(action[i]))
result = []
for i, r in enumerate(result_obj):
if r is None:
result.append([
self._obs[i], self._rew[i], self._done[i], self._info[i]])
else:
result.append(ray.get(r))
self._obs, self._rew, self._done, self._info = zip(*result)
if self.is_reset_after_done() and sum(self._done):
self._obs = np.stack(self._obs)
index = np.where(self._done)[0]
result_obj = []
for i in range(len(index)):
result_obj.append(self.envs[index[i]].reset.remote())
for i in range(len(index)):
obs[index[i]] = ray.get(result_obj[i])
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
self._obs[index[i]] = ray.get(result_obj[i])
return np.stack(self._obs), np.stack(self._rew),\
np.stack(self._done), np.stack(self._info)
def reset(self):
self._done = np.zeros(self.env_num)
result_obj = [e.reset.remote() for e in self.envs]
return np.stack([ray.get(r) for r in result_obj])
self._obs = np.stack([ray.get(r) for r in result_obj])
return self._obs
def seed(self, seed=None):
if not hasattr(self.envs[0], 'seed'):

View File

@ -35,8 +35,10 @@ class DQNPolicy(BasePolicy, nn.Module):
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy()
# add eps to act
if eps is None:
eps = self.eps
for i in range(len(q)):
if np.random.rand() < self.eps:
if np.random.rand() < eps:
act[i] = np.random.randint(q.shape[1])
return Batch(Q=q, act=act, state=h)
@ -66,7 +68,7 @@ class DQNPolicy(BasePolicy, nn.Module):
terminal = (indice + self._n_step - 1) % len(buffer)
if self._target:
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
a = self(buffer[terminal], input='obs_next').act
a = self(buffer[terminal], input='obs_next', eps=0).act
target_q = self(
buffer[terminal], model='model_old', input='obs_next').Q
if isinstance(target_q, torch.Tensor):