add trainer
This commit is contained in:
parent
9c5417dd51
commit
c87fe3c18c
138
test/test_a2c.py
138
test/test_a2c.py
@ -1,6 +1,4 @@
|
|||||||
import gym
|
import gym
|
||||||
import time
|
|
||||||
import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,12 +7,12 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from tianshou.policy import A2CPolicy
|
from tianshou.policy import A2CPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils import tqdm_config, MovAvg
|
from tianshou.trainer import episodic_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Module):
|
class Net(nn.Module):
|
||||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
def __init__(self, layer_num, state_shape, device='cpu'):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = [
|
self.model = [
|
||||||
@ -22,18 +20,40 @@ class Net(nn.Module):
|
|||||||
nn.ReLU(inplace=True)]
|
nn.ReLU(inplace=True)]
|
||||||
for i in range(layer_num):
|
for i in range(layer_num):
|
||||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
self.actor = self.model + [nn.Linear(128, np.prod(action_shape))]
|
self.model = nn.Sequential(*self.model)
|
||||||
self.critic = self.model + [nn.Linear(128, 1)]
|
|
||||||
self.actor = nn.Sequential(*self.actor)
|
|
||||||
self.critic = nn.Sequential(*self.critic)
|
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s):
|
||||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
s = s.view(batch, -1)
|
s = s.view(batch, -1)
|
||||||
logits = self.actor(s)
|
logits = self.model(s)
|
||||||
value = self.critic(s)
|
return logits
|
||||||
return logits, value, None
|
|
||||||
|
|
||||||
|
class Actor(nn.Module):
|
||||||
|
def __init__(self, preprocess_net, action_shape):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(*[
|
||||||
|
preprocess_net,
|
||||||
|
nn.Linear(128, np.prod(action_shape)),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, s, **kwargs):
|
||||||
|
logits = self.model(s)
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
|
||||||
|
class Critic(nn.Module):
|
||||||
|
def __init__(self, preprocess_net):
|
||||||
|
super().__init__()
|
||||||
|
self.model = nn.Sequential(*[
|
||||||
|
preprocess_net,
|
||||||
|
nn.Linear(128, 1),
|
||||||
|
])
|
||||||
|
|
||||||
|
def forward(self, s):
|
||||||
|
logits = self.model(s)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@ -80,83 +100,45 @@ def test_a2c(args=get_args()):
|
|||||||
train_envs.seed(args.seed)
|
train_envs.seed(args.seed)
|
||||||
test_envs.seed(args.seed)
|
test_envs.seed(args.seed)
|
||||||
# model
|
# model
|
||||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
net = Net(args.layer_num, args.state_shape, args.device)
|
||||||
net = net.to(args.device)
|
actor = Actor(net, args.action_shape).to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
critic = Critic(net).to(args.device)
|
||||||
|
optim = torch.optim.Adam(list(
|
||||||
|
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = A2CPolicy(
|
policy = A2CPolicy(
|
||||||
net, optim, dist, args.gamma,
|
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
|
||||||
vf_coef=args.vf_coef,
|
entropy_coef=args.entropy_coef, max_grad_norm=args.max_grad_norm)
|
||||||
entropy_coef=args.entropy_coef,
|
|
||||||
max_grad_norm=args.max_grad_norm)
|
|
||||||
# collector
|
# collector
|
||||||
training_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||||
# log
|
# log
|
||||||
stat_loss = MovAvg()
|
|
||||||
global_step = 0
|
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
best_epoch = -1
|
|
||||||
best_reward = -1e10
|
def stop_fn(x):
|
||||||
start_time = time.time()
|
return x >= env.spec.reward_threshold
|
||||||
for epoch in range(1, 1 + args.epoch):
|
|
||||||
desc = f'Epoch #{epoch}'
|
# trainer
|
||||||
# train
|
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||||
policy.train()
|
episodic_trainer(
|
||||||
with tqdm.tqdm(
|
policy, train_collector, test_collector, args.epoch,
|
||||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
while t.n < t.total:
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
result = training_collector.collect(
|
assert stop_fn(best_rew)
|
||||||
n_episode=args.collect_per_step)
|
train_collector.close()
|
||||||
losses = policy.learn(
|
|
||||||
training_collector.sample(0), args.batch_size)
|
|
||||||
training_collector.reset_buffer()
|
|
||||||
global_step += len(losses)
|
|
||||||
t.update(len(losses))
|
|
||||||
stat_loss.add(losses)
|
|
||||||
writer.add_scalar(
|
|
||||||
'reward', result['reward'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'length', result['length'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'loss', stat_loss.get(), global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'speed', result['speed'], global_step=global_step)
|
|
||||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
|
||||||
reward=f'{result["reward"]:.6f}',
|
|
||||||
length=f'{result["length"]:.2f}',
|
|
||||||
speed=f'{result["speed"]:.2f}')
|
|
||||||
if t.n <= t.total:
|
|
||||||
t.update()
|
|
||||||
# eval
|
|
||||||
test_collector.reset_env()
|
|
||||||
test_collector.reset_buffer()
|
|
||||||
policy.eval()
|
|
||||||
result = test_collector.collect(n_episode=args.test_num)
|
|
||||||
if best_reward < result['reward']:
|
|
||||||
best_reward = result['reward']
|
|
||||||
best_epoch = epoch
|
|
||||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
|
||||||
if best_reward >= env.spec.reward_threshold:
|
|
||||||
break
|
|
||||||
assert best_reward >= env.spec.reward_threshold
|
|
||||||
training_collector.close()
|
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_cnt = training_collector.collect_step
|
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||||
test_cnt = test_collector.collect_step
|
f'training and {test_step} frame / {test_episode} episode during'
|
||||||
duration = time.time() - start_time
|
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||||
f'in {duration:.2f}s, '
|
|
||||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
test_collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
result = collector.collect(n_episode=1, render=1 / 35)
|
||||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
test_collector.close()
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import gym
|
import gym
|
||||||
import time
|
|
||||||
import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,7 +6,7 @@ from torch import nn
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import DDPGPolicy
|
from tianshou.policy import DDPGPolicy
|
||||||
from tianshou.utils import tqdm_config, MovAvg
|
from tianshou.trainer import step_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
|
|
||||||
@ -121,85 +119,39 @@ def test_ddpg(args=get_args()):
|
|||||||
[env.action_space.low[0], env.action_space.high[0]],
|
[env.action_space.low[0], env.action_space.high[0]],
|
||||||
args.tau, args.gamma, args.exploration_noise)
|
args.tau, args.gamma, args.exploration_noise)
|
||||||
# collector
|
# collector
|
||||||
training_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
||||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||||
# log
|
# log
|
||||||
stat_a_loss = MovAvg()
|
|
||||||
stat_c_loss = MovAvg()
|
|
||||||
global_step = 0
|
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
best_epoch = -1
|
|
||||||
best_reward = -1e10
|
def stop_fn(x):
|
||||||
start_time = time.time()
|
|
||||||
# training_collector.collect(n_step=1000)
|
|
||||||
for epoch in range(1, 1 + args.epoch):
|
|
||||||
desc = f'Epoch #{epoch}'
|
|
||||||
# train
|
|
||||||
policy.train()
|
|
||||||
with tqdm.tqdm(
|
|
||||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
|
||||||
while t.n < t.total:
|
|
||||||
result = training_collector.collect(
|
|
||||||
n_step=args.collect_per_step)
|
|
||||||
for i in range(min(
|
|
||||||
result['n_step'] // args.collect_per_step,
|
|
||||||
t.total - t.n)):
|
|
||||||
t.update(1)
|
|
||||||
global_step += 1
|
|
||||||
actor_loss, critic_loss = policy.learn(
|
|
||||||
training_collector.sample(args.batch_size))
|
|
||||||
policy.sync_weight()
|
|
||||||
stat_a_loss.add(actor_loss)
|
|
||||||
stat_c_loss.add(critic_loss)
|
|
||||||
writer.add_scalar(
|
|
||||||
'reward', result['reward'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'length', result['length'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'actor_loss', stat_a_loss.get(),
|
|
||||||
global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'critic_loss', stat_a_loss.get(),
|
|
||||||
global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'speed', result['speed'], global_step=global_step)
|
|
||||||
t.set_postfix(actor_loss=f'{stat_a_loss.get():.6f}',
|
|
||||||
critic_loss=f'{stat_c_loss.get():.6f}',
|
|
||||||
reward=f'{result["reward"]:.6f}',
|
|
||||||
length=f'{result["length"]:.2f}',
|
|
||||||
speed=f'{result["speed"]:.2f}')
|
|
||||||
if t.n <= t.total:
|
|
||||||
t.update()
|
|
||||||
# eval
|
|
||||||
test_collector.reset_env()
|
|
||||||
test_collector.reset_buffer()
|
|
||||||
policy.eval()
|
|
||||||
result = test_collector.collect(n_episode=args.test_num)
|
|
||||||
if best_reward < result['reward']:
|
|
||||||
best_reward = result['reward']
|
|
||||||
best_epoch = epoch
|
|
||||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
|
||||||
if args.task == 'Pendulum-v0' and best_reward >= -250:
|
|
||||||
break
|
|
||||||
if args.task == 'Pendulum-v0':
|
if args.task == 'Pendulum-v0':
|
||||||
assert best_reward >= -250
|
return x >= -250
|
||||||
training_collector.close()
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# trainer
|
||||||
|
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||||
|
step_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
|
if args.task == 'Pendulum-v0':
|
||||||
|
assert stop_fn(best_rew)
|
||||||
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_cnt = training_collector.collect_step
|
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||||
test_cnt = test_collector.collect_step
|
f'training and {test_step} frame / {test_episode} episode during'
|
||||||
duration = time.time() - start_time
|
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||||
f'in {duration:.2f}s, '
|
|
||||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
test_collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
result = collector.collect(n_episode=1, render=1 / 35)
|
||||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
test_collector.close()
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import gym
|
import gym
|
||||||
import time
|
|
||||||
import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from tianshou.policy import DQNPolicy
|
from tianshou.policy import DQNPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils import tqdm_config, MovAvg
|
from tianshou.trainer import step_trainer
|
||||||
from tianshou.data import Collector, ReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
@ -80,79 +78,45 @@ def test_dqn(args=get_args()):
|
|||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
|
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
|
||||||
# collector
|
# collector
|
||||||
training_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||||
training_collector.collect(n_step=args.batch_size)
|
train_collector.collect(n_step=args.batch_size)
|
||||||
# log
|
# log
|
||||||
stat_loss = MovAvg()
|
|
||||||
global_step = 0
|
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
best_epoch = -1
|
|
||||||
best_reward = -1e10
|
def stop_fn(x):
|
||||||
start_time = time.time()
|
return x >= env.spec.reward_threshold
|
||||||
for epoch in range(1, 1 + args.epoch):
|
|
||||||
desc = f"Epoch #{epoch}"
|
def train_fn(x):
|
||||||
# train
|
|
||||||
policy.train()
|
|
||||||
policy.sync_weight()
|
policy.sync_weight()
|
||||||
policy.set_eps(args.eps_train)
|
policy.set_eps(args.eps_train)
|
||||||
with tqdm.tqdm(
|
|
||||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
def test_fn(x):
|
||||||
while t.n < t.total:
|
|
||||||
result = training_collector.collect(
|
|
||||||
n_step=args.collect_per_step)
|
|
||||||
for i in range(min(
|
|
||||||
result['n_step'] // args.collect_per_step,
|
|
||||||
t.total - t.n)):
|
|
||||||
t.update(1)
|
|
||||||
global_step += 1
|
|
||||||
loss = policy.learn(
|
|
||||||
training_collector.sample(args.batch_size))
|
|
||||||
stat_loss.add(loss)
|
|
||||||
writer.add_scalar(
|
|
||||||
'reward', result['reward'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'length', result['length'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'loss', stat_loss.get(), global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'speed', result['speed'], global_step=global_step)
|
|
||||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
|
||||||
reward=f'{result["reward"]:.6f}',
|
|
||||||
length=f'{result["length"]:.2f}',
|
|
||||||
speed=f'{result["speed"]:.2f}')
|
|
||||||
if t.n <= t.total:
|
|
||||||
t.update()
|
|
||||||
# eval
|
|
||||||
test_collector.reset_env()
|
|
||||||
test_collector.reset_buffer()
|
|
||||||
policy.eval()
|
|
||||||
policy.set_eps(args.eps_test)
|
policy.set_eps(args.eps_test)
|
||||||
result = test_collector.collect(n_episode=args.test_num)
|
|
||||||
if best_reward < result['reward']:
|
# trainer
|
||||||
best_reward = result['reward']
|
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||||
best_epoch = epoch
|
step_trainer(
|
||||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
policy, train_collector, test_collector, args.epoch,
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
if best_reward >= env.spec.reward_threshold:
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
break
|
stop_fn=stop_fn, writer=writer)
|
||||||
assert best_reward >= env.spec.reward_threshold
|
|
||||||
training_collector.close()
|
assert stop_fn(best_rew)
|
||||||
|
train_collector.close()
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_cnt = training_collector.collect_step
|
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||||
test_cnt = test_collector.collect_step
|
f'training and {test_step} frame / {test_episode} episode during'
|
||||||
duration = time.time() - start_time
|
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||||
f'in {duration:.2f}s, '
|
|
||||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
test_collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
result = collector.collect(n_episode=1, render=1 / 35)
|
||||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
test_collector.close()
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
import gym
|
import gym
|
||||||
import time
|
import time
|
||||||
import tqdm
|
|
||||||
import torch
|
import torch
|
||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -9,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from tianshou.policy import PGPolicy
|
from tianshou.policy import PGPolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import SubprocVectorEnv
|
||||||
from tianshou.utils import tqdm_config, MovAvg
|
from tianshou.trainer import episodic_trainer
|
||||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
@ -131,73 +130,35 @@ def test_pg(args=get_args()):
|
|||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = PGPolicy(net, optim, dist, args.gamma)
|
policy = PGPolicy(net, optim, dist, args.gamma)
|
||||||
# collector
|
# collector
|
||||||
training_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||||
# log
|
# log
|
||||||
stat_loss = MovAvg()
|
|
||||||
global_step = 0
|
|
||||||
writer = SummaryWriter(args.logdir)
|
writer = SummaryWriter(args.logdir)
|
||||||
best_epoch = -1
|
|
||||||
best_reward = -1e10
|
def stop_fn(x):
|
||||||
start_time = time.time()
|
return x >= env.spec.reward_threshold
|
||||||
for epoch in range(1, 1 + args.epoch):
|
|
||||||
desc = f"Epoch #{epoch}"
|
# trainer
|
||||||
# train
|
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||||
policy.train()
|
episodic_trainer(
|
||||||
with tqdm.tqdm(
|
policy, train_collector, test_collector, args.epoch,
|
||||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
while t.n < t.total:
|
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||||
result = training_collector.collect(
|
assert stop_fn(best_rew)
|
||||||
n_episode=args.collect_per_step)
|
train_collector.close()
|
||||||
losses = policy.learn(
|
|
||||||
training_collector.sample(0), args.batch_size)
|
|
||||||
training_collector.reset_buffer()
|
|
||||||
global_step += len(losses)
|
|
||||||
t.update(len(losses))
|
|
||||||
stat_loss.add(losses)
|
|
||||||
writer.add_scalar(
|
|
||||||
'reward', result['reward'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'length', result['length'], global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'loss', stat_loss.get(), global_step=global_step)
|
|
||||||
writer.add_scalar(
|
|
||||||
'speed', result['speed'], global_step=global_step)
|
|
||||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
|
||||||
reward=f'{result["reward"]:.6f}',
|
|
||||||
length=f'{result["length"]:.2f}',
|
|
||||||
speed=f'{result["speed"]:.2f}')
|
|
||||||
if t.n <= t.total:
|
|
||||||
t.update()
|
|
||||||
# eval
|
|
||||||
test_collector.reset_env()
|
|
||||||
test_collector.reset_buffer()
|
|
||||||
policy.eval()
|
|
||||||
result = test_collector.collect(n_episode=args.test_num)
|
|
||||||
if best_reward < result['reward']:
|
|
||||||
best_reward = result['reward']
|
|
||||||
best_epoch = epoch
|
|
||||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
|
||||||
if best_reward >= env.spec.reward_threshold:
|
|
||||||
break
|
|
||||||
assert best_reward >= env.spec.reward_threshold
|
|
||||||
training_collector.close()
|
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_cnt = training_collector.collect_step
|
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||||
test_cnt = test_collector.collect_step
|
f'training and {test_step} frame / {test_episode} episode during'
|
||||||
duration = time.time() - start_time
|
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||||
f'in {duration:.2f}s, '
|
|
||||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
test_collector = Collector(policy, env)
|
collector = Collector(policy, env)
|
||||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
result = collector.collect(n_episode=1, render=1 / 35)
|
||||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||||
test_collector.close()
|
collector.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from tianshou import data, env, utils, policy, exploration
|
from tianshou import data, env, utils, policy, trainer,\
|
||||||
|
exploration
|
||||||
|
|
||||||
__version__ = '0.2.0'
|
__version__ = '0.2.0'
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -6,5 +7,6 @@ __all__ = [
|
|||||||
'data',
|
'data',
|
||||||
'utils',
|
'utils',
|
||||||
'policy',
|
'policy',
|
||||||
|
'trainer',
|
||||||
'exploration',
|
'exploration',
|
||||||
]
|
]
|
||||||
|
@ -16,6 +16,7 @@ class Collector(object):
|
|||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
self.collect_step = 0
|
self.collect_step = 0
|
||||||
|
self.collect_episode = 0
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
@ -39,9 +40,8 @@ class Collector(object):
|
|||||||
self.reset_buffer()
|
self.reset_buffer()
|
||||||
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
||||||
self.state = None
|
self.state = None
|
||||||
self.stat_reward = MovAvg(stat_size)
|
self.step_speed = MovAvg(stat_size)
|
||||||
self.stat_length = MovAvg(stat_size)
|
self.episode_speed = MovAvg(stat_size)
|
||||||
self.stat_speed = MovAvg(stat_size)
|
|
||||||
|
|
||||||
def reset_buffer(self):
|
def reset_buffer(self):
|
||||||
if self._multi_buf:
|
if self._multi_buf:
|
||||||
@ -81,11 +81,12 @@ class Collector(object):
|
|||||||
|
|
||||||
def collect(self, n_step=0, n_episode=0, render=0):
|
def collect(self, n_step=0, n_episode=0, render=0):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
start_step = self.collect_step
|
|
||||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
cur_step = 0
|
cur_step = 0
|
||||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||||
|
reward_sum = 0
|
||||||
|
length_sum = 0
|
||||||
while True:
|
while True:
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
batch_data = Batch(
|
batch_data = Batch(
|
||||||
@ -126,20 +127,17 @@ class Collector(object):
|
|||||||
elif self._multi_buf:
|
elif self._multi_buf:
|
||||||
self.buffer[i].add(**data)
|
self.buffer[i].add(**data)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
self.collect_step += 1
|
|
||||||
else:
|
else:
|
||||||
self.buffer.add(**data)
|
self.buffer.add(**data)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
self.collect_step += 1
|
|
||||||
if self._done[i]:
|
if self._done[i]:
|
||||||
cur_episode[i] += 1
|
cur_episode[i] += 1
|
||||||
self.stat_reward.add(self.reward[i])
|
reward_sum += self.reward[i]
|
||||||
self.stat_length.add(self.length[i])
|
length_sum += self.length[i]
|
||||||
self.reward[i], self.length[i] = 0, 0
|
self.reward[i], self.length[i] = 0, 0
|
||||||
if self._cached_buf:
|
if self._cached_buf:
|
||||||
self.buffer.update(self._cached_buf[i])
|
self.buffer.update(self._cached_buf[i])
|
||||||
cur_step += len(self._cached_buf[i])
|
cur_step += len(self._cached_buf[i])
|
||||||
self.collect_step += len(self._cached_buf[i])
|
|
||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[i] = None
|
self.state[i] = None
|
||||||
@ -158,11 +156,10 @@ class Collector(object):
|
|||||||
self._obs, self._act[0], self._rew,
|
self._obs, self._act[0], self._rew,
|
||||||
self._done, obs_next, self._info)
|
self._done, obs_next, self._info)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
self.collect_step += 1
|
|
||||||
if self._done:
|
if self._done:
|
||||||
cur_episode += 1
|
cur_episode += 1
|
||||||
self.stat_reward.add(self.reward)
|
reward_sum += self.reward
|
||||||
self.stat_length.add(self.length)
|
length_sum += self.length
|
||||||
self.reward, self.length = 0, 0
|
self.reward, self.length = 0, 0
|
||||||
self.state = None
|
self.state = None
|
||||||
self._obs = self.env.reset()
|
self._obs = self.env.reset()
|
||||||
@ -172,16 +169,20 @@ class Collector(object):
|
|||||||
break
|
break
|
||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
self.stat_speed.add((self.collect_step - start_step) / (
|
|
||||||
time.time() - start_time))
|
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
cur_episode = sum(cur_episode)
|
cur_episode = sum(cur_episode)
|
||||||
|
duration = time.time() - start_time
|
||||||
|
self.step_speed.add(cur_step / duration)
|
||||||
|
self.episode_speed.add(cur_episode / duration)
|
||||||
|
self.collect_step += cur_step
|
||||||
|
self.collect_episode += cur_episode
|
||||||
return {
|
return {
|
||||||
'reward': self.stat_reward.get(),
|
'n/ep': cur_episode,
|
||||||
'length': self.stat_length.get(),
|
'n/st': cur_step,
|
||||||
'speed': self.stat_speed.get(),
|
'speed/st': self.step_speed.get(),
|
||||||
'n_episode': cur_episode,
|
'speed/ep': self.episode_speed.get(),
|
||||||
'n_step': cur_step,
|
'rew': reward_sum / cur_episode,
|
||||||
|
'len': length_sum / cur_episode,
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
|
@ -3,11 +3,13 @@ from tianshou.policy.dqn import DQNPolicy
|
|||||||
from tianshou.policy.pg import PGPolicy
|
from tianshou.policy.pg import PGPolicy
|
||||||
from tianshou.policy.a2c import A2CPolicy
|
from tianshou.policy.a2c import A2CPolicy
|
||||||
from tianshou.policy.ddpg import DDPGPolicy
|
from tianshou.policy.ddpg import DDPGPolicy
|
||||||
|
from tianshou.policy.ppo import PPOPolicy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BasePolicy',
|
'BasePolicy',
|
||||||
'DQNPolicy',
|
'DQNPolicy',
|
||||||
'PGPolicy',
|
'PGPolicy',
|
||||||
'A2CPolicy',
|
'A2CPolicy',
|
||||||
'DDPGPolicy'
|
'DDPGPolicy',
|
||||||
|
'PPOPolicy',
|
||||||
]
|
]
|
||||||
|
@ -9,20 +9,23 @@ from tianshou.policy import PGPolicy
|
|||||||
class A2CPolicy(PGPolicy):
|
class A2CPolicy(PGPolicy):
|
||||||
"""docstring for A2CPolicy"""
|
"""docstring for A2CPolicy"""
|
||||||
|
|
||||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
def __init__(self, actor, critic, optim,
|
||||||
|
dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
|
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
|
||||||
max_grad_norm=None):
|
max_grad_norm=None):
|
||||||
super().__init__(model, optim, dist_fn, discount_factor)
|
super().__init__(None, optim, dist_fn, discount_factor)
|
||||||
|
self.actor = actor
|
||||||
|
self.critic = critic
|
||||||
self._w_value = vf_coef
|
self._w_value = vf_coef
|
||||||
self._w_entropy = entropy_coef
|
self._w_entropy = entropy_coef
|
||||||
self._grad_norm = max_grad_norm
|
self._grad_norm = max_grad_norm
|
||||||
|
|
||||||
def __call__(self, batch, state=None):
|
def __call__(self, batch, state=None):
|
||||||
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
|
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||||
logits = F.softmax(logits, dim=1)
|
logits = F.softmax(logits, dim=1)
|
||||||
dist = self.dist_fn(logits)
|
dist = self.dist_fn(logits)
|
||||||
act = dist.sample()
|
act = dist.sample()
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
def learn(self, batch, batch_size=None):
|
def learn(self, batch, batch_size=None):
|
||||||
losses = []
|
losses = []
|
||||||
@ -30,7 +33,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
result = self(b)
|
result = self(b)
|
||||||
dist = result.dist
|
dist = result.dist
|
||||||
v = result.value
|
v = self.critic(b.obs)
|
||||||
a = torch.tensor(b.act, device=dist.logits.device)
|
a = torch.tensor(b.act, device=dist.logits.device)
|
||||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||||
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||||
@ -45,4 +48,4 @@ class A2CPolicy(PGPolicy):
|
|||||||
self.model.parameters(), max_norm=self._grad_norm)
|
self.model.parameters(), max_norm=self._grad_norm)
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
losses.append(loss.detach().cpu().numpy())
|
losses.append(loss.detach().cpu().numpy())
|
||||||
return losses
|
return {'loss': losses}
|
||||||
|
@ -18,6 +18,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def learn(self, batch, batch_size=None):
|
def learn(self, batch, batch_size=None):
|
||||||
|
# return a dict which includes loss and its name
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sync_weight(self):
|
def sync_weight(self):
|
||||||
|
@ -28,7 +28,7 @@ class DDPGPolicy(BasePolicy):
|
|||||||
self._tau = tau
|
self._tau = tau
|
||||||
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
|
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
|
||||||
self._gamma = gamma
|
self._gamma = gamma
|
||||||
assert 0 <= exploration_noise, 'noise should greater than zero'
|
assert 0 <= exploration_noise, 'noise should not be negative'
|
||||||
self._eps = exploration_noise
|
self._eps = exploration_noise
|
||||||
self._range = action_range
|
self._range = action_range
|
||||||
# self.noise = OUNoise()
|
# self.noise = OUNoise()
|
||||||
@ -87,5 +87,8 @@ class DDPGPolicy(BasePolicy):
|
|||||||
self.actor_optim.zero_grad()
|
self.actor_optim.zero_grad()
|
||||||
actor_loss.backward()
|
actor_loss.backward()
|
||||||
self.actor_optim.step()
|
self.actor_optim.step()
|
||||||
return actor_loss.detach().cpu().numpy(),\
|
self.sync_weight()
|
||||||
critic_loss.detach().cpu().numpy()
|
return {
|
||||||
|
'loss/actor': actor_loss.detach().cpu().numpy(),
|
||||||
|
'loss/critic': critic_loss.detach().cpu().numpy(),
|
||||||
|
}
|
||||||
|
@ -93,4 +93,4 @@ class DQNPolicy(BasePolicy):
|
|||||||
loss = F.mse_loss(q, r)
|
loss = F.mse_loss(q, r)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
return loss.detach().cpu().numpy()
|
return {'loss': loss.detach().cpu().numpy()}
|
||||||
|
@ -45,7 +45,7 @@ class PGPolicy(BasePolicy):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
losses.append(loss.detach().cpu().numpy())
|
losses.append(loss.detach().cpu().numpy())
|
||||||
return losses
|
return {'loss': losses}
|
||||||
|
|
||||||
def _vanilla_returns(self, batch):
|
def _vanilla_returns(self, batch):
|
||||||
returns = batch.rew[:]
|
returns = batch.rew[:]
|
||||||
|
50
tianshou/policy/ppo.py
Normal file
50
tianshou/policy/ppo.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
|
from tianshou.policy import PGPolicy
|
||||||
|
|
||||||
|
|
||||||
|
class PPOPolicy(PGPolicy):
|
||||||
|
"""docstring for PPOPolicy"""
|
||||||
|
|
||||||
|
def __init__(self, actor, actor_optim,
|
||||||
|
critic, critic_optim,
|
||||||
|
dist_fn=torch.distributions.Categorical,
|
||||||
|
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
|
||||||
|
eps_clip=None):
|
||||||
|
super().__init__(None, None, dist_fn, discount_factor)
|
||||||
|
self._w_value = vf_coef
|
||||||
|
self._w_entropy = entropy_coef
|
||||||
|
self._eps_clip = eps_clip
|
||||||
|
|
||||||
|
def __call__(self, batch, state=None):
|
||||||
|
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
|
logits = F.softmax(logits, dim=1)
|
||||||
|
dist = self.dist_fn(logits)
|
||||||
|
act = dist.sample()
|
||||||
|
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
|
||||||
|
|
||||||
|
def learn(self, batch, batch_size=None):
|
||||||
|
losses = []
|
||||||
|
for b in batch.split(batch_size):
|
||||||
|
self.optim.zero_grad()
|
||||||
|
result = self(b)
|
||||||
|
dist = result.dist
|
||||||
|
v = result.value
|
||||||
|
a = torch.tensor(b.act, device=dist.logits.device)
|
||||||
|
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||||
|
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||||
|
critic_loss = F.mse_loss(r[:, None], v)
|
||||||
|
entropy_loss = dist.entropy().mean()
|
||||||
|
loss = actor_loss \
|
||||||
|
+ self._w_value * critic_loss \
|
||||||
|
- self._w_entropy * entropy_loss
|
||||||
|
loss.backward()
|
||||||
|
if self._grad_norm:
|
||||||
|
nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(), max_norm=self._grad_norm)
|
||||||
|
self.optim.step()
|
||||||
|
losses.append(loss.detach().cpu().numpy())
|
||||||
|
return losses
|
7
tianshou/trainer/__init__.py
Normal file
7
tianshou/trainer/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from tianshou.trainer.episodic import episodic_trainer
|
||||||
|
from tianshou.trainer.step import step_trainer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'episodic_trainer',
|
||||||
|
'step_trainer',
|
||||||
|
]
|
68
tianshou/trainer/episodic.py
Normal file
68
tianshou/trainer/episodic.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from tianshou.utils import tqdm_config, MovAvg
|
||||||
|
|
||||||
|
|
||||||
|
def episodic_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
|
step_per_epoch, collect_per_step, episode_per_test,
|
||||||
|
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||||
|
writer=None, verbose=True):
|
||||||
|
global_step = 0
|
||||||
|
best_epoch, best_reward = -1, -1
|
||||||
|
stat = {}
|
||||||
|
start_time = time.time()
|
||||||
|
for epoch in range(1, 1 + max_epoch):
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
|
if train_fn:
|
||||||
|
train_fn(epoch)
|
||||||
|
with tqdm.tqdm(
|
||||||
|
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
|
**tqdm_config) as t:
|
||||||
|
while t.n < t.total:
|
||||||
|
result = train_collector.collect(n_episode=collect_per_step)
|
||||||
|
losses = policy.learn(train_collector.sample(0), batch_size)
|
||||||
|
train_collector.reset_buffer()
|
||||||
|
step = 1
|
||||||
|
data = {}
|
||||||
|
for k in losses.keys():
|
||||||
|
if isinstance(losses[k], list):
|
||||||
|
step = max(step, len(losses[k]))
|
||||||
|
global_step += step
|
||||||
|
for k in result.keys():
|
||||||
|
data[k] = f'{result[k]:.2f}'
|
||||||
|
if writer:
|
||||||
|
writer.add_scalar(
|
||||||
|
k, result[k], global_step=global_step)
|
||||||
|
for k in losses.keys():
|
||||||
|
if stat.get(k) is None:
|
||||||
|
stat[k] = MovAvg()
|
||||||
|
stat[k].add(losses[k])
|
||||||
|
data[k] = f'{stat[k].get():.6f}'
|
||||||
|
if writer:
|
||||||
|
writer.add_scalar(
|
||||||
|
k, stat[k].get(), global_step=global_step)
|
||||||
|
t.update(step)
|
||||||
|
t.set_postfix(**data)
|
||||||
|
if t.n <= t.total:
|
||||||
|
t.update()
|
||||||
|
# eval
|
||||||
|
test_collector.reset_env()
|
||||||
|
test_collector.reset_buffer()
|
||||||
|
policy.eval()
|
||||||
|
if test_fn:
|
||||||
|
test_fn(epoch)
|
||||||
|
result = test_collector.collect(n_episode=episode_per_test)
|
||||||
|
if best_epoch == -1 or best_reward < result['rew']:
|
||||||
|
best_reward = result['rew']
|
||||||
|
best_epoch = epoch
|
||||||
|
if verbose:
|
||||||
|
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||||
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
|
if stop_fn(best_reward):
|
||||||
|
break
|
||||||
|
duration = time.time() - start_time
|
||||||
|
return train_collector.collect_step, train_collector.collect_episode,\
|
||||||
|
test_collector.collect_step, test_collector.collect_episode,\
|
||||||
|
best_reward, duration
|
66
tianshou/trainer/step.py
Normal file
66
tianshou/trainer/step.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from tianshou.utils import tqdm_config, MovAvg
|
||||||
|
|
||||||
|
|
||||||
|
def step_trainer(policy, train_collector, test_collector, max_epoch,
|
||||||
|
step_per_epoch, collect_per_step, episode_per_test,
|
||||||
|
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||||
|
writer=None, verbose=True):
|
||||||
|
global_step = 0
|
||||||
|
best_epoch, best_reward = -1, -1
|
||||||
|
stat = {}
|
||||||
|
start_time = time.time()
|
||||||
|
for epoch in range(1, 1 + max_epoch):
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
|
if train_fn:
|
||||||
|
train_fn(epoch)
|
||||||
|
with tqdm.tqdm(
|
||||||
|
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
|
**tqdm_config) as t:
|
||||||
|
while t.n < t.total:
|
||||||
|
result = train_collector.collect(n_step=collect_per_step)
|
||||||
|
for i in range(min(
|
||||||
|
result['n/st'] // collect_per_step,
|
||||||
|
t.total - t.n)):
|
||||||
|
global_step += 1
|
||||||
|
losses = policy.learn(train_collector.sample(batch_size))
|
||||||
|
data = {}
|
||||||
|
for k in result.keys():
|
||||||
|
data[k] = f'{result[k]:.2f}'
|
||||||
|
if writer:
|
||||||
|
writer.add_scalar(
|
||||||
|
k, result[k], global_step=global_step)
|
||||||
|
for k in losses.keys():
|
||||||
|
if stat.get(k) is None:
|
||||||
|
stat[k] = MovAvg()
|
||||||
|
stat[k].add(losses[k])
|
||||||
|
data[k] = f'{stat[k].get():.6f}'
|
||||||
|
if writer:
|
||||||
|
writer.add_scalar(
|
||||||
|
k, stat[k].get(), global_step=global_step)
|
||||||
|
t.update(1)
|
||||||
|
t.set_postfix(**data)
|
||||||
|
if t.n <= t.total:
|
||||||
|
t.update()
|
||||||
|
# eval
|
||||||
|
test_collector.reset_env()
|
||||||
|
test_collector.reset_buffer()
|
||||||
|
policy.eval()
|
||||||
|
if test_fn:
|
||||||
|
test_fn(epoch)
|
||||||
|
result = test_collector.collect(n_episode=episode_per_test)
|
||||||
|
if best_epoch == -1 or best_reward < result['rew']:
|
||||||
|
best_reward = result['rew']
|
||||||
|
best_epoch = epoch
|
||||||
|
if verbose:
|
||||||
|
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||||
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
|
if stop_fn(best_reward):
|
||||||
|
break
|
||||||
|
duration = time.time() - start_time
|
||||||
|
return train_collector.collect_step, train_collector.collect_episode,\
|
||||||
|
test_collector.collect_step, test_collector.collect_episode,\
|
||||||
|
best_reward, duration
|
Loading…
x
Reference in New Issue
Block a user