fix collector

This commit is contained in:
Trinkle23897 2020-03-25 14:08:28 +08:00
parent e95218e295
commit fdc969b830
21 changed files with 288 additions and 250 deletions

View File

@ -0,0 +1,65 @@
import numpy as np
from tianshou.policy import BasePolicy
from tianshou.env import SubprocVectorEnv
from tianshou.data import Collector, Batch
if __name__ == '__main__':
from env import MyTestEnv
else: # pytest
from test.base.env import MyTestEnv
class MyPolicy(BasePolicy):
"""docstring for MyPolicy"""
def __init__(self):
super().__init__()
def __call__(self, batch, state=None):
return Batch(act=np.ones(batch.obs.shape[0]))
def learn(self):
pass
def equal(a, b):
return abs(np.array(a) - np.array(b)).sum() < 1e-6
def test_collector():
env_fns = [
lambda: MyTestEnv(size=2, sleep=0),
lambda: MyTestEnv(size=3, sleep=0),
lambda: MyTestEnv(size=4, sleep=0),
lambda: MyTestEnv(size=5, sleep=0),
]
venv = SubprocVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env)
c0.collect(n_step=3)
assert equal(c0.buffer.obs[:3], [0, 1, 0])
assert equal(c0.buffer.obs_next[:3], [1, 2, 1])
c0.collect(n_episode=3)
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
assert equal(c0.buffer.obs_next[:8], [1, 2, 1, 2, 1, 2, 1, 2])
c1 = Collector(policy, venv)
c1.collect(n_step=6)
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
assert equal(c1.buffer.obs_next[:11], [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
c1.collect(n_episode=2)
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
assert equal(c1.buffer.obs_next[11:21], [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
c2 = Collector(policy, venv)
c2.collect(n_episode=[1, 2, 2, 2])
assert equal(c2.buffer.obs_next[:26], [
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
c2.reset_env()
c2.collect(n_episode=[2, 2, 2, 2])
assert equal(c2.buffer.obs_next[26:54], [
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
if __name__ == '__main__':
test_collector()

View File

@ -37,20 +37,34 @@ def test_framestack(k=4, size=10):
def test_vecenv(size=10, num=8, sleep=0.001):
verbose = __name__ == '__main__'
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
env_fns = [
lambda: MyTestEnv(size=size, sleep=sleep),
lambda: MyTestEnv(size=size + 1, sleep=sleep),
lambda: MyTestEnv(size=size + 2, sleep=sleep),
lambda: MyTestEnv(size=size + 3, sleep=sleep),
lambda: MyTestEnv(size=size + 4, sleep=sleep),
lambda: MyTestEnv(size=size + 5, sleep=sleep),
lambda: MyTestEnv(size=size + 6, sleep=sleep),
lambda: MyTestEnv(size=size + 7, sleep=sleep),
]
venv = [
VectorEnv(env_fns, reset_after_done=True),
SubprocVectorEnv(env_fns, reset_after_done=True),
VectorEnv(env_fns),
SubprocVectorEnv(env_fns),
]
if verbose:
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
venv.append(RayVectorEnv(env_fns))
for v in venv:
v.seed()
action_list = [1] * 5 + [0] * 10 + [1] * 15
action_list = [1] * 5 + [0] * 10 + [1] * 20
if not verbose:
o = [v.reset() for v in venv]
for i, a in enumerate(action_list):
o = [v.step([a] * num) for v in venv]
o = []
for v in venv:
A, B, C, D = v.step([a] * num)
if sum(C):
A = v.reset(np.where(C)[0])
o.append([A, B, C, D])
for i in zip(*o):
for j in range(1, len(i)):
assert (i[0] == i[j]).all()
@ -60,7 +74,9 @@ def test_vecenv(size=10, num=8, sleep=0.001):
t[i] = time.time()
e.reset()
for a in action_list:
e.step([a] * num)
done = e.step([a] * num)[2]
if sum(done) > 0:
e.reset(np.where(done)[0])
t[i] = time.time() - t[i]
print(f'VectorEnv: {t[0]:.6f}s')
print(f'SubprocVectorEnv: {t[1]:.6f}s')
@ -69,40 +85,6 @@ def test_vecenv(size=10, num=8, sleep=0.001):
v.close()
def test_vecenv2():
verbose = __name__ == '__main__'
env_fns = [
lambda: MyTestEnv(size=1),
lambda: MyTestEnv(size=2),
lambda: MyTestEnv(size=3),
lambda: MyTestEnv(size=4),
]
num = len(env_fns)
venv = [
VectorEnv(env_fns, reset_after_done=False),
SubprocVectorEnv(env_fns, reset_after_done=False),
]
if verbose:
venv.append(RayVectorEnv(env_fns, reset_after_done=False))
for v in venv:
v.seed()
o = [v.reset() for v in venv]
action_list = [1] * 6
for i, a in enumerate(action_list):
o = [v.step([a] * num) for v in venv]
if verbose:
print(o[0])
print(o[1])
print(o[2])
print('---')
for i in zip(*o):
for j in range(1, len(i)):
assert (i[0] == i[j]).all()
for v in venv:
v.close()
if __name__ == '__main__':
test_framestack()
test_vecenv()
test_vecenv2()

View File

@ -42,7 +42,8 @@ class ActorProb(nn.Module):
self._max = max_action
def forward(self, s, **kwargs):
s = torch.tensor(s, device=self.device, dtype=torch.float)
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
@ -64,8 +65,9 @@ class Critic(nn.Module):
self.model = nn.Sequential(*self.model)
def forward(self, s, a=None):
s = torch.tensor(s, device=self.device, dtype=torch.float)
if isinstance(a, np.ndarray):
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
if a is not None and not isinstance(a, torch.Tensor):
a = torch.tensor(a, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)

View File

@ -50,12 +50,10 @@ def test_ddpg(args=get_args()):
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -75,10 +73,10 @@ def test_ddpg(args=get_args()):
actor, actor_optim, critic, critic_optim,
args.tau, args.gamma, args.exploration_noise,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True)
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
# log
writer = SummaryWriter(args.logdir)
@ -91,8 +89,7 @@ def test_ddpg(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
if args.task == 'Pendulum-v0':
assert stop_fn(result['best_reward'])
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':

View File

@ -21,12 +21,12 @@ def get_args():
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=16)
@ -54,12 +54,10 @@ def _test_ppo(args=get_args()):
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -85,7 +83,8 @@ def _test_ppo(args=get_args()):
action_range=[env.action_space.low[0], env.action_space.high[0]])
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
policy, train_envs, ReplayBuffer(args.buffer_size),
remove_done_flag=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.step_per_epoch)
# log

View File

@ -50,12 +50,10 @@ def test_sac(args=get_args()):
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -79,12 +77,12 @@ def test_sac(args=get_args()):
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
args.tau, args.gamma, args.alpha,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True)
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.buffer_size)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir)
@ -96,8 +94,7 @@ def test_sac(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
if args.task == 'Pendulum-v0':
assert stop_fn(result['best_reward'])
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':

View File

@ -53,12 +53,10 @@ def test_td3(args=get_args()):
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = VectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -83,12 +81,12 @@ def test_td3(args=get_args()):
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
args.update_actor_freq, args.noise_clip,
[env.action_space.low[0], env.action_space.high[0]],
reward_normalization=True)
reward_normalization=True, ignore_done=True)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.buffer_size)
# train_collector.collect(n_step=args.buffer_size)
# log
writer = SummaryWriter(args.logdir)
@ -100,8 +98,7 @@ def test_td3(args=get_args()):
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
if args.task == 'Pendulum-v0':
assert stop_fn(result['best_reward'])
assert stop_fn(result['best_reward'])
train_collector.close()
test_collector.close()
if __name__ == '__main__':

