fix some bugs
This commit is contained in:
parent
5983c6b33d
commit
cef5de8b83
@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from tianshou.data import Batch
|
||||
@ -10,6 +11,9 @@ def test_batch():
|
||||
batch.append(batch)
|
||||
assert batch.obs == [1, 1]
|
||||
assert batch.np.shape == (6, 4)
|
||||
assert batch[0].obs == batch[1].obs
|
||||
with pytest.raises(IndexError):
|
||||
batch[2]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -8,16 +8,23 @@ else: # pytest
|
||||
def test_replaybuffer(size=10, bufsize=20):
|
||||
env = MyTestEnv(size)
|
||||
buf = ReplayBuffer(bufsize)
|
||||
buf2 = ReplayBuffer(bufsize)
|
||||
obs = env.reset()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 15
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 10
|
||||
for i, a in enumerate(action_list):
|
||||
obs_next, rew, done, info = env.step(a)
|
||||
buf.add(obs, a, rew, done, obs_next, info)
|
||||
obs = obs_next
|
||||
assert len(buf) == min(bufsize, i + 1), print(len(buf), i)
|
||||
data, indice = buf.sample(bufsize * 2)
|
||||
assert (indice < len(buf)).all()
|
||||
assert (data.obs < size).all()
|
||||
assert (0 <= data.done).all() and (data.done <= 1).all()
|
||||
assert len(buf) > len(buf2)
|
||||
buf2.update(buf)
|
||||
assert len(buf) == len(buf2)
|
||||
assert buf2[0].obs == buf[5].obs
|
||||
assert buf2[-1].obs == buf[4].obs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,4 +1,5 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
@ -36,7 +37,7 @@ class Net(nn.Module):
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
@ -65,6 +66,7 @@ def test_dqn(args=get_args()):
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
@ -124,11 +126,26 @@ def test_dqn(args=get_args()):
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, '
|
||||
print(f'Epoch #{epoch + 1} test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if args.task == 'CartPole-v0' and best_reward >= 200:
|
||||
break
|
||||
assert best_reward >= 200
|
||||
if __name__ == '__main__':
|
||||
# let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
obs = env.reset()
|
||||
done = False
|
||||
total = 0
|
||||
while not done:
|
||||
q, _ = net([obs])
|
||||
action = q.max(dim=1)[1]
|
||||
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
|
||||
total += rew
|
||||
env.render()
|
||||
time.sleep(1 / 100)
|
||||
env.close()
|
||||
print(f'Final test: {total}')
|
||||
return best_reward
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import gym
|
||||
import time
|
||||
import pytest
|
||||
import numpy as np
|
||||
from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
|
||||
@ -8,24 +9,28 @@ class MyTestEnv(gym.Env):
|
||||
def __init__(self, size, sleep=0):
|
||||
self.size = size
|
||||
self.sleep = sleep
|
||||
self.index = 0
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.done = False
|
||||
self.index = 0
|
||||
return self.index
|
||||
|
||||
def step(self, action):
|
||||
if self.done:
|
||||
raise ValueError('step after done !!!')
|
||||
if self.sleep > 0:
|
||||
time.sleep(self.sleep)
|
||||
if self.index == self.size:
|
||||
self.done = True
|
||||
return self.index, 0, True, {}
|
||||
if action == 0:
|
||||
self.index = max(self.index - 1, 0)
|
||||
return self.index, 0, False, {}
|
||||
elif action == 1:
|
||||
self.index += 1
|
||||
finished = self.index == self.size
|
||||
return self.index, int(finished), finished, {}
|
||||
self.done = self.index == self.size
|
||||
return self.index, int(self.done), self.done, {}
|
||||
|
||||
|
||||
def test_framestack(k=4, size=10):
|
||||
@ -47,19 +52,21 @@ def test_framestack(k=4, size=10):
|
||||
obs, rew, done, info = fsenv.step(1)
|
||||
assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0
|
||||
assert (rew, done) == (1, True)
|
||||
obs, rew, done, info = fsenv.step(0)
|
||||
assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0
|
||||
assert (rew, done) == (0, True)
|
||||
with pytest.raises(ValueError):
|
||||
obs, rew, done, info = fsenv.step(0)
|
||||
# assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0
|
||||
# assert (rew, done) == (0, True)
|
||||
fsenv.close()
|
||||
|
||||
|
||||
def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
|
||||
venv = [
|
||||
VectorEnv(env_fns, reset_after_done=True),
|
||||
SubprocVectorEnv(env_fns, reset_after_done=True),
|
||||
]
|
||||
if __name__ == '__main__':
|
||||
if verbose:
|
||||
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
|
||||
for v in venv:
|
||||
v.seed()
|
||||
@ -86,6 +93,40 @@ def test_vecenv(verbose=False, size=10, num=8, sleep=0.001):
|
||||
v.close()
|
||||
|
||||
|
||||
def test_vecenv2():
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=1),
|
||||
lambda: MyTestEnv(size=2),
|
||||
lambda: MyTestEnv(size=3),
|
||||
lambda: MyTestEnv(size=4),
|
||||
]
|
||||
num = len(env_fns)
|
||||
venv = [
|
||||
VectorEnv(env_fns, reset_after_done=False),
|
||||
SubprocVectorEnv(env_fns, reset_after_done=False),
|
||||
]
|
||||
if verbose:
|
||||
venv.append(RayVectorEnv(env_fns, reset_after_done=False))
|
||||
for v in venv:
|
||||
v.seed()
|
||||
o = [v.reset() for v in venv]
|
||||
action_list = [1] * 6
|
||||
for i, a in enumerate(action_list):
|
||||
o = [v.step([a] * num) for v in venv]
|
||||
if verbose:
|
||||
print(o[0])
|
||||
print(o[1])
|
||||
print(o[2])
|
||||
print('---')
|
||||
for i in zip(*o):
|
||||
for j in range(1, len(i)):
|
||||
assert (i[0] == i[j]).all()
|
||||
for v in venv:
|
||||
v.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_framestack()
|
||||
test_vecenv(True)
|
||||
test_vecenv()
|
||||
test_vecenv2()
|
||||
|
@ -13,7 +13,7 @@ class Batch(object):
|
||||
b = Batch()
|
||||
for k in self.__dict__.keys():
|
||||
if self.__dict__[k] is not None:
|
||||
b.update(k=self.__dict__[k][index])
|
||||
b.update(**{k: self.__dict__[k][index]})
|
||||
return b
|
||||
|
||||
def update(self, **kwargs):
|
||||
|
@ -27,10 +27,14 @@ class ReplayBuffer(object):
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer):
|
||||
for i in range(len(buffer)):
|
||||
i = begin = buffer._index % len(buffer)
|
||||
while True:
|
||||
self.add(
|
||||
buffer.obs[i], buffer.act[i], buffer.rew[i],
|
||||
buffer.done[i], buffer.obs_next[i], buffer.info[i])
|
||||
i = (i + 1) % len(buffer)
|
||||
if i == begin:
|
||||
break
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||
'''
|
||||
|
@ -1,3 +1,4 @@
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
@ -70,13 +71,13 @@ class Collector(object):
|
||||
if hasattr(self.env, 'close'):
|
||||
self.env.close()
|
||||
|
||||
def _make_batch(data):
|
||||
def _make_batch(self, data):
|
||||
if isinstance(data, np.ndarray):
|
||||
return data[None]
|
||||
else:
|
||||
return [data]
|
||||
|
||||
def collect(self, n_step=0, n_episode=0):
|
||||
def collect(self, n_step=0, n_episode=0, render=0):
|
||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
cur_step = 0
|
||||
@ -98,7 +99,10 @@ class Collector(object):
|
||||
self.state = result.state if hasattr(result, 'state') else None
|
||||
self._act = result.act
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||
self._act)
|
||||
self._act if self._multi_env else self._act[0])
|
||||
if render > 0:
|
||||
self.env.render()
|
||||
time.sleep(render)
|
||||
self.length += 1
|
||||
self.reward += self._rew
|
||||
if self._multi_env:
|
||||
@ -147,6 +151,7 @@ class Collector(object):
|
||||
self.stat_length.add(self.length)
|
||||
self.reward, self.length = 0, 0
|
||||
self.state = None
|
||||
self._obs = self.env.reset()
|
||||
if n_episode > 0 and cur_episode >= n_episode:
|
||||
break
|
||||
if n_step > 0 and cur_step >= n_step:
|
||||
|
62
tianshou/env/wrapper.py
vendored
62
tianshou/env/wrapper.py
vendored
@ -62,6 +62,7 @@ class BaseVectorEnv(ABC):
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
self._reset_after_done = reset_after_done
|
||||
self._done = np.zeros(self.env_num)
|
||||
|
||||
def is_reset_after_done(self):
|
||||
return self._reset_after_done
|
||||
@ -98,17 +99,26 @@ class VectorEnv(BaseVectorEnv):
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def reset(self):
|
||||
return np.stack([e.reset() for e in self.envs])
|
||||
self._done = np.zeros(self.env_num)
|
||||
self._obs = np.stack([e.reset() for e in self.envs])
|
||||
return self._obs
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
result = zip(*[e.step(a) for e, a in zip(self.envs, action)])
|
||||
obs, rew, done, info = result
|
||||
if self._reset_after_done and sum(done):
|
||||
obs = np.stack(obs)
|
||||
for i in np.where(done)[0]:
|
||||
obs[i] = self.envs[i].reset()
|
||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||
result = []
|
||||
for i, e in enumerate(self.envs):
|
||||
if not self.is_reset_after_done() and self._done[i]:
|
||||
result.append([
|
||||
self._obs[i], self._rew[i], self._done[i], self._info[i]])
|
||||
else:
|
||||
result.append(e.step(action[i]))
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
if self.is_reset_after_done() and sum(self._done):
|
||||
self._obs = np.stack(self._obs)
|
||||
for i in np.where(self._done)[0]:
|
||||
self._obs[i] = self.envs[i].reset()
|
||||
return np.stack(self._obs), np.stack(self._rew),\
|
||||
np.stack(self._done), np.stack(self._info)
|
||||
|
||||
def seed(self, seed=None):
|
||||
if np.isscalar(seed) or seed is None:
|
||||
@ -130,15 +140,18 @@ class VectorEnv(BaseVectorEnv):
|
||||
def worker(parent, p, env_fn_wrapper, reset_after_done):
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
done = False
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done or not done:
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done and done:
|
||||
# s_ is useless when episode finishes
|
||||
obs = env.reset()
|
||||
p.send([obs, rew, done, info])
|
||||
elif cmd == 'reset':
|
||||
done = False
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.close()
|
||||
@ -225,21 +238,36 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||
obs, rew, done, info = zip(*[ray.get(r) for r in result_obj])
|
||||
if self._reset_after_done and sum(done):
|
||||
obs = np.stack(obs)
|
||||
index = np.where(done)[0]
|
||||
result_obj = []
|
||||
for i, e in enumerate(self.envs):
|
||||
if not self.is_reset_after_done() and self._done[i]:
|
||||
result_obj.append(None)
|
||||
else:
|
||||
result_obj.append(e.step.remote(action[i]))
|
||||
result = []
|
||||
for i, r in enumerate(result_obj):
|
||||
if r is None:
|
||||
result.append([
|
||||
self._obs[i], self._rew[i], self._done[i], self._info[i]])
|
||||
else:
|
||||
result.append(ray.get(r))
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
if self.is_reset_after_done() and sum(self._done):
|
||||
self._obs = np.stack(self._obs)
|
||||
index = np.where(self._done)[0]
|
||||
result_obj = []
|
||||
for i in range(len(index)):
|
||||
result_obj.append(self.envs[index[i]].reset.remote())
|
||||
for i in range(len(index)):
|
||||
obs[index[i]] = ray.get(result_obj[i])
|
||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||
self._obs[index[i]] = ray.get(result_obj[i])
|
||||
return np.stack(self._obs), np.stack(self._rew),\
|
||||
np.stack(self._done), np.stack(self._info)
|
||||
|
||||
def reset(self):
|
||||
self._done = np.zeros(self.env_num)
|
||||
result_obj = [e.reset.remote() for e in self.envs]
|
||||
return np.stack([ray.get(r) for r in result_obj])
|
||||
self._obs = np.stack([ray.get(r) for r in result_obj])
|
||||
return self._obs
|
||||
|
||||
def seed(self, seed=None):
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
|
@ -35,8 +35,10 @@ class DQNPolicy(BasePolicy, nn.Module):
|
||||
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
|
||||
act = q.max(dim=1)[1].detach().cpu().numpy()
|
||||
# add eps to act
|
||||
if eps is None:
|
||||
eps = self.eps
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < self.eps:
|
||||
if np.random.rand() < eps:
|
||||
act[i] = np.random.randint(q.shape[1])
|
||||
return Batch(Q=q, act=act, state=h)
|
||||
|
||||
@ -66,7 +68,7 @@ class DQNPolicy(BasePolicy, nn.Module):
|
||||
terminal = (indice + self._n_step - 1) % len(buffer)
|
||||
if self._target:
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(buffer[terminal], input='obs_next').act
|
||||
a = self(buffer[terminal], input='obs_next', eps=0).act
|
||||
target_q = self(
|
||||
buffer[terminal], model='model_old', input='obs_next').Q
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
|
Loading…
x
Reference in New Issue
Block a user