change the step in trainer (#235)
This PR separates the `global_step` into `env_step` and `gradient_step`. In the future, the data from the collecting state will be stored under `env_step`, and the data from the updating state will be stored under `gradient_step`. Others: - add `rew_std` and `best_result` into the monitor - fix network unbounded in `test/continuous/test_sac_with_il.py` and `examples/box2d/bipedal_hardcore_sac.py` - change the dependency of ray to 1.0.0 since ray-project/ray#10134 has been resolved
This commit is contained in:
parent
710966eda7
commit
5ed6c1c7aa
@ -3,7 +3,7 @@ Cheat Sheet
|
||||
|
||||
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
|
||||
|
||||
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
|
||||
By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
|
||||
|
||||
|
||||
.. _network_api:
|
||||
@ -163,7 +163,7 @@ First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`:
|
||||
|
||||
Then, change the network to recurrent-style, for example, :class:`~tianshou.utils.net.common.Recurrent`, :class:`~tianshou.utils.net.continuous.RecurrentActorProb` and :class:`~tianshou.utils.net.continuous.RecurrentCritic`.
|
||||
|
||||
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
|
||||
The above code supports only stacked-observation. If you want to use stacked-action (for Q(stacked-s, stacked-a)), stacked-reward, or other stacked variables, you can add a ``gym.Wrapper`` to modify the state representation. For example, if we add a wrapper that map [s, a] pair to a new state:
|
||||
|
||||
- Before: (s, a, s', r, d) stored in replay buffer, and get stacked s;
|
||||
- After applying wrapper: ([s, a], a, [s', a'], r, d) stored in replay buffer, and get both stacked s and a.
|
||||
|
@ -93,8 +93,7 @@ def test_sac_bipedal(args=get_args()):
|
||||
# model
|
||||
net_a = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = ActorProb(
|
||||
net_a, args.action_shape,
|
||||
args.max_action, args.device
|
||||
net_a, args.action_shape, args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
|
||||
|
2
setup.py
2
setup.py
@ -60,7 +60,7 @@ setup(
|
||||
"flake8",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"ray>=0.8.0",
|
||||
"ray>=1.0.0",
|
||||
"mypy",
|
||||
"pydocstyle",
|
||||
"doc8",
|
||||
|
@ -111,13 +111,12 @@ def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
|
||||
t = time.time() - t
|
||||
ids = Batch(info).env_id
|
||||
print(ids, t)
|
||||
if cls != RayVectorEnv: # ray-project/ray#10134
|
||||
if not (len(ids) == len(res) and np.allclose(sorted(ids), res)
|
||||
and (t < timeout) == (len(res) == num - 1)):
|
||||
pass_check = 0
|
||||
break
|
||||
if not (len(ids) == len(res) and np.allclose(sorted(ids), res)
|
||||
and (t < timeout) == (len(res) == num - 1)):
|
||||
pass_check = 0
|
||||
break
|
||||
total_pass += pass_check
|
||||
assert total_pass >= 1 # should be modified when ray>=0.9.0 release
|
||||
assert total_pass >= 2
|
||||
|
||||
|
||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
|
@ -67,8 +67,7 @@ def test_sac_with_il(args=get_args()):
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape,
|
||||
args.max_action, args.device
|
||||
net, args.action_shape, args.max_action, args.device, unbounded=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
@ -118,9 +117,10 @@ def test_sac_with_il(args=get_args()):
|
||||
policy.eval()
|
||||
if args.task == 'Pendulum-v0':
|
||||
env.spec.reward_threshold = -300 # lower the goal
|
||||
net = Actor(Net(1, args.state_shape, device=args.device),
|
||||
args.action_shape, args.max_action, args.device
|
||||
).to(args.device)
|
||||
net = Actor(
|
||||
Net(1, args.state_shape, device=args.device),
|
||||
args.action_shape, args.max_action, args.device
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='continuous')
|
||||
il_test_collector = Collector(
|
||||
|
@ -214,7 +214,7 @@ class Collector(object):
|
||||
# number of episodes, these envs will be recorded in this list, and
|
||||
# they will not be stepped.
|
||||
finished_env_ids = []
|
||||
reward_total = 0.0
|
||||
rewards = []
|
||||
whole_data = Batch()
|
||||
if isinstance(n_episode, list):
|
||||
assert len(n_episode) == self.get_env_num()
|
||||
@ -311,7 +311,8 @@ class Collector(object):
|
||||
if not (isinstance(n_episode, list)
|
||||
and episode_count[i] >= n_episode[i]):
|
||||
episode_count[i] += 1
|
||||
reward_total += np.sum(self._cached_buf[i].rew, axis=0)
|
||||
rewards.append(self._rew_metric(
|
||||
np.sum(self._cached_buf[i].rew, axis=0)))
|
||||
step_count += len(self._cached_buf[i])
|
||||
if self.buffer is not None:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
@ -361,16 +362,13 @@ class Collector(object):
|
||||
self.collect_step += step_count
|
||||
self.collect_episode += episode_count
|
||||
self.collect_time += duration
|
||||
# average reward across the number of episodes
|
||||
reward_avg = reward_total / episode_count
|
||||
if np.asanyarray(reward_avg).size > 1: # non-scalar reward_avg
|
||||
reward_avg = self._rew_metric(reward_avg) # type: ignore
|
||||
return {
|
||||
"n/ep": episode_count,
|
||||
"n/st": step_count,
|
||||
"v/st": step_count / duration,
|
||||
"v/ep": episode_count / duration,
|
||||
"rew": reward_avg,
|
||||
"rew": np.mean(rewards),
|
||||
"rew_std": np.std(rewards),
|
||||
"len": step_count / episode_count,
|
||||
}
|
||||
|
||||
|
@ -71,8 +71,8 @@ def offpolicy_trainer(
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
env_step, gradient_step = 0, 0
|
||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
@ -86,13 +86,26 @@ def offpolicy_trainer(
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
if train_fn:
|
||||
train_fn(epoch, global_step)
|
||||
train_fn(epoch, env_step)
|
||||
result = train_collector.collect(n_step=collect_per_step)
|
||||
data = {}
|
||||
env_step += int(result["n/st"])
|
||||
data = {
|
||||
"env_step": str(env_step),
|
||||
"rew": f"{result['rew']:.2f}",
|
||||
"len": str(int(result["len"])),
|
||||
"n/ep": str(int(result["n/ep"])),
|
||||
"n/st": str(int(result["n/st"])),
|
||||
"v/ep": f"{result['v/ep']:.2f}",
|
||||
"v/st": f"{result['v/st']:.2f}",
|
||||
}
|
||||
if writer and env_step % log_interval == 0:
|
||||
for k in result.keys():
|
||||
writer.add_scalar(
|
||||
"train/" + k, result[k], global_step=env_step)
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, writer, global_step)
|
||||
epoch, episode_per_test, writer, env_step)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
@ -101,42 +114,38 @@ def offpolicy_trainer(
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result["rew"])
|
||||
test_result["rew"], test_result["rew_std"])
|
||||
else:
|
||||
policy.train()
|
||||
for i in range(update_per_step * min(
|
||||
result["n/st"] // collect_per_step, t.total - t.n)):
|
||||
global_step += collect_per_step
|
||||
gradient_step += 1
|
||||
losses = policy.update(batch_size, train_collector.buffer)
|
||||
for k in result.keys():
|
||||
data[k] = f"{result[k]:.2f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar("train/" + k, result[k],
|
||||
global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f"{stat[k].get():.6f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
if writer and gradient_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
k, stat[k].get(), global_step=gradient_step)
|
||||
t.update(1)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# test
|
||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, writer, global_step)
|
||||
episode_per_test, writer, env_step)
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_reward, best_reward_std = result["rew"], result["rew_std"]
|
||||
best_epoch = epoch
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
|
||||
f"best_reward: {best_reward:.6f} in #{best_epoch}")
|
||||
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
|
||||
f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
|
||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward)
|
||||
return gather_info(start_time, train_collector, test_collector,
|
||||
best_reward, best_reward_std)
|
||||
|
@ -71,8 +71,8 @@ def onpolicy_trainer(
|
||||
|
||||
:return: See :func:`~tianshou.trainer.gather_info`.
|
||||
"""
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1.0
|
||||
env_step, gradient_step = 0, 0
|
||||
best_epoch, best_reward, best_reward_std = -1, -1.0, 0.0
|
||||
stat: Dict[str, MovAvg] = {}
|
||||
start_time = time.time()
|
||||
train_collector.reset_stat()
|
||||
@ -86,13 +86,26 @@ def onpolicy_trainer(
|
||||
) as t:
|
||||
while t.n < t.total:
|
||||
if train_fn:
|
||||
train_fn(epoch, global_step)
|
||||
train_fn(epoch, env_step)
|
||||
result = train_collector.collect(n_episode=collect_per_step)
|
||||
data = {}
|
||||
env_step += int(result["n/st"])
|
||||
data = {
|
||||
"env_step": str(env_step),
|
||||
"rew": f"{result['rew']:.2f}",
|
||||
"len": str(int(result["len"])),
|
||||
"n/ep": str(int(result["n/ep"])),
|
||||
"n/st": str(int(result["n/st"])),
|
||||
"v/ep": f"{result['v/ep']:.2f}",
|
||||
"v/st": f"{result['v/st']:.2f}",
|
||||
}
|
||||
if writer and env_step % log_interval == 0:
|
||||
for k in result.keys():
|
||||
writer.add_scalar(
|
||||
"train/" + k, result[k], global_step=env_step)
|
||||
if test_in_train and stop_fn and stop_fn(result["rew"]):
|
||||
test_result = test_episode(
|
||||
policy, test_collector, test_fn,
|
||||
epoch, episode_per_test, writer, global_step)
|
||||
epoch, episode_per_test, writer, env_step)
|
||||
if stop_fn(test_result["rew"]):
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
@ -101,47 +114,41 @@ def onpolicy_trainer(
|
||||
t.set_postfix(**data)
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result["rew"])
|
||||
test_result["rew"], test_result["rew_std"])
|
||||
else:
|
||||
policy.train()
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer,
|
||||
batch_size=batch_size, repeat=repeat_per_collect)
|
||||
train_collector.reset_buffer()
|
||||
step = 1
|
||||
for v in losses.values():
|
||||
if isinstance(v, list):
|
||||
step = max(step, len(v))
|
||||
global_step += step * collect_per_step
|
||||
for k in result.keys():
|
||||
data[k] = f"{result[k]:.2f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
"train/" + k, result[k], global_step=global_step)
|
||||
step = max([1] + [
|
||||
len(v) for v in losses.values() if isinstance(v, list)])
|
||||
gradient_step += step
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f"{stat[k].get():.6f}"
|
||||
if writer and global_step % log_interval == 0:
|
||||
if writer and gradient_step % log_interval == 0:
|
||||
writer.add_scalar(
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
k, stat[k].get(), global_step=gradient_step)
|
||||
t.update(step)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# test
|
||||
result = test_episode(policy, test_collector, test_fn, epoch,
|
||||
episode_per_test, writer, global_step)
|
||||
episode_per_test, writer, env_step)
|
||||
if best_epoch == -1 or best_reward < result["rew"]:
|
||||
best_reward = result["rew"]
|
||||
best_reward, best_reward_std = result["rew"], result["rew_std"]
|
||||
best_epoch = epoch
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f}, "
|
||||
f"best_reward: {best_reward:.6f} in #{best_epoch}")
|
||||
print(f"Epoch #{epoch}: test_reward: {result['rew']:.6f} ± "
|
||||
f"{result['rew_std']:.6f}, best_reward: {best_reward:.6f} ± "
|
||||
f"{best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(
|
||||
start_time, train_collector, test_collector, best_reward)
|
||||
return gather_info(start_time, train_collector, test_collector,
|
||||
best_reward, best_reward_std)
|
||||
|
@ -39,6 +39,7 @@ def gather_info(
|
||||
train_c: Collector,
|
||||
test_c: Collector,
|
||||
best_reward: float,
|
||||
best_reward_std: float,
|
||||
) -> Dict[str, Union[float, str]]:
|
||||
"""A simple wrapper of gathering information from collectors.
|
||||
|
||||
@ -72,5 +73,6 @@ def gather_info(
|
||||
"test_time": f"{test_c.collect_time:.2f}s",
|
||||
"test_speed": f"{test_speed:.2f} step/s",
|
||||
"best_reward": best_reward,
|
||||
"best_result": f"{best_reward:.2f} ± {best_reward_std:.2f}",
|
||||
"duration": f"{duration:.2f}s",
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user