View File

@ -18,7 +18,8 @@ class Net(nn.Module):
self.model = nn.Sequential(*self.model)
def forward(self, s, state=None, info={}):
s = torch.tensor(s, device=self.device, dtype=torch.float)
if not isinstance(s, torch.Tensor):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)

View File

@ -24,7 +24,7 @@ def get_args():
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--repeat-per-collect', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
@ -49,12 +49,10 @@ def test_a2c(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)

View File

@ -23,11 +23,12 @@ def get_args():
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
@ -47,12 +48,10 @@ def test_dqn(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
@ -62,12 +61,17 @@ def test_dqn(args=get_args()):
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
policy = DQNPolicy(
net, optim, args.gamma, args.n_step,
use_target_network=args.target_update_freq > 0,
target_update_freq=args.target_update_freq)
# collector
train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.buffer_size)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size)
print(len(train_collector.buffer))
# log
writer = SummaryWriter(args.logdir)
@ -75,7 +79,6 @@ def test_dqn(args=get_args()):
return x >= env.spec.reward_threshold
def train_fn(x):
policy.sync_weight()
policy.set_eps(args.eps_train)
def test_fn(x):

View File

@ -99,12 +99,10 @@ def test_pg(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)

View File

@ -19,9 +19,9 @@ else: # pytest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-3)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=1000)
@ -50,12 +50,10 @@ def test_ppo(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)

