add rllib result and fix pep8
This commit is contained in:
parent
77068af526
commit
c42990c725
28
README.md
28
README.md
@ -37,7 +37,7 @@ pip3 install tianshou
|
||||
|
||||
The tutorials and API documentation are hosted on [https://tianshou.readthedocs.io](https://tianshou.readthedocs.io). It is under construction currently.
|
||||
|
||||
The example scripts are under [test/discrete](/test/discrete) (CartPole) and [test/continuous](/test/continuous) (Pendulum).
|
||||
The example scripts are under [test/](/test/) folder and [examples/](/examples/) folder.
|
||||
|
||||
## Why Tianshou?
|
||||
|
||||
@ -49,22 +49,26 @@ Tianshou is a lightweight but high-speed reinforcement learning platform. For ex
|
||||
|
||||
We select some of famous (>1k stars) reinforcement learning platforms. Here is the benchmark result for other algorithms and platforms on toy scenarios:
|
||||
|
||||
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
||||
| ---------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||
| Algo \ Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||
| PG - CartPole | 9.03±4.18s | None | | None | |
|
||||
| DQN - CartPole | 20.94±11.38s | 1046.34±291.27s | | 175.55±53.81s | |
|
||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | | Runtime Error | |
|
||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | | 29.16±15.46s | |
|
||||
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | | 652.83±471.28s | 172.18±62.48s |
|
||||
| TD3 - Pendulum | 48.39±7.22s | None | | 619.33±324.97s | 210.31±76.30s |
|
||||
| SAC - Pendulum | 38.92±2.09s | None | | 808.21±405.70s | 295.92±140.85s |
|
||||
| RL Platform | [Tianshou](https://github.com/thu-ml/tianshou) | [Baselines](https://github.com/openai/baselines) | [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [PyTorch DRL](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch) | [rlpyt](https://github.com/astooke/rlpyt) |
|
||||
| --------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
|
||||
| GitHub Stars | [](https://github.com/thu-ml/tianshou/stargazers) | [](https://github.com/openai/baselines/stargazers) | [](https://github.com/ray-project/ray/stargazers) | [](https://github.com/p-christ/Deep-Reinforcement-Learning-Algorithms-with-PyTorch/stargazers) | [](https://github.com/astooke/rlpyt/stargazers) |
|
||||
| Algo - Task | PyTorch | TensorFlow | TF/PyTorch | PyTorch | PyTorch |
|
||||
| PG - CartPole | 9.03±4.18s | None | 15.77±6.28s | None | |
|
||||
| DQN - CartPole | 10.61±5.51s | 1046.34±291.27s | 40.16±12.79s | 175.55±53.81s | |
|
||||
| A2C - CartPole | 11.72±3.85s | *(~1612s) | 46.15±6.64s | Runtime Error | |
|
||||
| PPO - CartPole | 35.25±16.47s | *(~1179s) | 62.21±13.31s (APPO) | 29.16±15.46s | |
|
||||
| DDPG - Pendulum | 46.95±24.31s | *(>1h) | 377.99±13.79s | 652.83±471.28s | 172.18±62.48s |
|
||||
| TD3 - Pendulum | 48.39±7.22s | None | 620.83±248.43s | 619.33±324.97s | 210.31±76.30s |
|
||||
| SAC - Pendulum | 38.92±2.09s | None | 92.68±4.48s | 808.21±405.70s | 295.92±140.85s |
|
||||
|
||||
*: Could not reach the target reward threshold in 1e6 steps in any of 10 runs. The total runtime is in the brackets.
|
||||
|
||||
All of the platforms use 10 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns.
|
||||
|
||||
Tianshou and RLlib's configures are very similar. They both use multiple workers for sampling. Indeed, both RLlib and rlpyt are excellent reinforcement learning platform :)
|
||||
|
||||
We will add results of Atari Pong / Mujoco these days.
|
||||
|
||||
### Reproducible
|
||||
|
||||
Tianshou has unit tests. Different from other platforms, **the unit tests include the full agent training procedure for all of the implemented algorithms**. It will be failed once it cannot train an agent to perform well enough on limited epochs on toy scenarios. The unit tests secure the reproducibility of our platform.
|
||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from continuous_net import Actor, Critic
|
||||
else: # pytest
|
||||
from test.continuous.net import Actor, Critic
|
||||
from continuous_net import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from continuous_net import ActorProb, Critic
|
||||
else: # pytest
|
||||
from test.continuous.net import ActorProb, Critic
|
||||
from continuous_net import ActorProb, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from continuous_net import Actor, Critic
|
||||
else: # pytest
|
||||
from test.continuous.net import Actor, Critic
|
||||
from continuous_net import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -10,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from continuous_net import Actor, Critic
|
||||
else: # pytest
|
||||
from test.continuous.net import Actor, Critic
|
||||
from continuous_net import Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
|
@ -1,4 +1,3 @@
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env.atari import create_atari_environment
|
||||
|
||||
if __name__ == '__main__':
|
||||
from discrete_net import Net, Actor, Critic
|
||||
else: # pytest
|
||||
from test.discrete.net import Net, Actor, Critic
|
||||
from discrete_net import Net, Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -48,17 +44,20 @@ def get_args():
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
|
||||
env = create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
||||
range(args.training_num)])
|
||||
[lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
||||
range(args.test_num)])
|
||||
[lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -91,7 +90,8 @@ def test_a2c(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,4 +1,3 @@
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
@ -11,10 +10,7 @@ from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env.atari import create_atari_environment
|
||||
|
||||
if __name__ == '__main__':
|
||||
from discrete_net import DQN
|
||||
else: # pytest
|
||||
from test.discrete.net import DQN
|
||||
from discrete_net import DQN
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -49,18 +45,22 @@ def test_dqn(args=get_args()):
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(args.task) for _ in range(args.training_num)])
|
||||
train_envs = SubprocVectorEnv([
|
||||
lambda: create_atari_environment(args.task)
|
||||
for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(args.task) for _ in range(args.test_num)])
|
||||
test_envs = SubprocVectorEnv([
|
||||
lambda: create_atari_environment(
|
||||
args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = DQN(args.state_shape[0], args.state_shape[1], args.action_shape, args.device)
|
||||
net = DQN(
|
||||
args.state_shape[0], args.state_shape[1],
|
||||
args.action_shape, args.device)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
|
@ -1,4 +1,3 @@
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
@ -11,10 +10,7 @@ from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env.atari import create_atari_environment
|
||||
|
||||
if __name__ == '__main__':
|
||||
from discrete_net import Net, Actor, Critic
|
||||
else: # pytest
|
||||
from test.discrete.net import Net, Actor, Critic
|
||||
from discrete_net import Net, Actor, Critic
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -48,17 +44,18 @@ def get_args():
|
||||
|
||||
|
||||
def test_ppo(args=get_args()):
|
||||
env = create_atari_environment(args.task, max_episode_steps=args.max_episode_steps)
|
||||
env = create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space().shape or env.action_space().n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
||||
range(args.training_num)])
|
||||
train_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: create_atari_environment(args.task, max_episode_steps=args.max_episode_steps) for _ in
|
||||
range(args.test_num)])
|
||||
test_envs = SubprocVectorEnv([lambda: create_atari_environment(
|
||||
args.task, max_episode_steps=args.max_episode_steps)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -95,7 +92,8 @@ def test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
|
2
setup.py
2
setup.py
@ -55,7 +55,7 @@ setup(
|
||||
],
|
||||
'atari': [
|
||||
'atari_py',
|
||||
'cv2'
|
||||
'cv2',
|
||||
],
|
||||
'mujoco': [
|
||||
'mujoco_py',
|
||||
|
@ -97,7 +97,8 @@ def _test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -51,7 +51,6 @@ class Critic(nn.Module):
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
|
||||
def __init__(self, h, w, action_shape, device='cpu'):
|
||||
super(DQN, self).__init__()
|
||||
self.device = device
|
||||
@ -73,7 +72,7 @@ class DQN(nn.Module):
|
||||
|
||||
def forward(self, x, state=None, info={}):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
s = torch.tensor(x, device=self.device, dtype=torch.float)
|
||||
x = torch.tensor(x, device=self.device, dtype=torch.float)
|
||||
x = F.relu(self.bn1(self.conv1(x)))
|
||||
x = F.relu(self.bn2(self.conv2(x)))
|
||||
x = F.relu(self.bn3(self.conv3(x)))
|
||||
|
@ -33,7 +33,6 @@ def get_args():
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
@ -84,7 +83,8 @@ def test_a2c(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -25,7 +25,7 @@ def get_args():
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=320)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
@ -72,7 +72,6 @@ def test_dqn(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
print(len(train_collector.buffer))
|
||||
# log
|
||||
writer = SummaryWriter(args.logdir + '/' + 'ppo')
|
||||
|
||||
|
@ -131,7 +131,8 @@ def test_pg(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -88,7 +88,8 @@ def test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer, task=args.task)
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer,
|
||||
task=args.task)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
@ -35,9 +35,10 @@ class Batch(object):
|
||||
elif isinstance(batch.__dict__[k], list):
|
||||
self.__dict__[k] += batch.__dict__[k]
|
||||
else:
|
||||
raise TypeError(
|
||||
'No support for append with type {} in class Batch.'
|
||||
.format(type(batch.__dict__[k])))
|
||||
s = 'No support for append with type'\
|
||||
+ str(type(batch.__dict__[k]))\
|
||||
+ 'in class Batch.'
|
||||
raise TypeError(s)
|
||||
|
||||
def split(self, size=None, permute=True):
|
||||
length = min([
|
||||
|
@ -19,7 +19,7 @@ class Collector(object):
|
||||
self.collect_episode = 0
|
||||
self.collect_time = 0
|
||||
if buffer is None:
|
||||
self.buffer = ReplayBuffer(20000)
|
||||
self.buffer = ReplayBuffer(100)
|
||||
else:
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
@ -100,7 +100,8 @@ class Collector(object):
|
||||
while True:
|
||||
if warning_count >= 100000:
|
||||
warnings.warn(
|
||||
'There are already many steps in an episode. You should add a time limitation to your environment!',
|
||||
'There are already many steps in an episode. '
|
||||
'You should add a time limitation to your environment!',
|
||||
Warning)
|
||||
if self._multi_env:
|
||||
batch_data = Batch(
|
||||
|
@ -13,8 +13,8 @@ class OUNoise(object):
|
||||
def __call__(self, size, mu=.1):
|
||||
if self.x is None or self.x.shape != size:
|
||||
self.x = 0
|
||||
self.x = self.x + self.alpha * (mu - self.x) + \
|
||||
self.beta * np.random.normal(size=size)
|
||||
r = self.beta * np.random.normal(size=size)
|
||||
self.x = self.x + self.alpha * (mu - self.x) + r
|
||||
return self.x
|
||||
|
||||
def reset(self):
|
||||
|
@ -34,9 +34,6 @@ class PGPolicy(BasePolicy):
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
losses = []
|
||||
|
||||
batch.returns = (batch.returns - batch.returns.mean()) \
|
||||
/ (batch.returns.std() + self._eps)
|
||||
r = batch.returns
|
||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||
for _ in range(repeat):
|
||||
|
@ -58,9 +58,6 @@ class PPOPolicy(PGPolicy):
|
||||
|
||||
def learn(self, batch, batch_size=None, repeat=1):
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
|
||||
batch.returns = (batch.returns - batch.returns.mean()) \
|
||||
/ (batch.returns.std() + self._eps)
|
||||
r = batch.returns
|
||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
||||
batch.act = torch.tensor(batch.act)
|
||||
|
@ -47,7 +47,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
if writer:
|
||||
writer.add_scalar(
|
||||
k + '_' + task, result[k], global_step=global_step)
|
||||
k + '_' + task if task else k,
|
||||
result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
@ -55,7 +56,8 @@ def offpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
if writer:
|
||||
writer.add_scalar(
|
||||
k + '_' + task, stat[k].get(), global_step=global_step)
|
||||
k + '_' + task if task else k,
|
||||
stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
|
@ -52,7 +52,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
if writer:
|
||||
writer.add_scalar(
|
||||
k + '_' + task, result[k], global_step=global_step)
|
||||
k + '_' + task if task else k,
|
||||
result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
@ -60,7 +61,8 @@ def onpolicy_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
if writer and global_step:
|
||||
writer.add_scalar(
|
||||
k + '_' + task, stat[k].get(), global_step=global_step)
|
||||
k + '_' + task if task else k,
|
||||
stat[k].get(), global_step=global_step)
|
||||
t.update(step)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
|
Loading…
x
Reference in New Issue
Block a user