fix collector
This commit is contained in:
parent
e95218e295
commit
fdc969b830
65
test/base/test_collector.py
Normal file
65
test/base/test_collector.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import numpy as np
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.data import Collector, Batch
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from env import MyTestEnv
|
||||||
|
else: # pytest
|
||||||
|
from test.base.env import MyTestEnv
|
||||||
|
|
||||||
|
|
||||||
|
class MyPolicy(BasePolicy):
|
||||||
|
"""docstring for MyPolicy"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, batch, state=None):
|
||||||
|
return Batch(act=np.ones(batch.obs.shape[0]))
|
||||||
|
|
||||||
|
def learn(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def equal(a, b):
|
||||||
|
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
||||||
|
|
||||||
|
|
||||||
|
def test_collector():
|
||||||
|
env_fns = [
|
||||||
|
lambda: MyTestEnv(size=2, sleep=0),
|
||||||
|
lambda: MyTestEnv(size=3, sleep=0),
|
||||||
|
lambda: MyTestEnv(size=4, sleep=0),
|
||||||
|
lambda: MyTestEnv(size=5, sleep=0),
|
||||||
|
]
|
||||||
|
venv = SubprocVectorEnv(env_fns)
|
||||||
|
policy = MyPolicy()
|
||||||
|
env = env_fns[0]()
|
||||||
|
c0 = Collector(policy, env)
|
||||||
|
c0.collect(n_step=3)
|
||||||
|
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||||
|
assert equal(c0.buffer.obs_next[:3], [1, 2, 1])
|
||||||
|
c0.collect(n_episode=3)
|
||||||
|
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||||
|
assert equal(c0.buffer.obs_next[:8], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||||
|
c1 = Collector(policy, venv)
|
||||||
|
c1.collect(n_step=6)
|
||||||
|
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
||||||
|
assert equal(c1.buffer.obs_next[:11], [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||||
|
c1.collect(n_episode=2)
|
||||||
|
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||||
|
assert equal(c1.buffer.obs_next[11:21], [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||||
|
c2 = Collector(policy, venv)
|
||||||
|
c2.collect(n_episode=[1, 2, 2, 2])
|
||||||
|
assert equal(c2.buffer.obs_next[:26], [
|
||||||
|
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||||
|
c2.reset_env()
|
||||||
|
c2.collect(n_episode=[2, 2, 2, 2])
|
||||||
|
assert equal(c2.buffer.obs_next[26:54], [
|
||||||
|
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||||
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_collector()
|
||||||
@ -37,20 +37,34 @@ def test_framestack(k=4, size=10):
|
|||||||
|
|
||||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||||
verbose = __name__ == '__main__'
|
verbose = __name__ == '__main__'
|
||||||
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
|
env_fns = [
|
||||||
|
lambda: MyTestEnv(size=size, sleep=sleep),
|
||||||
|
lambda: MyTestEnv(size=size + 1, sleep=sleep),
|
||||||
|
lambda: MyTestEnv(size=size + 2, sleep=sleep),
|
||||||
|
lambda: MyTestEnv(size=size + 3, sleep=sleep),
|
||||||
|
lambda: MyTestEnv(size=size + 4, sleep=sleep),
|
||||||
|
lambda: MyTestEnv(size=size + 5, sleep=sleep),
|
||||||
|
lambda: MyTestEnv(size=size + 6, sleep=sleep),
|
||||||
|
lambda: MyTestEnv(size=size + 7, sleep=sleep),
|
||||||
|
]
|
||||||
venv = [
|
venv = [
|
||||||
VectorEnv(env_fns, reset_after_done=True),
|
VectorEnv(env_fns),
|
||||||
SubprocVectorEnv(env_fns, reset_after_done=True),
|
SubprocVectorEnv(env_fns),
|
||||||
]
|
]
|
||||||
if verbose:
|
if verbose:
|
||||||
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
|
venv.append(RayVectorEnv(env_fns))
|
||||||
for v in venv:
|
for v in venv:
|
||||||
v.seed()
|
v.seed()
|
||||||
action_list = [1] * 5 + [0] * 10 + [1] * 15
|
action_list = [1] * 5 + [0] * 10 + [1] * 20
|
||||||
if not verbose:
|
if not verbose:
|
||||||
o = [v.reset() for v in venv]
|
o = [v.reset() for v in venv]
|
||||||
for i, a in enumerate(action_list):
|
for i, a in enumerate(action_list):
|
||||||
o = [v.step([a] * num) for v in venv]
|
o = []
|
||||||
|
for v in venv:
|
||||||
|
A, B, C, D = v.step([a] * num)
|
||||||
|
if sum(C):
|
||||||
|
A = v.reset(np.where(C)[0])
|
||||||
|
o.append([A, B, C, D])
|
||||||
for i in zip(*o):
|
for i in zip(*o):
|
||||||
for j in range(1, len(i)):
|
for j in range(1, len(i)):
|
||||||
assert (i[0] == i[j]).all()
|
assert (i[0] == i[j]).all()
|
||||||
@ -60,7 +74,9 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
t[i] = time.time()
|
t[i] = time.time()
|
||||||
e.reset()
|
e.reset()
|
||||||
for a in action_list:
|
for a in action_list:
|
||||||
e.step([a] * num)
|
done = e.step([a] * num)[2]
|
||||||
|
if sum(done) > 0:
|
||||||
|
e.reset(np.where(done)[0])
|
||||||
t[i] = time.time() - t[i]
|
t[i] = time.time() - t[i]
|
||||||
print(f'VectorEnv: {t[0]:.6f}s')
|
print(f'VectorEnv: {t[0]:.6f}s')
|
||||||
print(f'SubprocVectorEnv: {t[1]:.6f}s')
|
print(f'SubprocVectorEnv: {t[1]:.6f}s')
|
||||||
@ -69,40 +85,6 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
v.close()
|
v.close()
|
||||||
|
|
||||||
|
|
||||||
def test_vecenv2():
|
|
||||||
verbose = __name__ == '__main__'
|
|
||||||
env_fns = [
|
|
||||||
lambda: MyTestEnv(size=1),
|
|
||||||
lambda: MyTestEnv(size=2),
|
|
||||||
lambda: MyTestEnv(size=3),
|
|
||||||
lambda: MyTestEnv(size=4),
|
|
||||||
]
|
|
||||||
num = len(env_fns)
|
|
||||||
venv = [
|
|
||||||
VectorEnv(env_fns, reset_after_done=False),
|
|
||||||
SubprocVectorEnv(env_fns, reset_after_done=False),
|
|
||||||
]
|
|
||||||
if verbose:
|
|
||||||
venv.append(RayVectorEnv(env_fns, reset_after_done=False))
|
|
||||||
for v in venv:
|
|
||||||
v.seed()
|
|
||||||
o = [v.reset() for v in venv]
|
|
||||||
action_list = [1] * 6
|
|
||||||
for i, a in enumerate(action_list):
|
|
||||||
o = [v.step([a] * num) for v in venv]
|
|
||||||
if verbose:
|
|
||||||
print(o[0])
|
|
||||||
print(o[1])
|
|
||||||
print(o[2])
|
|
||||||
print('---')
|
|
||||||
for i in zip(*o):
|
|
||||||
for j in range(1, len(i)):
|
|
||||||
assert (i[0] == i[j]).all()
|
|
||||||
for v in venv:
|
|
||||||
v.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_framestack()
|
test_framestack()
|
||||||
test_vecenv()
|
test_vecenv()
|
||||||
test_vecenv2()
|
|
||||||
|
|||||||
@ -42,6 +42,7 @@ class ActorProb(nn.Module):
|
|||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
|
if not isinstance(s, torch.Tensor):
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
@ -64,8 +65,9 @@ class Critic(nn.Module):
|
|||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, a=None):
|
def forward(self, s, a=None):
|
||||||
|
if not isinstance(s, torch.Tensor):
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
if isinstance(a, np.ndarray):
|
if a is not None and not isinstance(a, torch.Tensor):
|
||||||
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
|
|||||||
@ -50,12 +50,10 @@ def test_ddpg(args=get_args()):
|
|||||||
args.max_action = env.action_space.high[0]
|
args.max_action = env.action_space.high[0]
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = VectorEnv(
|
train_envs = VectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -75,10 +73,10 @@ def test_ddpg(args=get_args()):
|
|||||||
actor, actor_optim, critic, critic_optim,
|
actor, actor_optim, critic, critic_optim,
|
||||||
args.tau, args.gamma, args.exploration_noise,
|
args.tau, args.gamma, args.exploration_noise,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True)
|
reward_normalization=True, ignore_done=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
@ -91,7 +89,6 @@ def test_ddpg(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
if args.task == 'Pendulum-v0':
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
|||||||
@ -21,12 +21,12 @@ def get_args():
|
|||||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=1)
|
parser.add_argument('--layer-num', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=16)
|
parser.add_argument('--training-num', type=int, default=16)
|
||||||
@ -54,12 +54,10 @@ def _test_ppo(args=get_args()):
|
|||||||
args.max_action = env.action_space.high[0]
|
args.max_action = env.action_space.high[0]
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -85,7 +83,8 @@ def _test_ppo(args=get_args()):
|
|||||||
action_range=[env.action_space.low[0], env.action_space.high[0]])
|
action_range=[env.action_space.low[0], env.action_space.high[0]])
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||||
|
remove_done_flag=True)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
train_collector.collect(n_step=args.step_per_epoch)
|
train_collector.collect(n_step=args.step_per_epoch)
|
||||||
# log
|
# log
|
||||||
|
|||||||
@ -50,12 +50,10 @@ def test_sac(args=get_args()):
|
|||||||
args.max_action = env.action_space.high[0]
|
args.max_action = env.action_space.high[0]
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = VectorEnv(
|
train_envs = VectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -79,12 +77,12 @@ def test_sac(args=get_args()):
|
|||||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||||
args.tau, args.gamma, args.alpha,
|
args.tau, args.gamma, args.alpha,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True)
|
reward_normalization=True, ignore_done=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
|
|
||||||
@ -96,7 +94,6 @@ def test_sac(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
if args.task == 'Pendulum-v0':
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
|||||||
@ -53,12 +53,10 @@ def test_td3(args=get_args()):
|
|||||||
args.max_action = env.action_space.high[0]
|
args.max_action = env.action_space.high[0]
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = VectorEnv(
|
train_envs = VectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -83,12 +81,12 @@ def test_td3(args=get_args()):
|
|||||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
||||||
args.update_actor_freq, args.noise_clip,
|
args.update_actor_freq, args.noise_clip,
|
||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
reward_normalization=True)
|
reward_normalization=True, ignore_done=True)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
train_collector.collect(n_step=args.buffer_size)
|
# train_collector.collect(n_step=args.buffer_size)
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
|
|
||||||
@ -100,7 +98,6 @@ def test_td3(args=get_args()):
|
|||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
if args.task == 'Pendulum-v0':
|
|
||||||
assert stop_fn(result['best_reward'])
|
assert stop_fn(result['best_reward'])
|
||||||
train_collector.close()
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
|||||||
@ -18,6 +18,7 @@ class Net(nn.Module):
|
|||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, state=None, info={}):
|
def forward(self, s, state=None, info={}):
|
||||||
|
if not isinstance(s, torch.Tensor):
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
|
|||||||
@ -24,7 +24,7 @@ def get_args():
|
|||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
@ -49,12 +49,10 @@ def test_a2c(args=get_args()):
|
|||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
|
|||||||
@ -23,11 +23,12 @@ def get_args():
|
|||||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-4)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--n-step', type=int, default=1)
|
parser.add_argument('--n-step', type=int, default=1)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=3)
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
@ -47,12 +48,10 @@ def test_dqn(args=get_args()):
|
|||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
@ -62,12 +61,17 @@ def test_dqn(args=get_args()):
|
|||||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||||
net = net.to(args.device)
|
net = net.to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
|
policy = DQNPolicy(
|
||||||
|
net, optim, args.gamma, args.n_step,
|
||||||
|
use_target_network=args.target_update_freq > 0,
|
||||||
|
target_update_freq=args.target_update_freq)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs)
|
||||||
train_collector.collect(n_step=args.buffer_size)
|
# policy.set_eps(1)
|
||||||
|
train_collector.collect(n_step=args.batch_size)
|
||||||
|
print(len(train_collector.buffer))
|
||||||
# log
|
# log
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
|
|
||||||
@ -75,7 +79,6 @@ def test_dqn(args=get_args()):
|
|||||||
return x >= env.spec.reward_threshold
|
return x >= env.spec.reward_threshold
|
||||||
|
|
||||||
def train_fn(x):
|
def train_fn(x):
|
||||||
policy.sync_weight()
|
|
||||||
policy.set_eps(args.eps_train)
|
policy.set_eps(args.eps_train)
|
||||||
|
|
||||||
def test_fn(x):
|
def test_fn(x):
|
||||||
|
|||||||
@ -99,12 +99,10 @@ def test_pg(args=get_args()):
|
|||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
|
|||||||
@ -19,9 +19,9 @@ else: # pytest
|
|||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
parser.add_argument('--seed', type=int, default=0)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=3e-3)
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
@ -50,12 +50,10 @@ def test_ppo(args=get_args()):
|
|||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
reset_after_done=True)
|
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
test_envs = SubprocVectorEnv(
|
test_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||||
reset_after_done=False)
|
|
||||||
# seed
|
# seed
|
||||||
np.random.seed(args.seed)
|
np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
|
|||||||
@ -11,13 +11,16 @@ from tianshou.utils import MovAvg
|
|||||||
class Collector(object):
|
class Collector(object):
|
||||||
"""docstring for Collector"""
|
"""docstring for Collector"""
|
||||||
|
|
||||||
def __init__(self, policy, env, buffer=ReplayBuffer(20000), stat_size=100):
|
def __init__(self, policy, env, buffer=None, stat_size=100):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
self.collect_step = 0
|
self.collect_step = 0
|
||||||
self.collect_episode = 0
|
self.collect_episode = 0
|
||||||
self.collect_time = 0
|
self.collect_time = 0
|
||||||
|
if buffer is None:
|
||||||
|
self.buffer = ReplayBuffer(20000)
|
||||||
|
else:
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
@ -34,7 +37,7 @@ class Collector(object):
|
|||||||
self._multi_buf = True
|
self._multi_buf = True
|
||||||
elif isinstance(self.buffer, ReplayBuffer):
|
elif isinstance(self.buffer, ReplayBuffer):
|
||||||
self._cached_buf = [
|
self._cached_buf = [
|
||||||
deepcopy(buffer) for _ in range(self.env_num)]
|
deepcopy(self.buffer) for _ in range(self.env_num)]
|
||||||
else:
|
else:
|
||||||
raise TypeError('The buffer in data collector is invalid!')
|
raise TypeError('The buffer in data collector is invalid!')
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
@ -64,11 +67,11 @@ class Collector(object):
|
|||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if hasattr(self.env, 'seed'):
|
if hasattr(self.env, 'seed'):
|
||||||
self.env.seed(seed)
|
return self.env.seed(seed)
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
if hasattr(self.env, 'render'):
|
if hasattr(self.env, 'render'):
|
||||||
self.env.render(**kwargs)
|
return self.env.render(**kwargs)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if hasattr(self.env, 'close'):
|
if hasattr(self.env, 'close'):
|
||||||
@ -78,11 +81,13 @@ class Collector(object):
|
|||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
return data[None]
|
return data[None]
|
||||||
else:
|
else:
|
||||||
return [data]
|
return np.array([data])
|
||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, render=0):
|
def collect(self, n_step=0, n_episode=0, render=0):
|
||||||
|
if not self._multi_env:
|
||||||
|
n_episode = np.sum(n_episode)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
assert sum([(n_step != 0), (n_episode != 0)]) == 1,\
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
cur_step = 0
|
cur_step = 0
|
||||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||||
@ -105,8 +110,10 @@ class Collector(object):
|
|||||||
self.state = result.state if hasattr(result, 'state') else None
|
self.state = result.state if hasattr(result, 'state') else None
|
||||||
if isinstance(result.act, torch.Tensor):
|
if isinstance(result.act, torch.Tensor):
|
||||||
self._act = result.act.detach().cpu().numpy()
|
self._act = result.act.detach().cpu().numpy()
|
||||||
else:
|
elif not isinstance(self._act, np.ndarray):
|
||||||
self._act = np.array(result.act)
|
self._act = np.array(result.act)
|
||||||
|
else:
|
||||||
|
self._act = result.act
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
self._act if self._multi_env else self._act[0])
|
self._act if self._multi_env else self._act[0])
|
||||||
if render > 0:
|
if render > 0:
|
||||||
@ -116,9 +123,6 @@ class Collector(object):
|
|||||||
self.reward += self._rew
|
self.reward += self._rew
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
for i in range(self.env_num):
|
for i in range(self.env_num):
|
||||||
if not self.env.is_reset_after_done()\
|
|
||||||
and cur_episode[i] > 0:
|
|
||||||
continue
|
|
||||||
data = {
|
data = {
|
||||||
'obs': self._obs[i], 'act': self._act[i],
|
'obs': self._obs[i], 'act': self._act[i],
|
||||||
'rew': self._rew[i], 'done': self._done[i],
|
'rew': self._rew[i], 'done': self._done[i],
|
||||||
@ -132,13 +136,16 @@ class Collector(object):
|
|||||||
self.buffer.add(**data)
|
self.buffer.add(**data)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done[i]:
|
if self._done[i]:
|
||||||
|
if n_step != 0 or np.isscalar(n_episode) or \
|
||||||
|
cur_episode[i] < n_episode[i]:
|
||||||
cur_episode[i] += 1
|
cur_episode[i] += 1
|
||||||
reward_sum += self.reward[i]
|
reward_sum += self.reward[i]
|
||||||
length_sum += self.length[i]
|
length_sum += self.length[i]
|
||||||
|
if self._cached_buf:
|
||||||
|
cur_step += len(self._cached_buf[i])
|
||||||
|
self.buffer.update(self._cached_buf[i])
|
||||||
self.reward[i], self.length[i] = 0, 0
|
self.reward[i], self.length[i] = 0, 0
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
self.buffer.update(self._cached_buf[i])
|
|
||||||
cur_step += len(self._cached_buf[i])
|
|
||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[i] = None
|
self.state[i] = None
|
||||||
@ -150,7 +157,13 @@ class Collector(object):
|
|||||||
if isinstance(self.state, torch.Tensor):
|
if isinstance(self.state, torch.Tensor):
|
||||||
# remove ref count in pytorch (?)
|
# remove ref count in pytorch (?)
|
||||||
self.state = self.state.detach()
|
self.state = self.state.detach()
|
||||||
if n_episode > 0 and cur_episode.sum() >= n_episode:
|
if sum(self._done):
|
||||||
|
obs_next = self.env.reset(np.where(self._done)[0])
|
||||||
|
if n_episode != 0:
|
||||||
|
if isinstance(n_episode, list) and \
|
||||||
|
(cur_episode >= np.array(n_episode)).all() or \
|
||||||
|
np.isscalar(n_episode) and \
|
||||||
|
cur_episode.sum() >= n_episode:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.buffer.add(
|
self.buffer.add(
|
||||||
@ -163,10 +176,10 @@ class Collector(object):
|
|||||||
length_sum += self.length
|
length_sum += self.length
|
||||||
self.reward, self.length = 0, 0
|
self.reward, self.length = 0, 0
|
||||||
self.state = None
|
self.state = None
|
||||||
self._obs = self.env.reset()
|
obs_next = self.env.reset()
|
||||||
if n_episode > 0 and cur_episode >= n_episode:
|
if n_episode != 0 and cur_episode >= n_episode:
|
||||||
break
|
break
|
||||||
if n_step > 0 and cur_step >= n_step:
|
if n_step != 0 and cur_step >= n_step:
|
||||||
break
|
break
|
||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
@ -178,13 +191,17 @@ class Collector(object):
|
|||||||
self.collect_step += cur_step
|
self.collect_step += cur_step
|
||||||
self.collect_episode += cur_episode
|
self.collect_episode += cur_episode
|
||||||
self.collect_time += duration
|
self.collect_time += duration
|
||||||
|
if isinstance(n_episode, list):
|
||||||
|
n_episode = np.sum(n_episode)
|
||||||
|
else:
|
||||||
|
n_episode = max(cur_episode, 1)
|
||||||
return {
|
return {
|
||||||
'n/ep': cur_episode,
|
'n/ep': cur_episode,
|
||||||
'n/st': cur_step,
|
'n/st': cur_step,
|
||||||
'v/st': self.step_speed.get(),
|
'v/st': self.step_speed.get(),
|
||||||
'v/ep': self.episode_speed.get(),
|
'v/ep': self.episode_speed.get(),
|
||||||
'rew': reward_sum / cur_episode,
|
'rew': reward_sum / n_episode,
|
||||||
'len': length_sum / cur_episode,
|
'len': length_sum / n_episode,
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
|
|||||||
4
tianshou/env/common.py
vendored
4
tianshou/env/common.py
vendored
@ -14,11 +14,11 @@ class EnvWrapper(object):
|
|||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if hasattr(self.env, 'seed'):
|
if hasattr(self.env, 'seed'):
|
||||||
self.env.seed(seed)
|
return self.env.seed(seed)
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
if hasattr(self.env, 'render'):
|
if hasattr(self.env, 'render'):
|
||||||
self.env.render(**kwargs)
|
return self.env.render(**kwargs)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.env.close()
|
self.env.close()
|
||||||
|
|||||||
141
tianshou/env/vecenv.py
vendored
141
tianshou/env/vecenv.py
vendored
@ -10,14 +10,9 @@ from tianshou.env import EnvWrapper, CloudpickleWrapper
|
|||||||
|
|
||||||
|
|
||||||
class BaseVectorEnv(ABC):
|
class BaseVectorEnv(ABC):
|
||||||
def __init__(self, env_fns, reset_after_done):
|
def __init__(self, env_fns):
|
||||||
self._env_fns = env_fns
|
self._env_fns = env_fns
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
self._reset_after_done = reset_after_done
|
|
||||||
self._done = np.zeros(self.env_num)
|
|
||||||
|
|
||||||
def is_reset_after_done(self):
|
|
||||||
return self._reset_after_done
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.env_num
|
return self.env_num
|
||||||
@ -46,67 +41,62 @@ class BaseVectorEnv(ABC):
|
|||||||
class VectorEnv(BaseVectorEnv):
|
class VectorEnv(BaseVectorEnv):
|
||||||
"""docstring for VectorEnv"""
|
"""docstring for VectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns):
|
||||||
super().__init__(env_fns, reset_after_done)
|
super().__init__(env_fns)
|
||||||
self.envs = [_() for _ in env_fns]
|
self.envs = [_() for _ in env_fns]
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, id=None):
|
||||||
self._done = np.zeros(self.env_num)
|
if id is None:
|
||||||
self._obs = np.stack([e.reset() for e in self.envs])
|
self._obs = np.stack([e.reset() for e in self.envs])
|
||||||
|
else:
|
||||||
|
if np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
for i in id:
|
||||||
|
self._obs[i] = self.envs[i].reset()
|
||||||
return self._obs
|
return self._obs
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert len(action) == self.env_num
|
assert len(action) == self.env_num
|
||||||
result = []
|
result = [e.step(a) for e, a in zip(self.envs, action)]
|
||||||
for i, e in enumerate(self.envs):
|
|
||||||
if not self.is_reset_after_done() and self._done[i]:
|
|
||||||
result.append([
|
|
||||||
self._obs[i], self._rew[i], self._done[i], self._info[i]])
|
|
||||||
else:
|
|
||||||
result.append(e.step(action[i]))
|
|
||||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||||
if self.is_reset_after_done() and sum(self._done):
|
|
||||||
self._obs = np.stack(self._obs)
|
self._obs = np.stack(self._obs)
|
||||||
for i in np.where(self._done)[0]:
|
self._rew = np.stack(self._rew)
|
||||||
self._obs[i] = self.envs[i].reset()
|
self._done = np.stack(self._done)
|
||||||
return np.stack(self._obs), np.stack(self._rew),\
|
self._info = np.stack(self._info)
|
||||||
np.stack(self._done), np.stack(self._info)
|
return self._obs, self._rew, self._done, self._info
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if np.isscalar(seed):
|
if np.isscalar(seed):
|
||||||
seed = [seed + _ for _ in range(self.env_num)]
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
elif seed is None:
|
elif seed is None:
|
||||||
seed = [seed] * self.env_num
|
seed = [seed] * self.env_num
|
||||||
|
result = []
|
||||||
for e, s in zip(self.envs, seed):
|
for e, s in zip(self.envs, seed):
|
||||||
if hasattr(e, 'seed'):
|
if hasattr(e, 'seed'):
|
||||||
e.seed(s)
|
result.append(e.seed(s))
|
||||||
|
return result
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
|
result = []
|
||||||
for e in self.envs:
|
for e in self.envs:
|
||||||
if hasattr(e, 'render'):
|
if hasattr(e, 'render'):
|
||||||
e.render(**kwargs)
|
result.append(e.render(**kwargs))
|
||||||
|
return result
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
for e in self.envs:
|
for e in self.envs:
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
|
|
||||||
def worker(parent, p, env_fn_wrapper, reset_after_done):
|
def worker(parent, p, env_fn_wrapper):
|
||||||
parent.close()
|
parent.close()
|
||||||
env = env_fn_wrapper.data()
|
env = env_fn_wrapper.data()
|
||||||
done = False
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
cmd, data = p.recv()
|
cmd, data = p.recv()
|
||||||
if cmd == 'step':
|
if cmd == 'step':
|
||||||
if reset_after_done or not done:
|
p.send(env.step(data))
|
||||||
obs, rew, done, info = env.step(data)
|
|
||||||
if reset_after_done and done:
|
|
||||||
# s_ is useless when episode finishes
|
|
||||||
obs = env.reset()
|
|
||||||
p.send([obs, rew, done, info])
|
|
||||||
elif cmd == 'reset':
|
elif cmd == 'reset':
|
||||||
done = False
|
|
||||||
p.send(env.reset())
|
p.send(env.reset())
|
||||||
elif cmd == 'close':
|
elif cmd == 'close':
|
||||||
p.close()
|
p.close()
|
||||||
@ -125,15 +115,14 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
|
|||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
"""docstring for SubProcVectorEnv"""
|
"""docstring for SubProcVectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns):
|
||||||
super().__init__(env_fns, reset_after_done)
|
super().__init__(env_fns)
|
||||||
self.closed = False
|
self.closed = False
|
||||||
self.parent_remote, self.child_remote = \
|
self.parent_remote, self.child_remote = \
|
||||||
zip(*[Pipe() for _ in range(self.env_num)])
|
zip(*[Pipe() for _ in range(self.env_num)])
|
||||||
self.processes = [
|
self.processes = [
|
||||||
Process(target=worker, args=(
|
Process(target=worker, args=(
|
||||||
parent, child,
|
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
||||||
CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
|
|
||||||
for (parent, child, env_fn) in zip(
|
for (parent, child, env_fn) in zip(
|
||||||
self.parent_remote, self.child_remote, env_fns)
|
self.parent_remote, self.child_remote, env_fns)
|
||||||
]
|
]
|
||||||
@ -147,13 +136,27 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
for p, a in zip(self.parent_remote, action):
|
for p, a in zip(self.parent_remote, action):
|
||||||
p.send(['step', a])
|
p.send(['step', a])
|
||||||
result = [p.recv() for p in self.parent_remote]
|
result = [p.recv() for p in self.parent_remote]
|
||||||
obs, rew, done, info = zip(*result)
|
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
self._obs = np.stack(self._obs)
|
||||||
|
self._rew = np.stack(self._rew)
|
||||||
|
self._done = np.stack(self._done)
|
||||||
|
self._info = np.stack(self._info)
|
||||||
|
return self._obs, self._rew, self._done, self._info
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, id=None):
|
||||||
|
if id is None:
|
||||||
for p in self.parent_remote:
|
for p in self.parent_remote:
|
||||||
p.send(['reset', None])
|
p.send(['reset', None])
|
||||||
return np.stack([p.recv() for p in self.parent_remote])
|
self._obs = np.stack([p.recv() for p in self.parent_remote])
|
||||||
|
return self._obs
|
||||||
|
else:
|
||||||
|
if np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
for i in id:
|
||||||
|
self.parent_remote[i].send(['reset', None])
|
||||||
|
for i in id:
|
||||||
|
self._obs[i] = self.parent_remote[i].recv()
|
||||||
|
return self._obs
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if np.isscalar(seed):
|
if np.isscalar(seed):
|
||||||
@ -162,14 +165,12 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
seed = [seed] * self.env_num
|
seed = [seed] * self.env_num
|
||||||
for p, s in zip(self.parent_remote, seed):
|
for p, s in zip(self.parent_remote, seed):
|
||||||
p.send(['seed', s])
|
p.send(['seed', s])
|
||||||
for p in self.parent_remote:
|
return [p.recv() for p in self.parent_remote]
|
||||||
p.recv()
|
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
for p in self.parent_remote:
|
for p in self.parent_remote:
|
||||||
p.send(['render', kwargs])
|
p.send(['render', kwargs])
|
||||||
for p in self.parent_remote:
|
return [p.recv() for p in self.parent_remote]
|
||||||
p.recv()
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
if self.closed:
|
if self.closed:
|
||||||
@ -184,8 +185,8 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
class RayVectorEnv(BaseVectorEnv):
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
"""docstring for RayVectorEnv"""
|
"""docstring for RayVectorEnv"""
|
||||||
|
|
||||||
def __init__(self, env_fns, reset_after_done=False):
|
def __init__(self, env_fns):
|
||||||
super().__init__(env_fns, reset_after_done)
|
super().__init__(env_fns)
|
||||||
try:
|
try:
|
||||||
if not ray.is_initialized():
|
if not ray.is_initialized():
|
||||||
ray.init()
|
ray.init()
|
||||||
@ -198,35 +199,27 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert len(action) == self.env_num
|
assert len(action) == self.env_num
|
||||||
result_obj = []
|
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||||
for i, e in enumerate(self.envs):
|
result = [ray.get(r) for r in result_obj]
|
||||||
if not self.is_reset_after_done() and self._done[i]:
|
|
||||||
result_obj.append(None)
|
|
||||||
else:
|
|
||||||
result_obj.append(e.step.remote(action[i]))
|
|
||||||
result = []
|
|
||||||
for i, r in enumerate(result_obj):
|
|
||||||
if r is None:
|
|
||||||
result.append([
|
|
||||||
self._obs[i], self._rew[i], self._done[i], self._info[i]])
|
|
||||||
else:
|
|
||||||
result.append(ray.get(r))
|
|
||||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||||
if self.is_reset_after_done() and sum(self._done):
|
|
||||||
self._obs = np.stack(self._obs)
|
self._obs = np.stack(self._obs)
|
||||||
index = np.where(self._done)[0]
|
self._rew = np.stack(self._rew)
|
||||||
result_obj = []
|
self._done = np.stack(self._done)
|
||||||
for i in range(len(index)):
|
self._info = np.stack(self._info)
|
||||||
result_obj.append(self.envs[index[i]].reset.remote())
|
return self._obs, self._rew, self._done, self._info
|
||||||
for i in range(len(index)):
|
|
||||||
self._obs[index[i]] = ray.get(result_obj[i])
|
|
||||||
return np.stack(self._obs), np.stack(self._rew),\
|
|
||||||
np.stack(self._done), np.stack(self._info)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self, id=None):
|
||||||
self._done = np.zeros(self.env_num)
|
if id is None:
|
||||||
result_obj = [e.reset.remote() for e in self.envs]
|
result_obj = [e.reset.remote() for e in self.envs]
|
||||||
self._obs = np.stack([ray.get(r) for r in result_obj])
|
self._obs = np.stack([ray.get(r) for r in result_obj])
|
||||||
|
else:
|
||||||
|
result_obj = []
|
||||||
|
if np.isscalar(id):
|
||||||
|
id = [id]
|
||||||
|
for i in id:
|
||||||
|
result_obj.append(self.envs[i].reset.remote())
|
||||||
|
for _, i in enumerate(id):
|
||||||
|
self._obs[i] = ray.get(result_obj[_])
|
||||||
return self._obs
|
return self._obs
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
@ -237,15 +230,13 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
elif seed is None:
|
elif seed is None:
|
||||||
seed = [seed] * self.env_num
|
seed = [seed] * self.env_num
|
||||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||||
for r in result_obj:
|
return [ray.get(r) for r in result_obj]
|
||||||
ray.get(r)
|
|
||||||
|
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
if not hasattr(self.envs[0], 'render'):
|
if not hasattr(self.envs[0], 'render'):
|
||||||
return
|
return
|
||||||
result_obj = [e.render.remote(**kwargs) for e in self.envs]
|
result_obj = [e.render.remote(**kwargs) for e in self.envs]
|
||||||
for r in result_obj:
|
return [ray.get(r) for r in result_obj]
|
||||||
ray.get(r)
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
result_obj = [e.close.remote() for e in self.envs]
|
result_obj = [e.close.remote() for e in self.envs]
|
||||||
|
|||||||
@ -13,7 +13,8 @@ class DDPGPolicy(BasePolicy):
|
|||||||
|
|
||||||
def __init__(self, actor, actor_optim, critic, critic_optim,
|
def __init__(self, actor, actor_optim, critic, critic_optim,
|
||||||
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
||||||
action_range=None, reward_normalization=True):
|
action_range=None, reward_normalization=False,
|
||||||
|
ignore_done=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if actor is not None:
|
if actor is not None:
|
||||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||||
@ -35,6 +36,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
self._action_scale = (action_range[1] - action_range[0]) / 2
|
self._action_scale = (action_range[1] - action_range[0]) / 2
|
||||||
# it is only a little difference to use rand_normal
|
# it is only a little difference to use rand_normal
|
||||||
# self.noise = OUNoise()
|
# self.noise = OUNoise()
|
||||||
|
self._rm_done = ignore_done
|
||||||
self._rew_norm = reward_normalization
|
self._rew_norm = reward_normalization
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
@ -60,8 +62,12 @@ class DDPGPolicy(BasePolicy):
|
|||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
if self._rew_norm:
|
if self._rew_norm:
|
||||||
self._rew_mean = buffer.rew.mean()
|
bfr = buffer.rew[:len(buffer)]
|
||||||
self._rew_std = buffer.rew.std()
|
mean, std = bfr.mean(), bfr.std()
|
||||||
|
if std > self.__eps:
|
||||||
|
batch.rew = (batch.rew - mean) / std
|
||||||
|
if self._rm_done:
|
||||||
|
batch.done = batch.done * 0.
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def __call__(self, batch, state=None,
|
def __call__(self, batch, state=None,
|
||||||
@ -72,10 +78,10 @@ class DDPGPolicy(BasePolicy):
|
|||||||
logits += self._action_bias
|
logits += self._action_bias
|
||||||
if eps is None:
|
if eps is None:
|
||||||
eps = self._eps
|
eps = self._eps
|
||||||
# noise = np.random.normal(0, eps, size=logits.shape)
|
|
||||||
# noise = self.noise(logits.shape, eps)
|
|
||||||
# logits += torch.tensor(noise, device=logits.device)
|
|
||||||
if eps > 0:
|
if eps > 0:
|
||||||
|
# noise = np.random.normal(0, eps, size=logits.shape)
|
||||||
|
# logits += torch.tensor(noise, device=logits.device)
|
||||||
|
# noise = self.noise(logits.shape, eps)
|
||||||
logits += torch.randn(
|
logits += torch.randn(
|
||||||
size=logits.shape, device=logits.device) * eps
|
size=logits.shape, device=logits.device) * eps
|
||||||
logits = logits.clamp(self._range[0], self._range[1])
|
logits = logits.clamp(self._range[0], self._range[1])
|
||||||
@ -86,10 +92,8 @@ class DDPGPolicy(BasePolicy):
|
|||||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||||
dev = target_q.device
|
dev = target_q.device
|
||||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
||||||
if self._rew_norm:
|
|
||||||
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
|
|
||||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
||||||
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
|
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
||||||
current_q = self.critic(batch.obs, batch.act)
|
current_q = self.critic(batch.obs, batch.act)
|
||||||
critic_loss = F.mse_loss(current_q, target_q)
|
critic_loss = F.mse_loss(current_q, target_q)
|
||||||
self.critic_optim.zero_grad()
|
self.critic_optim.zero_grad()
|
||||||
|
|||||||
@ -10,10 +10,9 @@ from tianshou.policy import BasePolicy
|
|||||||
class DQNPolicy(BasePolicy):
|
class DQNPolicy(BasePolicy):
|
||||||
"""docstring for DQNPolicy"""
|
"""docstring for DQNPolicy"""
|
||||||
|
|
||||||
def __init__(self, model, optim,
|
def __init__(self, model, optim, discount_factor=0.99,
|
||||||
discount_factor=0.99,
|
estimation_step=1, use_target_network=True,
|
||||||
estimation_step=1,
|
target_update_freq=300):
|
||||||
use_target_network=True):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
@ -23,6 +22,8 @@ class DQNPolicy(BasePolicy):
|
|||||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||||
self._n_step = estimation_step
|
self._n_step = estimation_step
|
||||||
self._target = use_target_network
|
self._target = use_target_network
|
||||||
|
self._freq = target_update_freq
|
||||||
|
self._cnt = 0
|
||||||
if use_target_network:
|
if use_target_network:
|
||||||
self.model_old = deepcopy(self.model)
|
self.model_old = deepcopy(self.model)
|
||||||
self.model_old.eval()
|
self.model_old.eval()
|
||||||
@ -39,7 +40,6 @@ class DQNPolicy(BasePolicy):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def sync_weight(self):
|
def sync_weight(self):
|
||||||
if self._target:
|
|
||||||
self.model_old.load_state_dict(self.model.state_dict())
|
self.model_old.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
@ -84,6 +84,8 @@ class DQNPolicy(BasePolicy):
|
|||||||
return Batch(logits=q, act=act, state=h)
|
return Batch(logits=q, act=act, state=h)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1):
|
def learn(self, batch, batch_size=None, repeat=1):
|
||||||
|
if self._target and self._cnt % self._freq == 0:
|
||||||
|
self.sync_weight()
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
q = self(batch).logits
|
q = self(batch).logits
|
||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
@ -93,4 +95,5 @@ class DQNPolicy(BasePolicy):
|
|||||||
loss = F.mse_loss(q, r)
|
loss = F.mse_loss(q, r)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
|
self._cnt += 1
|
||||||
return {'loss': loss.detach().cpu().numpy()}
|
return {'loss': loss.detach().cpu().numpy()}
|
||||||
|
|||||||
@ -12,9 +12,10 @@ class SACPolicy(DDPGPolicy):
|
|||||||
|
|
||||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||||
alpha=0.2, action_range=None, reward_normalization=True):
|
alpha=0.2, action_range=None, reward_normalization=False,
|
||||||
|
ignore_done=False):
|
||||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||||
action_range, reward_normalization)
|
action_range, reward_normalization, ignore_done)
|
||||||
self.actor, self.actor_optim = actor, actor_optim
|
self.actor, self.actor_optim = actor, actor_optim
|
||||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||||
self.critic1_old.eval()
|
self.critic1_old.eval()
|
||||||
@ -70,10 +71,8 @@ class SACPolicy(DDPGPolicy):
|
|||||||
self.critic2_old(batch.obs_next, a_),
|
self.critic2_old(batch.obs_next, a_),
|
||||||
) - self._alpha * obs_next_result.log_prob
|
) - self._alpha * obs_next_result.log_prob
|
||||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
||||||
if self._rew_norm:
|
|
||||||
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
|
|
||||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
||||||
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
|
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
||||||
obs_result = self(batch)
|
obs_result = self(batch)
|
||||||
a = obs_result.act
|
a = obs_result.act
|
||||||
current_q1, current_q1a = self.critic1(
|
current_q1, current_q1a = self.critic1(
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
@ -12,10 +11,11 @@ class TD3Policy(DDPGPolicy):
|
|||||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||||
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
|
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
|
||||||
noise_clip=0.5, action_range=None, reward_normalization=True):
|
noise_clip=0.5, action_range=None,
|
||||||
super().__init__(actor, actor_optim, None, None,
|
reward_normalization=False, ignore_done=False):
|
||||||
tau, gamma, exploration_noise, action_range,
|
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||||
reward_normalization)
|
exploration_noise, action_range, reward_normalization,
|
||||||
|
ignore_done)
|
||||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||||
self.critic1_old.eval()
|
self.critic1_old.eval()
|
||||||
self.critic1_optim = critic1_optim
|
self.critic1_optim = critic1_optim
|
||||||
@ -27,7 +27,6 @@ class TD3Policy(DDPGPolicy):
|
|||||||
self._noise_clip = noise_clip
|
self._noise_clip = noise_clip
|
||||||
self._cnt = 0
|
self._cnt = 0
|
||||||
self._last = 0
|
self._last = 0
|
||||||
self.__eps = np.finfo(np.float32).eps.item()
|
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
self.training = True
|
self.training = True
|
||||||
@ -63,10 +62,8 @@ class TD3Policy(DDPGPolicy):
|
|||||||
self.critic1_old(batch.obs_next, a_),
|
self.critic1_old(batch.obs_next, a_),
|
||||||
self.critic2_old(batch.obs_next, a_))
|
self.critic2_old(batch.obs_next, a_))
|
||||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
||||||
if self._rew_norm:
|
|
||||||
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
|
|
||||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
||||||
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
|
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
||||||
# critic 1
|
# critic 1
|
||||||
current_q1 = self.critic1(batch.obs, batch.act)
|
current_q1 = self.critic1(batch.obs, batch.act)
|
||||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||||
|
|||||||
@ -5,15 +5,6 @@ from tianshou.utils import tqdm_config, MovAvg
|
|||||||
from tianshou.trainer import test_episode, gather_info
|
from tianshou.trainer import test_episode, gather_info
|
||||||
|
|
||||||
|
|
||||||
def test(policy, collector, test_fn, epoch, n_episode):
|
|
||||||
collector.reset_env()
|
|
||||||
collector.reset_buffer()
|
|
||||||
policy.eval()
|
|
||||||
if test_fn:
|
|
||||||
test_fn(epoch)
|
|
||||||
return collector.collect(n_episode=n_episode)
|
|
||||||
|
|
||||||
|
|
||||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
step_per_epoch, collect_per_step, episode_per_test,
|
step_per_epoch, collect_per_step, episode_per_test,
|
||||||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||||
@ -49,8 +40,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
|||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
for i in range(min(
|
for i in range(min(
|
||||||
result['n/st'] // collect_per_step,
|
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||||
t.total - t.n)):
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
losses = policy.learn(train_collector.sample(batch_size))
|
losses = policy.learn(train_collector.sample(batch_size))
|
||||||
for k in result.keys():
|
for k in result.keys():
|
||||||
|
|||||||
@ -7,7 +7,7 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
|
|||||||
policy.eval()
|
policy.eval()
|
||||||
if test_fn:
|
if test_fn:
|
||||||
test_fn(epoch)
|
test_fn(epoch)
|
||||||
return collector.collect(n_episode=n_episode)
|
return collector.collect(n_episode=[1] * n_episode)
|
||||||
|
|
||||||
|
|
||||||
def gather_info(start_time, train_c, test_c, best_reward):
|
def gather_info(start_time, train_c, test_c, best_reward):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user