View File

@ -11,14 +11,17 @@ from tianshou.utils import MovAvg
class Collector(object):
"""docstring for Collector"""
def __init__(self, policy, env, buffer=ReplayBuffer(20000), stat_size=100):
def __init__(self, policy, env, buffer=None, stat_size=100):
super().__init__()
self.env = env
self.env_num = 1
self.collect_step = 0
self.collect_episode = 0
self.collect_time = 0
self.buffer = buffer
if buffer is None:
self.buffer = ReplayBuffer(20000)
else:
self.buffer = buffer
self.policy = policy
self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv)
@ -34,7 +37,7 @@ class Collector(object):
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer):
self._cached_buf = [
deepcopy(buffer) for _ in range(self.env_num)]
deepcopy(self.buffer) for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.reset_env()
@ -64,11 +67,11 @@ class Collector(object):
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
self.env.seed(seed)
return self.env.seed(seed)
def render(self, **kwargs):
if hasattr(self.env, 'render'):
self.env.render(**kwargs)
return self.env.render(**kwargs)
def close(self):
if hasattr(self.env, 'close'):
@ -78,11 +81,13 @@ class Collector(object):
if isinstance(data, np.ndarray):
return data[None]
else:
return [data]
return np.array([data])
def collect(self, n_step=0, n_episode=0, render=0):
if not self._multi_env:
n_episode = np.sum(n_episode)
start_time = time.time()
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
assert sum([(n_step != 0), (n_episode != 0)]) == 1,\
"One and only one collection number specification permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
@ -105,8 +110,10 @@ class Collector(object):
self.state = result.state if hasattr(result, 'state') else None
if isinstance(result.act, torch.Tensor):
self._act = result.act.detach().cpu().numpy()
else:
elif not isinstance(self._act, np.ndarray):
self._act = np.array(result.act)
else:
self._act = result.act
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0])
if render > 0:
@ -116,9 +123,6 @@ class Collector(object):
self.reward += self._rew
if self._multi_env:
for i in range(self.env_num):
if not self.env.is_reset_after_done()\
and cur_episode[i] > 0:
continue
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
@ -132,13 +136,16 @@ class Collector(object):
self.buffer.add(**data)
cur_step += 1
if self._done[i]:
cur_episode[i] += 1
reward_sum += self.reward[i]
length_sum += self.length[i]
if n_step != 0 or np.isscalar(n_episode) or \
cur_episode[i] < n_episode[i]:
cur_episode[i] += 1
reward_sum += self.reward[i]
length_sum += self.length[i]
if self._cached_buf:
cur_step += len(self._cached_buf[i])
self.buffer.update(self._cached_buf[i])
self.reward[i], self.length[i] = 0, 0
if self._cached_buf:
self.buffer.update(self._cached_buf[i])
cur_step += len(self._cached_buf[i])
self._cached_buf[i].reset()
if isinstance(self.state, list):
self.state[i] = None
@ -150,8 +157,14 @@ class Collector(object):
if isinstance(self.state, torch.Tensor):
# remove ref count in pytorch (?)
self.state = self.state.detach()
if n_episode > 0 and cur_episode.sum() >= n_episode:
break
if sum(self._done):
obs_next = self.env.reset(np.where(self._done)[0])
if n_episode != 0:
if isinstance(n_episode, list) and \
(cur_episode >= np.array(n_episode)).all() or \
np.isscalar(n_episode) and \
cur_episode.sum() >= n_episode:
break
else:
self.buffer.add(
self._obs, self._act[0], self._rew,
@ -163,10 +176,10 @@ class Collector(object):
length_sum += self.length
self.reward, self.length = 0, 0
self.state = None
self._obs = self.env.reset()
if n_episode > 0 and cur_episode >= n_episode:
obs_next = self.env.reset()
if n_episode != 0 and cur_episode >= n_episode:
break
if n_step > 0 and cur_step >= n_step:
if n_step != 0 and cur_step >= n_step:
break
self._obs = obs_next
self._obs = obs_next
@ -178,13 +191,17 @@ class Collector(object):
self.collect_step += cur_step
self.collect_episode += cur_episode
self.collect_time += duration
if isinstance(n_episode, list):
n_episode = np.sum(n_episode)
else:
n_episode = max(cur_episode, 1)
return {
'n/ep': cur_episode,
'n/st': cur_step,
'v/st': self.step_speed.get(),
'v/ep': self.episode_speed.get(),
'rew': reward_sum / cur_episode,
'len': length_sum / cur_episode,
'rew': reward_sum / n_episode,
'len': length_sum / n_episode,
}
def sample(self, batch_size):

View File

@ -14,11 +14,11 @@ class EnvWrapper(object):
def seed(self, seed=None):
if hasattr(self.env, 'seed'):
self.env.seed(seed)
return self.env.seed(seed)
def render(self, **kwargs):
if hasattr(self.env, 'render'):
self.env.render(**kwargs)
return self.env.render(**kwargs)
def close(self):
self.env.close()

155
tianshou/env/vecenv.py vendored
View File

@ -10,14 +10,9 @@ from tianshou.env import EnvWrapper, CloudpickleWrapper
class BaseVectorEnv(ABC):
def __init__(self, env_fns, reset_after_done):
def __init__(self, env_fns):
self._env_fns = env_fns
self.env_num = len(env_fns)
self._reset_after_done = reset_after_done
self._done = np.zeros(self.env_num)
def is_reset_after_done(self):
return self._reset_after_done
def __len__(self):
return self.env_num
@ -46,67 +41,62 @@ class BaseVectorEnv(ABC):
class VectorEnv(BaseVectorEnv):
"""docstring for VectorEnv"""
def __init__(self, env_fns, reset_after_done=False):
super().__init__(env_fns, reset_after_done)
def __init__(self, env_fns):
super().__init__(env_fns)
self.envs = [_() for _ in env_fns]
def reset(self):
self._done = np.zeros(self.env_num)
self._obs = np.stack([e.reset() for e in self.envs])
def reset(self, id=None):
if id is None:
self._obs = np.stack([e.reset() for e in self.envs])
else:
if np.isscalar(id):
id = [id]
for i in id:
self._obs[i] = self.envs[i].reset()
return self._obs
def step(self, action):
assert len(action) == self.env_num
result = []
for i, e in enumerate(self.envs):
if not self.is_reset_after_done() and self._done[i]:
result.append([
self._obs[i], self._rew[i], self._done[i], self._info[i]])
else:
result.append(e.step(action[i]))
result = [e.step(a) for e, a in zip(self.envs, action)]
self._obs, self._rew, self._done, self._info = zip(*result)
if self.is_reset_after_done() and sum(self._done):
self._obs = np.stack(self._obs)
for i in np.where(self._done)[0]:
self._obs[i] = self.envs[i].reset()
return np.stack(self._obs), np.stack(self._rew),\
np.stack(self._done), np.stack(self._info)
self._obs = np.stack(self._obs)
self._rew = np.stack(self._rew)
self._done = np.stack(self._done)
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def seed(self, seed=None):
if np.isscalar(seed):
seed = [seed + _ for _ in range(self.env_num)]
elif seed is None:
seed = [seed] * self.env_num
result = []
for e, s in zip(self.envs, seed):
if hasattr(e, 'seed'):
e.seed(s)
result.append(e.seed(s))
return result
def render(self, **kwargs):
result = []
for e in self.envs:
if hasattr(e, 'render'):
e.render(**kwargs)
result.append(e.render(**kwargs))
return result
def close(self):
for e in self.envs:
e.close()
def worker(parent, p, env_fn_wrapper, reset_after_done):
def worker(parent, p, env_fn_wrapper):
parent.close()
env = env_fn_wrapper.data()
done = False
try:
while True:
cmd, data = p.recv()
if cmd == 'step':
if reset_after_done or not done:
obs, rew, done, info = env.step(data)
if reset_after_done and done:
# s_ is useless when episode finishes
obs = env.reset()
p.send([obs, rew, done, info])
p.send(env.step(data))
elif cmd == 'reset':
done = False
p.send(env.reset())
elif cmd == 'close':
p.close()
@ -125,15 +115,14 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
class SubprocVectorEnv(BaseVectorEnv):
"""docstring for SubProcVectorEnv"""
def __init__(self, env_fns, reset_after_done=False):
super().__init__(env_fns, reset_after_done)
def __init__(self, env_fns):
super().__init__(env_fns)
self.closed = False
self.parent_remote, self.child_remote = \
zip(*[Pipe() for _ in range(self.env_num)])
self.processes = [
Process(target=worker, args=(
parent, child,
CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
for (parent, child, env_fn) in zip(
self.parent_remote, self.child_remote, env_fns)
]
@ -147,13 +136,27 @@ class SubprocVectorEnv(BaseVectorEnv):
for p, a in zip(self.parent_remote, action):
p.send(['step', a])
result = [p.recv() for p in self.parent_remote]
obs, rew, done, info = zip(*result)
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
self._obs, self._rew, self._done, self._info = zip(*result)
self._obs = np.stack(self._obs)
self._rew = np.stack(self._rew)
self._done = np.stack(self._done)
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def reset(self):
for p in self.parent_remote:
p.send(['reset', None])
return np.stack([p.recv() for p in self.parent_remote])
def reset(self, id=None):
if id is None:
for p in self.parent_remote:
p.send(['reset', None])
self._obs = np.stack([p.recv() for p in self.parent_remote])
return self._obs
else:
if np.isscalar(id):
id = [id]
for i in id:
self.parent_remote[i].send(['reset', None])
for i in id:
self._obs[i] = self.parent_remote[i].recv()
return self._obs
def seed(self, seed=None):
if np.isscalar(seed):
@ -162,14 +165,12 @@ class SubprocVectorEnv(BaseVectorEnv):
seed = [seed] * self.env_num
for p, s in zip(self.parent_remote, seed):
p.send(['seed', s])
for p in self.parent_remote:
p.recv()
return [p.recv() for p in self.parent_remote]
def render(self, **kwargs):
for p in self.parent_remote:
p.send(['render', kwargs])
for p in self.parent_remote:
p.recv()
return [p.recv() for p in self.parent_remote]
def close(self):
if self.closed:
@ -184,8 +185,8 @@ class SubprocVectorEnv(BaseVectorEnv):
class RayVectorEnv(BaseVectorEnv):
"""docstring for RayVectorEnv"""
def __init__(self, env_fns, reset_after_done=False):
super().__init__(env_fns, reset_after_done)
def __init__(self, env_fns):
super().__init__(env_fns)
try:
if not ray.is_initialized():
ray.init()
@ -198,35 +199,27 @@ class RayVectorEnv(BaseVectorEnv):
def step(self, action):
assert len(action) == self.env_num
result_obj = []
for i, e in enumerate(self.envs):
if not self.is_reset_after_done() and self._done[i]:
result_obj.append(None)
else:
result_obj.append(e.step.remote(action[i]))
result = []
for i, r in enumerate(result_obj):
if r is None:
result.append([
self._obs[i], self._rew[i], self._done[i], self._info[i]])
else:
result.append(ray.get(r))
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
result = [ray.get(r) for r in result_obj]
self._obs, self._rew, self._done, self._info = zip(*result)
if self.is_reset_after_done() and sum(self._done):
self._obs = np.stack(self._obs)
index = np.where(self._done)[0]
result_obj = []
for i in range(len(index)):
result_obj.append(self.envs[index[i]].reset.remote())
for i in range(len(index)):
self._obs[index[i]] = ray.get(result_obj[i])
return np.stack(self._obs), np.stack(self._rew),\
np.stack(self._done), np.stack(self._info)
self._obs = np.stack(self._obs)
self._rew = np.stack(self._rew)
self._done = np.stack(self._done)
self._info = np.stack(self._info)
return self._obs, self._rew, self._done, self._info
def reset(self):
self._done = np.zeros(self.env_num)
result_obj = [e.reset.remote() for e in self.envs]
self._obs = np.stack([ray.get(r) for r in result_obj])
def reset(self, id=None):
if id is None:
result_obj = [e.reset.remote() for e in self.envs]
self._obs = np.stack([ray.get(r) for r in result_obj])
else:
result_obj = []
if np.isscalar(id):
id = [id]
for i in id:
result_obj.append(self.envs[i].reset.remote())
for _, i in enumerate(id):
self._obs[i] = ray.get(result_obj[_])
return self._obs
def seed(self, seed=None):
@ -237,15 +230,13 @@ class RayVectorEnv(BaseVectorEnv):
elif seed is None:
seed = [seed] * self.env_num
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
for r in result_obj:
ray.get(r)
return [ray.get(r) for r in result_obj]
def render(self, **kwargs):
if not hasattr(self.envs[0], 'render'):
return
result_obj = [e.render.remote(**kwargs) for e in self.envs]
for r in result_obj:
ray.get(r)
return [ray.get(r) for r in result_obj]
def close(self):
result_obj = [e.close.remote() for e in self.envs]

View File

@ -13,7 +13,8 @@ class DDPGPolicy(BasePolicy):
def __init__(self, actor, actor_optim, critic, critic_optim,
tau=0.005, gamma=0.99, exploration_noise=0.1,
action_range=None, reward_normalization=True):
action_range=None, reward_normalization=False,
ignore_done=False):
super().__init__()
if actor is not None:
self.actor, self.actor_old = actor, deepcopy(actor)
@ -35,6 +36,7 @@ class DDPGPolicy(BasePolicy):
self._action_scale = (action_range[1] - action_range[0]) / 2
# it is only a little difference to use rand_normal
# self.noise = OUNoise()
self._rm_done = ignore_done
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
@ -60,8 +62,12 @@ class DDPGPolicy(BasePolicy):
def process_fn(self, batch, buffer, indice):
if self._rew_norm:
self._rew_mean = buffer.rew.mean()
self._rew_std = buffer.rew.std()
bfr = buffer.rew[:len(buffer)]
mean, std = bfr.mean(), bfr.std()
if std > self.__eps:
batch.rew = (batch.rew - mean) / std
if self._rm_done:
batch.done = batch.done * 0.
return batch
def __call__(self, batch, state=None,
@ -72,10 +78,10 @@ class DDPGPolicy(BasePolicy):
logits += self._action_bias
if eps is None:
eps = self._eps
# noise = np.random.normal(0, eps, size=logits.shape)
# noise = self.noise(logits.shape, eps)
# logits += torch.tensor(noise, device=logits.device)
if eps > 0:
# noise = np.random.normal(0, eps, size=logits.shape)
# logits += torch.tensor(noise, device=logits.device)
# noise = self.noise(logits.shape, eps)
logits += torch.randn(
size=logits.shape, device=logits.device) * eps
logits = logits.clamp(self._range[0], self._range[1])
@ -86,10 +92,8 @@ class DDPGPolicy(BasePolicy):
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
if self._rew_norm:
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()

View File

@ -10,10 +10,9 @@ from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy):
"""docstring for DQNPolicy"""
def __init__(self, model, optim,
discount_factor=0.99,
estimation_step=1,
use_target_network=True):
def __init__(self, model, optim, discount_factor=0.99,
estimation_step=1, use_target_network=True,
target_update_freq=300):
super().__init__()
self.model = model
self.optim = optim
@ -23,6 +22,8 @@ class DQNPolicy(BasePolicy):
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
self._target = use_target_network
self._freq = target_update_freq
self._cnt = 0
if use_target_network:
self.model_old = deepcopy(self.model)
self.model_old.eval()
@ -39,8 +40,7 @@ class DQNPolicy(BasePolicy):
self.model.eval()
def sync_weight(self):
if self._target:
self.model_old.load_state_dict(self.model.state_dict())
self.model_old.load_state_dict(self.model.state_dict())
def process_fn(self, batch, buffer, indice):
returns = np.zeros_like(indice)
@ -84,6 +84,8 @@ class DQNPolicy(BasePolicy):
return Batch(logits=q, act=act, state=h)
def learn(self, batch, batch_size=None, repeat=1):
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
q = self(batch).logits
q = q[np.arange(len(q)), batch.act]
@ -93,4 +95,5 @@ class DQNPolicy(BasePolicy):
loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()
self._cnt += 1
return {'loss': loss.detach().cpu().numpy()}

View File

@ -12,9 +12,10 @@ class SACPolicy(DDPGPolicy):
def __init__(self, actor, actor_optim, critic1, critic1_optim,
critic2, critic2_optim, tau=0.005, gamma=0.99,
alpha=0.2, action_range=None, reward_normalization=True):
alpha=0.2, action_range=None, reward_normalization=False,
ignore_done=False):
super().__init__(None, None, None, None, tau, gamma, 0,
action_range, reward_normalization)
action_range, reward_normalization, ignore_done)
self.actor, self.actor_optim = actor, actor_optim
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
@ -70,10 +71,8 @@ class SACPolicy(DDPGPolicy):
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
if self._rew_norm:
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
obs_result = self(batch)
a = obs_result.act
current_q1, current_q1a = self.critic1(

View File

@ -1,5 +1,4 @@
import torch
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
@ -12,10 +11,11 @@ class TD3Policy(DDPGPolicy):
def __init__(self, actor, actor_optim, critic1, critic1_optim,
critic2, critic2_optim, tau=0.005, gamma=0.99,
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
noise_clip=0.5, action_range=None, reward_normalization=True):
super().__init__(actor, actor_optim, None, None,
tau, gamma, exploration_noise, action_range,
reward_normalization)
noise_clip=0.5, action_range=None,
reward_normalization=False, ignore_done=False):
super().__init__(actor, actor_optim, None, None, tau, gamma,
exploration_noise, action_range, reward_normalization,
ignore_done)
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
self.critic1_old.eval()
self.critic1_optim = critic1_optim
@ -27,7 +27,6 @@ class TD3Policy(DDPGPolicy):
self._noise_clip = noise_clip
self._cnt = 0
self._last = 0
self.__eps = np.finfo(np.float32).eps.item()
def train(self):
self.training = True
@ -63,10 +62,8 @@ class TD3Policy(DDPGPolicy):
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
if self._rew_norm:
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)
critic1_loss = F.mse_loss(current_q1, target_q)

View File

@ -5,15 +5,6 @@ from tianshou.utils import tqdm_config, MovAvg
from tianshou.trainer import test_episode, gather_info
def test(policy, collector, test_fn, epoch, n_episode):
collector.reset_env()
collector.reset_buffer()
policy.eval()
if test_fn:
test_fn(epoch)
return collector.collect(n_episode=n_episode)
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size, train_fn=None, test_fn=None, stop_fn=None,
@ -49,8 +40,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
if train_fn:
train_fn(epoch)
for i in range(min(
result['n/st'] // collect_per_step,
t.total - t.n)):
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += 1
losses = policy.learn(train_collector.sample(batch_size))
for k in result.keys():

View File

@ -7,7 +7,7 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
policy.eval()
if test_fn:
test_fn(epoch)
return collector.collect(n_episode=n_episode)
return collector.collect(n_episode=[1] * n_episode)
def gather_info(start_time, train_c, test_c, best_reward):