fix collector
This commit is contained in:
parent
e95218e295
commit
fdc969b830
65
test/base/test_collector.py
Normal file
65
test/base/test_collector.py
Normal file
@ -0,0 +1,65 @@
|
||||
import numpy as np
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.data import Collector, Batch
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
else: # pytest
|
||||
from test.base.env import MyTestEnv
|
||||
|
||||
|
||||
class MyPolicy(BasePolicy):
|
||||
"""docstring for MyPolicy"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
return Batch(act=np.ones(batch.obs.shape[0]))
|
||||
|
||||
def learn(self):
|
||||
pass
|
||||
|
||||
|
||||
def equal(a, b):
|
||||
return abs(np.array(a) - np.array(b)).sum() < 1e-6
|
||||
|
||||
|
||||
def test_collector():
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=2, sleep=0),
|
||||
lambda: MyTestEnv(size=3, sleep=0),
|
||||
lambda: MyTestEnv(size=4, sleep=0),
|
||||
lambda: MyTestEnv(size=5, sleep=0),
|
||||
]
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(policy, env)
|
||||
c0.collect(n_step=3)
|
||||
assert equal(c0.buffer.obs[:3], [0, 1, 0])
|
||||
assert equal(c0.buffer.obs_next[:3], [1, 2, 1])
|
||||
c0.collect(n_episode=3)
|
||||
assert equal(c0.buffer.obs[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
assert equal(c0.buffer.obs_next[:8], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c1 = Collector(policy, venv)
|
||||
c1.collect(n_step=6)
|
||||
assert equal(c1.buffer.obs[:11], [0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3])
|
||||
assert equal(c1.buffer.obs_next[:11], [1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4])
|
||||
c1.collect(n_episode=2)
|
||||
assert equal(c1.buffer.obs[11:21], [0, 1, 2, 3, 4, 0, 1, 0, 1, 2])
|
||||
assert equal(c1.buffer.obs_next[11:21], [1, 2, 3, 4, 5, 1, 2, 1, 2, 3])
|
||||
c2 = Collector(policy, venv)
|
||||
c2.collect(n_episode=[1, 2, 2, 2])
|
||||
assert equal(c2.buffer.obs_next[:26], [
|
||||
1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||
c2.reset_env()
|
||||
c2.collect(n_episode=[2, 2, 2, 2])
|
||||
assert equal(c2.buffer.obs_next[26:54], [
|
||||
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4, 1, 2, 3, 4, 5,
|
||||
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_collector()
|
||||
@ -37,20 +37,34 @@ def test_framestack(k=4, size=10):
|
||||
|
||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep) for _ in range(num)]
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=size, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 1, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 2, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 3, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 4, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 5, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 6, sleep=sleep),
|
||||
lambda: MyTestEnv(size=size + 7, sleep=sleep),
|
||||
]
|
||||
venv = [
|
||||
VectorEnv(env_fns, reset_after_done=True),
|
||||
SubprocVectorEnv(env_fns, reset_after_done=True),
|
||||
VectorEnv(env_fns),
|
||||
SubprocVectorEnv(env_fns),
|
||||
]
|
||||
if verbose:
|
||||
venv.append(RayVectorEnv(env_fns, reset_after_done=True))
|
||||
venv.append(RayVectorEnv(env_fns))
|
||||
for v in venv:
|
||||
v.seed()
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 15
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 20
|
||||
if not verbose:
|
||||
o = [v.reset() for v in venv]
|
||||
for i, a in enumerate(action_list):
|
||||
o = [v.step([a] * num) for v in venv]
|
||||
o = []
|
||||
for v in venv:
|
||||
A, B, C, D = v.step([a] * num)
|
||||
if sum(C):
|
||||
A = v.reset(np.where(C)[0])
|
||||
o.append([A, B, C, D])
|
||||
for i in zip(*o):
|
||||
for j in range(1, len(i)):
|
||||
assert (i[0] == i[j]).all()
|
||||
@ -60,7 +74,9 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
t[i] = time.time()
|
||||
e.reset()
|
||||
for a in action_list:
|
||||
e.step([a] * num)
|
||||
done = e.step([a] * num)[2]
|
||||
if sum(done) > 0:
|
||||
e.reset(np.where(done)[0])
|
||||
t[i] = time.time() - t[i]
|
||||
print(f'VectorEnv: {t[0]:.6f}s')
|
||||
print(f'SubprocVectorEnv: {t[1]:.6f}s')
|
||||
@ -69,40 +85,6 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
v.close()
|
||||
|
||||
|
||||
def test_vecenv2():
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=1),
|
||||
lambda: MyTestEnv(size=2),
|
||||
lambda: MyTestEnv(size=3),
|
||||
lambda: MyTestEnv(size=4),
|
||||
]
|
||||
num = len(env_fns)
|
||||
venv = [
|
||||
VectorEnv(env_fns, reset_after_done=False),
|
||||
SubprocVectorEnv(env_fns, reset_after_done=False),
|
||||
]
|
||||
if verbose:
|
||||
venv.append(RayVectorEnv(env_fns, reset_after_done=False))
|
||||
for v in venv:
|
||||
v.seed()
|
||||
o = [v.reset() for v in venv]
|
||||
action_list = [1] * 6
|
||||
for i, a in enumerate(action_list):
|
||||
o = [v.step([a] * num) for v in venv]
|
||||
if verbose:
|
||||
print(o[0])
|
||||
print(o[1])
|
||||
print(o[2])
|
||||
print('---')
|
||||
for i in zip(*o):
|
||||
for j in range(1, len(i)):
|
||||
assert (i[0] == i[j]).all()
|
||||
for v in venv:
|
||||
v.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_framestack()
|
||||
test_vecenv()
|
||||
test_vecenv2()
|
||||
|
||||
@ -42,6 +42,7 @@ class ActorProb(nn.Module):
|
||||
self._max = max_action
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
@ -64,8 +65,9 @@ class Critic(nn.Module):
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, a=None):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
if isinstance(a, np.ndarray):
|
||||
if a is not None and not isinstance(a, torch.Tensor):
|
||||
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
|
||||
@ -50,12 +50,10 @@ def test_ddpg(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -75,10 +73,10 @@ def test_ddpg(args=get_args()):
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
args.tau, args.gamma, args.exploration_noise,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=True)
|
||||
reward_normalization=True, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir)
|
||||
@ -91,7 +89,6 @@ def test_ddpg(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
if args.task == 'Pendulum-v0':
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -21,12 +21,12 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
@ -54,12 +54,10 @@ def _test_ppo(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -85,7 +83,8 @@ def _test_ppo(args=get_args()):
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]])
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size),
|
||||
remove_done_flag=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.step_per_epoch)
|
||||
# log
|
||||
|
||||
@ -50,12 +50,10 @@ def test_sac(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -79,12 +77,12 @@ def test_sac(args=get_args()):
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
args.tau, args.gamma, args.alpha,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=True)
|
||||
reward_normalization=True, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.buffer_size)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir)
|
||||
|
||||
@ -96,7 +94,6 @@ def test_sac(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
if args.task == 'Pendulum-v0':
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -53,12 +53,10 @@ def test_td3(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -83,12 +81,12 @@ def test_td3(args=get_args()):
|
||||
args.tau, args.gamma, args.exploration_noise, args.policy_noise,
|
||||
args.update_actor_freq, args.noise_clip,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
reward_normalization=True)
|
||||
reward_normalization=True, ignore_done=True)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.buffer_size)
|
||||
# train_collector.collect(n_step=args.buffer_size)
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir)
|
||||
|
||||
@ -100,7 +98,6 @@ def test_td3(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
if args.task == 'Pendulum-v0':
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
@ -18,6 +18,7 @@ class Net(nn.Module):
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, state=None, info={}):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
|
||||
@ -24,7 +24,7 @@ def get_args():
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
@ -49,12 +49,10 @@ def test_a2c(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
@ -23,11 +23,12 @@ def get_args():
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
@ -47,12 +48,10 @@ def test_dqn(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -62,12 +61,17 @@ def test_dqn(args=get_args()):
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
|
||||
policy = DQNPolicy(
|
||||
net, optim, args.gamma, args.n_step,
|
||||
use_target_network=args.target_update_freq > 0,
|
||||
target_update_freq=args.target_update_freq)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.buffer_size)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
print(len(train_collector.buffer))
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir)
|
||||
|
||||
@ -75,7 +79,6 @@ def test_dqn(args=get_args()):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
policy.sync_weight()
|
||||
policy.set_eps(args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
|
||||
@ -99,12 +99,10 @@ def test_pg(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
@ -19,9 +19,9 @@ else: # pytest
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=3e-3)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
@ -50,12 +50,10 @@ def test_ppo(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
@ -11,13 +11,16 @@ from tianshou.utils import MovAvg
|
||||
class Collector(object):
|
||||
"""docstring for Collector"""
|
||||
|
||||
def __init__(self, policy, env, buffer=ReplayBuffer(20000), stat_size=100):
|
||||
def __init__(self, policy, env, buffer=None, stat_size=100):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
self.collect_step = 0
|
||||
self.collect_episode = 0
|
||||
self.collect_time = 0
|
||||
if buffer is None:
|
||||
self.buffer = ReplayBuffer(20000)
|
||||
else:
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.process_fn = policy.process_fn
|
||||
@ -34,7 +37,7 @@ class Collector(object):
|
||||
self._multi_buf = True
|
||||
elif isinstance(self.buffer, ReplayBuffer):
|
||||
self._cached_buf = [
|
||||
deepcopy(buffer) for _ in range(self.env_num)]
|
||||
deepcopy(self.buffer) for _ in range(self.env_num)]
|
||||
else:
|
||||
raise TypeError('The buffer in data collector is invalid!')
|
||||
self.reset_env()
|
||||
@ -64,11 +67,11 @@ class Collector(object):
|
||||
|
||||
def seed(self, seed=None):
|
||||
if hasattr(self.env, 'seed'):
|
||||
self.env.seed(seed)
|
||||
return self.env.seed(seed)
|
||||
|
||||
def render(self, **kwargs):
|
||||
if hasattr(self.env, 'render'):
|
||||
self.env.render(**kwargs)
|
||||
return self.env.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
if hasattr(self.env, 'close'):
|
||||
@ -78,11 +81,13 @@ class Collector(object):
|
||||
if isinstance(data, np.ndarray):
|
||||
return data[None]
|
||||
else:
|
||||
return [data]
|
||||
return np.array([data])
|
||||
|
||||
def collect(self, n_step=0, n_episode=0, render=0):
|
||||
if not self._multi_env:
|
||||
n_episode = np.sum(n_episode)
|
||||
start_time = time.time()
|
||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||
assert sum([(n_step != 0), (n_episode != 0)]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
cur_step = 0
|
||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||
@ -105,8 +110,10 @@ class Collector(object):
|
||||
self.state = result.state if hasattr(result, 'state') else None
|
||||
if isinstance(result.act, torch.Tensor):
|
||||
self._act = result.act.detach().cpu().numpy()
|
||||
else:
|
||||
elif not isinstance(self._act, np.ndarray):
|
||||
self._act = np.array(result.act)
|
||||
else:
|
||||
self._act = result.act
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||
self._act if self._multi_env else self._act[0])
|
||||
if render > 0:
|
||||
@ -116,9 +123,6 @@ class Collector(object):
|
||||
self.reward += self._rew
|
||||
if self._multi_env:
|
||||
for i in range(self.env_num):
|
||||
if not self.env.is_reset_after_done()\
|
||||
and cur_episode[i] > 0:
|
||||
continue
|
||||
data = {
|
||||
'obs': self._obs[i], 'act': self._act[i],
|
||||
'rew': self._rew[i], 'done': self._done[i],
|
||||
@ -132,13 +136,16 @@ class Collector(object):
|
||||
self.buffer.add(**data)
|
||||
cur_step += 1
|
||||
if self._done[i]:
|
||||
if n_step != 0 or np.isscalar(n_episode) or \
|
||||
cur_episode[i] < n_episode[i]:
|
||||
cur_episode[i] += 1
|
||||
reward_sum += self.reward[i]
|
||||
length_sum += self.length[i]
|
||||
if self._cached_buf:
|
||||
cur_step += len(self._cached_buf[i])
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
self.reward[i], self.length[i] = 0, 0
|
||||
if self._cached_buf:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
cur_step += len(self._cached_buf[i])
|
||||
self._cached_buf[i].reset()
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
@ -150,7 +157,13 @@ class Collector(object):
|
||||
if isinstance(self.state, torch.Tensor):
|
||||
# remove ref count in pytorch (?)
|
||||
self.state = self.state.detach()
|
||||
if n_episode > 0 and cur_episode.sum() >= n_episode:
|
||||
if sum(self._done):
|
||||
obs_next = self.env.reset(np.where(self._done)[0])
|
||||
if n_episode != 0:
|
||||
if isinstance(n_episode, list) and \
|
||||
(cur_episode >= np.array(n_episode)).all() or \
|
||||
np.isscalar(n_episode) and \
|
||||
cur_episode.sum() >= n_episode:
|
||||
break
|
||||
else:
|
||||
self.buffer.add(
|
||||
@ -163,10 +176,10 @@ class Collector(object):
|
||||
length_sum += self.length
|
||||
self.reward, self.length = 0, 0
|
||||
self.state = None
|
||||
self._obs = self.env.reset()
|
||||
if n_episode > 0 and cur_episode >= n_episode:
|
||||
obs_next = self.env.reset()
|
||||
if n_episode != 0 and cur_episode >= n_episode:
|
||||
break
|
||||
if n_step > 0 and cur_step >= n_step:
|
||||
if n_step != 0 and cur_step >= n_step:
|
||||
break
|
||||
self._obs = obs_next
|
||||
self._obs = obs_next
|
||||
@ -178,13 +191,17 @@ class Collector(object):
|
||||
self.collect_step += cur_step
|
||||
self.collect_episode += cur_episode
|
||||
self.collect_time += duration
|
||||
if isinstance(n_episode, list):
|
||||
n_episode = np.sum(n_episode)
|
||||
else:
|
||||
n_episode = max(cur_episode, 1)
|
||||
return {
|
||||
'n/ep': cur_episode,
|
||||
'n/st': cur_step,
|
||||
'v/st': self.step_speed.get(),
|
||||
'v/ep': self.episode_speed.get(),
|
||||
'rew': reward_sum / cur_episode,
|
||||
'len': length_sum / cur_episode,
|
||||
'rew': reward_sum / n_episode,
|
||||
'len': length_sum / n_episode,
|
||||
}
|
||||
|
||||
def sample(self, batch_size):
|
||||
|
||||
4
tianshou/env/common.py
vendored
4
tianshou/env/common.py
vendored
@ -14,11 +14,11 @@ class EnvWrapper(object):
|
||||
|
||||
def seed(self, seed=None):
|
||||
if hasattr(self.env, 'seed'):
|
||||
self.env.seed(seed)
|
||||
return self.env.seed(seed)
|
||||
|
||||
def render(self, **kwargs):
|
||||
if hasattr(self.env, 'render'):
|
||||
self.env.render(**kwargs)
|
||||
return self.env.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
self.env.close()
|
||||
|
||||
141
tianshou/env/vecenv.py
vendored
141
tianshou/env/vecenv.py
vendored
@ -10,14 +10,9 @@ from tianshou.env import EnvWrapper, CloudpickleWrapper
|
||||
|
||||
|
||||
class BaseVectorEnv(ABC):
|
||||
def __init__(self, env_fns, reset_after_done):
|
||||
def __init__(self, env_fns):
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
self._reset_after_done = reset_after_done
|
||||
self._done = np.zeros(self.env_num)
|
||||
|
||||
def is_reset_after_done(self):
|
||||
return self._reset_after_done
|
||||
|
||||
def __len__(self):
|
||||
return self.env_num
|
||||
@ -46,67 +41,62 @@ class BaseVectorEnv(ABC):
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""docstring for VectorEnv"""
|
||||
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__(env_fns, reset_after_done)
|
||||
def __init__(self, env_fns):
|
||||
super().__init__(env_fns)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def reset(self):
|
||||
self._done = np.zeros(self.env_num)
|
||||
def reset(self, id=None):
|
||||
if id is None:
|
||||
self._obs = np.stack([e.reset() for e in self.envs])
|
||||
else:
|
||||
if np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self._obs[i] = self.envs[i].reset()
|
||||
return self._obs
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
result = []
|
||||
for i, e in enumerate(self.envs):
|
||||
if not self.is_reset_after_done() and self._done[i]:
|
||||
result.append([
|
||||
self._obs[i], self._rew[i], self._done[i], self._info[i]])
|
||||
else:
|
||||
result.append(e.step(action[i]))
|
||||
result = [e.step(a) for e, a in zip(self.envs, action)]
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
if self.is_reset_after_done() and sum(self._done):
|
||||
self._obs = np.stack(self._obs)
|
||||
for i in np.where(self._done)[0]:
|
||||
self._obs[i] = self.envs[i].reset()
|
||||
return np.stack(self._obs), np.stack(self._rew),\
|
||||
np.stack(self._done), np.stack(self._info)
|
||||
self._rew = np.stack(self._rew)
|
||||
self._done = np.stack(self._done)
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def seed(self, seed=None):
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
result = []
|
||||
for e, s in zip(self.envs, seed):
|
||||
if hasattr(e, 'seed'):
|
||||
e.seed(s)
|
||||
result.append(e.seed(s))
|
||||
return result
|
||||
|
||||
def render(self, **kwargs):
|
||||
result = []
|
||||
for e in self.envs:
|
||||
if hasattr(e, 'render'):
|
||||
e.render(**kwargs)
|
||||
result.append(e.render(**kwargs))
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
for e in self.envs:
|
||||
e.close()
|
||||
|
||||
|
||||
def worker(parent, p, env_fn_wrapper, reset_after_done):
|
||||
def worker(parent, p, env_fn_wrapper):
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
done = False
|
||||
try:
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
if reset_after_done or not done:
|
||||
obs, rew, done, info = env.step(data)
|
||||
if reset_after_done and done:
|
||||
# s_ is useless when episode finishes
|
||||
obs = env.reset()
|
||||
p.send([obs, rew, done, info])
|
||||
p.send(env.step(data))
|
||||
elif cmd == 'reset':
|
||||
done = False
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.close()
|
||||
@ -125,15 +115,14 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""docstring for SubProcVectorEnv"""
|
||||
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__(env_fns, reset_after_done)
|
||||
def __init__(self, env_fns):
|
||||
super().__init__(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = \
|
||||
zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [
|
||||
Process(target=worker, args=(
|
||||
parent, child,
|
||||
CloudpickleWrapper(env_fn), reset_after_done), daemon=True)
|
||||
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
||||
for (parent, child, env_fn) in zip(
|
||||
self.parent_remote, self.child_remote, env_fns)
|
||||
]
|
||||
@ -147,13 +136,27 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
for p, a in zip(self.parent_remote, action):
|
||||
p.send(['step', a])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
obs, rew, done, info = zip(*result)
|
||||
return np.stack(obs), np.stack(rew), np.stack(done), np.stack(info)
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
self._obs = np.stack(self._obs)
|
||||
self._rew = np.stack(self._rew)
|
||||
self._done = np.stack(self._done)
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def reset(self):
|
||||
def reset(self, id=None):
|
||||
if id is None:
|
||||
for p in self.parent_remote:
|
||||
p.send(['reset', None])
|
||||
return np.stack([p.recv() for p in self.parent_remote])
|
||||
self._obs = np.stack([p.recv() for p in self.parent_remote])
|
||||
return self._obs
|
||||
else:
|
||||
if np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self.parent_remote[i].send(['reset', None])
|
||||
for i in id:
|
||||
self._obs[i] = self.parent_remote[i].recv()
|
||||
return self._obs
|
||||
|
||||
def seed(self, seed=None):
|
||||
if np.isscalar(seed):
|
||||
@ -162,14 +165,12 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
seed = [seed] * self.env_num
|
||||
for p, s in zip(self.parent_remote, seed):
|
||||
p.send(['seed', s])
|
||||
for p in self.parent_remote:
|
||||
p.recv()
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def render(self, **kwargs):
|
||||
for p in self.parent_remote:
|
||||
p.send(['render', kwargs])
|
||||
for p in self.parent_remote:
|
||||
p.recv()
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def close(self):
|
||||
if self.closed:
|
||||
@ -184,8 +185,8 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
"""docstring for RayVectorEnv"""
|
||||
|
||||
def __init__(self, env_fns, reset_after_done=False):
|
||||
super().__init__(env_fns, reset_after_done)
|
||||
def __init__(self, env_fns):
|
||||
super().__init__(env_fns)
|
||||
try:
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
@ -198,35 +199,27 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
|
||||
def step(self, action):
|
||||
assert len(action) == self.env_num
|
||||
result_obj = []
|
||||
for i, e in enumerate(self.envs):
|
||||
if not self.is_reset_after_done() and self._done[i]:
|
||||
result_obj.append(None)
|
||||
else:
|
||||
result_obj.append(e.step.remote(action[i]))
|
||||
result = []
|
||||
for i, r in enumerate(result_obj):
|
||||
if r is None:
|
||||
result.append([
|
||||
self._obs[i], self._rew[i], self._done[i], self._info[i]])
|
||||
else:
|
||||
result.append(ray.get(r))
|
||||
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||
result = [ray.get(r) for r in result_obj]
|
||||
self._obs, self._rew, self._done, self._info = zip(*result)
|
||||
if self.is_reset_after_done() and sum(self._done):
|
||||
self._obs = np.stack(self._obs)
|
||||
index = np.where(self._done)[0]
|
||||
result_obj = []
|
||||
for i in range(len(index)):
|
||||
result_obj.append(self.envs[index[i]].reset.remote())
|
||||
for i in range(len(index)):
|
||||
self._obs[index[i]] = ray.get(result_obj[i])
|
||||
return np.stack(self._obs), np.stack(self._rew),\
|
||||
np.stack(self._done), np.stack(self._info)
|
||||
self._rew = np.stack(self._rew)
|
||||
self._done = np.stack(self._done)
|
||||
self._info = np.stack(self._info)
|
||||
return self._obs, self._rew, self._done, self._info
|
||||
|
||||
def reset(self):
|
||||
self._done = np.zeros(self.env_num)
|
||||
def reset(self, id=None):
|
||||
if id is None:
|
||||
result_obj = [e.reset.remote() for e in self.envs]
|
||||
self._obs = np.stack([ray.get(r) for r in result_obj])
|
||||
else:
|
||||
result_obj = []
|
||||
if np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
result_obj.append(self.envs[i].reset.remote())
|
||||
for _, i in enumerate(id):
|
||||
self._obs[i] = ray.get(result_obj[_])
|
||||
return self._obs
|
||||
|
||||
def seed(self, seed=None):
|
||||
@ -237,15 +230,13 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||
for r in result_obj:
|
||||
ray.get(r)
|
||||
return [ray.get(r) for r in result_obj]
|
||||
|
||||
def render(self, **kwargs):
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return
|
||||
result_obj = [e.render.remote(**kwargs) for e in self.envs]
|
||||
for r in result_obj:
|
||||
ray.get(r)
|
||||
return [ray.get(r) for r in result_obj]
|
||||
|
||||
def close(self):
|
||||
result_obj = [e.close.remote() for e in self.envs]
|
||||
|
||||
@ -13,7 +13,8 @@ class DDPGPolicy(BasePolicy):
|
||||
|
||||
def __init__(self, actor, actor_optim, critic, critic_optim,
|
||||
tau=0.005, gamma=0.99, exploration_noise=0.1,
|
||||
action_range=None, reward_normalization=True):
|
||||
action_range=None, reward_normalization=False,
|
||||
ignore_done=False):
|
||||
super().__init__()
|
||||
if actor is not None:
|
||||
self.actor, self.actor_old = actor, deepcopy(actor)
|
||||
@ -35,6 +36,7 @@ class DDPGPolicy(BasePolicy):
|
||||
self._action_scale = (action_range[1] - action_range[0]) / 2
|
||||
# it is only a little difference to use rand_normal
|
||||
# self.noise = OUNoise()
|
||||
self._rm_done = ignore_done
|
||||
self._rew_norm = reward_normalization
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
@ -60,8 +62,12 @@ class DDPGPolicy(BasePolicy):
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
if self._rew_norm:
|
||||
self._rew_mean = buffer.rew.mean()
|
||||
self._rew_std = buffer.rew.std()
|
||||
bfr = buffer.rew[:len(buffer)]
|
||||
mean, std = bfr.mean(), bfr.std()
|
||||
if std > self.__eps:
|
||||
batch.rew = (batch.rew - mean) / std
|
||||
if self._rm_done:
|
||||
batch.done = batch.done * 0.
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None,
|
||||
@ -72,10 +78,10 @@ class DDPGPolicy(BasePolicy):
|
||||
logits += self._action_bias
|
||||
if eps is None:
|
||||
eps = self._eps
|
||||
# noise = np.random.normal(0, eps, size=logits.shape)
|
||||
# noise = self.noise(logits.shape, eps)
|
||||
# logits += torch.tensor(noise, device=logits.device)
|
||||
if eps > 0:
|
||||
# noise = np.random.normal(0, eps, size=logits.shape)
|
||||
# logits += torch.tensor(noise, device=logits.device)
|
||||
# noise = self.noise(logits.shape, eps)
|
||||
logits += torch.randn(
|
||||
size=logits.shape, device=logits.device) * eps
|
||||
logits = logits.clamp(self._range[0], self._range[1])
|
||||
@ -86,10 +92,8 @@ class DDPGPolicy(BasePolicy):
|
||||
batch, model='actor_old', input='obs_next', eps=0).act)
|
||||
dev = target_q.device
|
||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
||||
if self._rew_norm:
|
||||
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
|
||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
||||
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
|
||||
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
||||
current_q = self.critic(batch.obs, batch.act)
|
||||
critic_loss = F.mse_loss(current_q, target_q)
|
||||
self.critic_optim.zero_grad()
|
||||
|
||||
@ -10,10 +10,9 @@ from tianshou.policy import BasePolicy
|
||||
class DQNPolicy(BasePolicy):
|
||||
"""docstring for DQNPolicy"""
|
||||
|
||||
def __init__(self, model, optim,
|
||||
discount_factor=0.99,
|
||||
estimation_step=1,
|
||||
use_target_network=True):
|
||||
def __init__(self, model, optim, discount_factor=0.99,
|
||||
estimation_step=1, use_target_network=True,
|
||||
target_update_freq=300):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
@ -23,6 +22,8 @@ class DQNPolicy(BasePolicy):
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
self._n_step = estimation_step
|
||||
self._target = use_target_network
|
||||
self._freq = target_update_freq
|
||||
self._cnt = 0
|
||||
if use_target_network:
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
@ -39,7 +40,6 @@ class DQNPolicy(BasePolicy):
|
||||
self.model.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
if self._target:
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
@ -84,6 +84,8 @@ class DQNPolicy(BasePolicy):
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
if self._target and self._cnt % self._freq == 0:
|
||||
self.sync_weight()
|
||||
self.optim.zero_grad()
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
@ -93,4 +95,5 @@ class DQNPolicy(BasePolicy):
|
||||
loss = F.mse_loss(q, r)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
self._cnt += 1
|
||||
return {'loss': loss.detach().cpu().numpy()}
|
||||
|
||||
@ -12,9 +12,10 @@ class SACPolicy(DDPGPolicy):
|
||||
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||
alpha=0.2, action_range=None, reward_normalization=True):
|
||||
alpha=0.2, action_range=None, reward_normalization=False,
|
||||
ignore_done=False):
|
||||
super().__init__(None, None, None, None, tau, gamma, 0,
|
||||
action_range, reward_normalization)
|
||||
action_range, reward_normalization, ignore_done)
|
||||
self.actor, self.actor_optim = actor, actor_optim
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
@ -70,10 +71,8 @@ class SACPolicy(DDPGPolicy):
|
||||
self.critic2_old(batch.obs_next, a_),
|
||||
) - self._alpha * obs_next_result.log_prob
|
||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
||||
if self._rew_norm:
|
||||
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
|
||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
||||
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
|
||||
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
||||
obs_result = self(batch)
|
||||
a = obs_result.act
|
||||
current_q1, current_q1a = self.critic1(
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -12,10 +11,11 @@ class TD3Policy(DDPGPolicy):
|
||||
def __init__(self, actor, actor_optim, critic1, critic1_optim,
|
||||
critic2, critic2_optim, tau=0.005, gamma=0.99,
|
||||
exploration_noise=0.1, policy_noise=0.2, update_actor_freq=2,
|
||||
noise_clip=0.5, action_range=None, reward_normalization=True):
|
||||
super().__init__(actor, actor_optim, None, None,
|
||||
tau, gamma, exploration_noise, action_range,
|
||||
reward_normalization)
|
||||
noise_clip=0.5, action_range=None,
|
||||
reward_normalization=False, ignore_done=False):
|
||||
super().__init__(actor, actor_optim, None, None, tau, gamma,
|
||||
exploration_noise, action_range, reward_normalization,
|
||||
ignore_done)
|
||||
self.critic1, self.critic1_old = critic1, deepcopy(critic1)
|
||||
self.critic1_old.eval()
|
||||
self.critic1_optim = critic1_optim
|
||||
@ -27,7 +27,6 @@ class TD3Policy(DDPGPolicy):
|
||||
self._noise_clip = noise_clip
|
||||
self._cnt = 0
|
||||
self._last = 0
|
||||
self.__eps = np.finfo(np.float32).eps.item()
|
||||
|
||||
def train(self):
|
||||
self.training = True
|
||||
@ -63,10 +62,8 @@ class TD3Policy(DDPGPolicy):
|
||||
self.critic1_old(batch.obs_next, a_),
|
||||
self.critic2_old(batch.obs_next, a_))
|
||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)[:, None]
|
||||
if self._rew_norm:
|
||||
rew = (rew - self._rew_mean) / (self._rew_std + self.__eps)
|
||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)[:, None]
|
||||
target_q = rew + ((1. - done) * self._gamma * target_q).detach()
|
||||
target_q = (rew + (1. - done) * self._gamma * target_q).detach()
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act)
|
||||
critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
|
||||
@ -5,15 +5,6 @@ from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.trainer import test_episode, gather_info
|
||||
|
||||
|
||||
def test(policy, collector, test_fn, epoch, n_episode):
|
||||
collector.reset_env()
|
||||
collector.reset_buffer()
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch)
|
||||
return collector.collect(n_episode=n_episode)
|
||||
|
||||
|
||||
def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, episode_per_test,
|
||||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||
@ -49,8 +40,7 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
for i in range(min(
|
||||
result['n/st'] // collect_per_step,
|
||||
t.total - t.n)):
|
||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||
global_step += 1
|
||||
losses = policy.learn(train_collector.sample(batch_size))
|
||||
for k in result.keys():
|
||||
|
||||
@ -7,7 +7,7 @@ def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch)
|
||||
return collector.collect(n_episode=n_episode)
|
||||
return collector.collect(n_episode=[1] * n_episode)
|
||||
|
||||
|
||||
def gather_info(start_time, train_c, test_c, best_reward):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user