add trainer
This commit is contained in:
parent
9c5417dd51
commit
c87fe3c18c
138
test/test_a2c.py
138
test/test_a2c.py
@ -1,6 +1,4 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -9,12 +7,12 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.trainer import episodic_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||
def __init__(self, layer_num, state_shape, device='cpu'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
@ -22,18 +20,40 @@ class Net(nn.Module):
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.actor = self.model + [nn.Linear(128, np.prod(action_shape))]
|
||||
self.critic = self.model + [nn.Linear(128, 1)]
|
||||
self.actor = nn.Sequential(*self.actor)
|
||||
self.critic = nn.Sequential(*self.critic)
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
def forward(self, s):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.actor(s)
|
||||
value = self.critic(s)
|
||||
return logits, value, None
|
||||
logits = self.model(s)
|
||||
return logits
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, preprocess_net, action_shape):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(*[
|
||||
preprocess_net,
|
||||
nn.Linear(128, np.prod(action_shape)),
|
||||
])
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
logits = self.model(s)
|
||||
return logits, None
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, preprocess_net):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(*[
|
||||
preprocess_net,
|
||||
nn.Linear(128, 1),
|
||||
])
|
||||
|
||||
def forward(self, s):
|
||||
logits = self.model(s)
|
||||
return logits
|
||||
|
||||
|
||||
def get_args():
|
||||
@ -80,83 +100,45 @@ def test_a2c(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
net = Net(args.layer_num, args.state_shape, args.device)
|
||||
actor = Actor(net, args.action_shape).to(args.device)
|
||||
critic = Critic(net).to(args.device)
|
||||
optim = torch.optim.Adam(list(
|
||||
actor.parameters()) + list(critic.parameters()), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = A2CPolicy(
|
||||
net, optim, dist, args.gamma,
|
||||
vf_coef=args.vf_coef,
|
||||
entropy_coef=args.entropy_coef,
|
||||
max_grad_norm=args.max_grad_norm)
|
||||
actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
|
||||
entropy_coef=args.entropy_coef, max_grad_norm=args.max_grad_norm)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f'Epoch #{epoch}'
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_episode=args.collect_per_step)
|
||||
losses = policy.learn(
|
||||
training_collector.sample(0), args.batch_size)
|
||||
training_collector.reset_buffer()
|
||||
global_step += len(losses)
|
||||
t.update(len(losses))
|
||||
stat_loss.add(losses)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if best_reward >= env.spec.reward_threshold:
|
||||
break
|
||||
assert best_reward >= env.spec.reward_threshold
|
||||
training_collector.close()
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||
episodic_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(best_rew)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||
f'training and {test_step} frame / {test_episode} episode during'
|
||||
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,6 +1,4 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -8,7 +6,7 @@ from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.trainer import step_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
|
||||
@ -121,85 +119,39 @@ def test_ddpg(args=get_args()):
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
args.tau, args.gamma, args.exploration_noise)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
# log
|
||||
stat_a_loss = MovAvg()
|
||||
stat_c_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
start_time = time.time()
|
||||
# training_collector.collect(n_step=1000)
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f'Epoch #{epoch}'
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_step=args.collect_per_step)
|
||||
for i in range(min(
|
||||
result['n_step'] // args.collect_per_step,
|
||||
t.total - t.n)):
|
||||
t.update(1)
|
||||
global_step += 1
|
||||
actor_loss, critic_loss = policy.learn(
|
||||
training_collector.sample(args.batch_size))
|
||||
policy.sync_weight()
|
||||
stat_a_loss.add(actor_loss)
|
||||
stat_c_loss.add(critic_loss)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'actor_loss', stat_a_loss.get(),
|
||||
global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'critic_loss', stat_a_loss.get(),
|
||||
global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(actor_loss=f'{stat_a_loss.get():.6f}',
|
||||
critic_loss=f'{stat_c_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if args.task == 'Pendulum-v0' and best_reward >= -250:
|
||||
break
|
||||
|
||||
def stop_fn(x):
|
||||
if args.task == 'Pendulum-v0':
|
||||
return x >= -250
|
||||
else:
|
||||
return False
|
||||
|
||||
# trainer
|
||||
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||
step_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
if args.task == 'Pendulum-v0':
|
||||
assert best_reward >= -250
|
||||
training_collector.close()
|
||||
assert stop_fn(best_rew)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||
f'training and {test_step} frame / {test_episode} episode during'
|
||||
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,6 +1,4 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -9,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.trainer import step_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
|
||||
@ -80,79 +78,45 @@ def test_dqn(args=get_args()):
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
training_collector.collect(n_step=args.batch_size)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f"Epoch #{epoch}"
|
||||
# train
|
||||
policy.train()
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
policy.sync_weight()
|
||||
policy.set_eps(args.eps_train)
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_step=args.collect_per_step)
|
||||
for i in range(min(
|
||||
result['n_step'] // args.collect_per_step,
|
||||
t.total - t.n)):
|
||||
t.update(1)
|
||||
global_step += 1
|
||||
loss = policy.learn(
|
||||
training_collector.sample(args.batch_size))
|
||||
stat_loss.add(loss)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
|
||||
def test_fn(x):
|
||||
policy.set_eps(args.eps_test)
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if best_reward >= env.spec.reward_threshold:
|
||||
break
|
||||
assert best_reward >= env.spec.reward_threshold
|
||||
training_collector.close()
|
||||
|
||||
# trainer
|
||||
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||
step_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, writer=writer)
|
||||
|
||||
assert stop_fn(best_rew)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||
f'training and {test_step} frame / {test_episode} episode during'
|
||||
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,6 +1,5 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -9,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.trainer import episodic_trainer
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
|
||||
|
||||
@ -131,73 +130,35 @@ def test_pg(args=get_args()):
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PGPolicy(net, optim, dist, args.gamma)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f"Epoch #{epoch}"
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_episode=args.collect_per_step)
|
||||
losses = policy.learn(
|
||||
training_collector.sample(0), args.batch_size)
|
||||
training_collector.reset_buffer()
|
||||
global_step += len(losses)
|
||||
t.update(len(losses))
|
||||
stat_loss.add(losses)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if best_reward >= env.spec.reward_threshold:
|
||||
break
|
||||
assert best_reward >= env.spec.reward_threshold
|
||||
training_collector.close()
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
# trainer
|
||||
train_step, train_episode, test_step, test_episode, best_rew, duration = \
|
||||
episodic_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(best_rew)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
print(f'Collect {train_step} frame / {train_episode} episode during '
|
||||
f'training and {test_step} frame / {test_episode} episode during'
|
||||
f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
|
||||
f'{(train_step + test_step) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -1,4 +1,5 @@
|
||||
from tianshou import data, env, utils, policy, exploration
|
||||
from tianshou import data, env, utils, policy, trainer,\
|
||||
exploration
|
||||
|
||||
__version__ = '0.2.0'
|
||||
__all__ = [
|
||||
@ -6,5 +7,6 @@ __all__ = [
|
||||
'data',
|
||||
'utils',
|
||||
'policy',
|
||||
'trainer',
|
||||
'exploration',
|
||||
]
|
||||
|
@ -16,6 +16,7 @@ class Collector(object):
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
self.collect_step = 0
|
||||
self.collect_episode = 0
|
||||
self.buffer = buffer
|
||||
self.policy = policy
|
||||
self.process_fn = policy.process_fn
|
||||
@ -39,9 +40,8 @@ class Collector(object):
|
||||
self.reset_buffer()
|
||||
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
||||
self.state = None
|
||||
self.stat_reward = MovAvg(stat_size)
|
||||
self.stat_length = MovAvg(stat_size)
|
||||
self.stat_speed = MovAvg(stat_size)
|
||||
self.step_speed = MovAvg(stat_size)
|
||||
self.episode_speed = MovAvg(stat_size)
|
||||
|
||||
def reset_buffer(self):
|
||||
if self._multi_buf:
|
||||
@ -81,11 +81,12 @@ class Collector(object):
|
||||
|
||||
def collect(self, n_step=0, n_episode=0, render=0):
|
||||
start_time = time.time()
|
||||
start_step = self.collect_step
|
||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||
"One and only one collection number specification permitted!"
|
||||
cur_step = 0
|
||||
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||
reward_sum = 0
|
||||
length_sum = 0
|
||||
while True:
|
||||
if self._multi_env:
|
||||
batch_data = Batch(
|
||||
@ -126,20 +127,17 @@ class Collector(object):
|
||||
elif self._multi_buf:
|
||||
self.buffer[i].add(**data)
|
||||
cur_step += 1
|
||||
self.collect_step += 1
|
||||
else:
|
||||
self.buffer.add(**data)
|
||||
cur_step += 1
|
||||
self.collect_step += 1
|
||||
if self._done[i]:
|
||||
cur_episode[i] += 1
|
||||
self.stat_reward.add(self.reward[i])
|
||||
self.stat_length.add(self.length[i])
|
||||
reward_sum += self.reward[i]
|
||||
length_sum += self.length[i]
|
||||
self.reward[i], self.length[i] = 0, 0
|
||||
if self._cached_buf:
|
||||
self.buffer.update(self._cached_buf[i])
|
||||
cur_step += len(self._cached_buf[i])
|
||||
self.collect_step += len(self._cached_buf[i])
|
||||
self._cached_buf[i].reset()
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
@ -158,11 +156,10 @@ class Collector(object):
|
||||
self._obs, self._act[0], self._rew,
|
||||
self._done, obs_next, self._info)
|
||||
cur_step += 1
|
||||
self.collect_step += 1
|
||||
if self._done:
|
||||
cur_episode += 1
|
||||
self.stat_reward.add(self.reward)
|
||||
self.stat_length.add(self.length)
|
||||
reward_sum += self.reward
|
||||
length_sum += self.length
|
||||
self.reward, self.length = 0, 0
|
||||
self.state = None
|
||||
self._obs = self.env.reset()
|
||||
@ -172,16 +169,20 @@ class Collector(object):
|
||||
break
|
||||
self._obs = obs_next
|
||||
self._obs = obs_next
|
||||
self.stat_speed.add((self.collect_step - start_step) / (
|
||||
time.time() - start_time))
|
||||
if self._multi_env:
|
||||
cur_episode = sum(cur_episode)
|
||||
duration = time.time() - start_time
|
||||
self.step_speed.add(cur_step / duration)
|
||||
self.episode_speed.add(cur_episode / duration)
|
||||
self.collect_step += cur_step
|
||||
self.collect_episode += cur_episode
|
||||
return {
|
||||
'reward': self.stat_reward.get(),
|
||||
'length': self.stat_length.get(),
|
||||
'speed': self.stat_speed.get(),
|
||||
'n_episode': cur_episode,
|
||||
'n_step': cur_step,
|
||||
'n/ep': cur_episode,
|
||||
'n/st': cur_step,
|
||||
'speed/st': self.step_speed.get(),
|
||||
'speed/ep': self.episode_speed.get(),
|
||||
'rew': reward_sum / cur_episode,
|
||||
'len': length_sum / cur_episode,
|
||||
}
|
||||
|
||||
def sample(self, batch_size):
|
||||
|
@ -3,11 +3,13 @@ from tianshou.policy.dqn import DQNPolicy
|
||||
from tianshou.policy.pg import PGPolicy
|
||||
from tianshou.policy.a2c import A2CPolicy
|
||||
from tianshou.policy.ddpg import DDPGPolicy
|
||||
from tianshou.policy.ppo import PPOPolicy
|
||||
|
||||
__all__ = [
|
||||
'BasePolicy',
|
||||
'DQNPolicy',
|
||||
'PGPolicy',
|
||||
'A2CPolicy',
|
||||
'DDPGPolicy'
|
||||
'DDPGPolicy',
|
||||
'PPOPolicy',
|
||||
]
|
||||
|
@ -9,20 +9,23 @@ from tianshou.policy import PGPolicy
|
||||
class A2CPolicy(PGPolicy):
|
||||
"""docstring for A2CPolicy"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
def __init__(self, actor, critic, optim,
|
||||
dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
|
||||
max_grad_norm=None):
|
||||
super().__init__(model, optim, dist_fn, discount_factor)
|
||||
super().__init__(None, optim, dist_fn, discount_factor)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self._w_value = vf_coef
|
||||
self._w_entropy = entropy_coef
|
||||
self._grad_norm = max_grad_norm
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
losses = []
|
||||
@ -30,7 +33,7 @@ class A2CPolicy(PGPolicy):
|
||||
self.optim.zero_grad()
|
||||
result = self(b)
|
||||
dist = result.dist
|
||||
v = result.value
|
||||
v = self.critic(b.obs)
|
||||
a = torch.tensor(b.act, device=dist.logits.device)
|
||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||
@ -45,4 +48,4 @@ class A2CPolicy(PGPolicy):
|
||||
self.model.parameters(), max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
return losses
|
||||
return {'loss': losses}
|
||||
|
@ -18,6 +18,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, batch, batch_size=None):
|
||||
# return a dict which includes loss and its name
|
||||
pass
|
||||
|
||||
def sync_weight(self):
|
||||
|
@ -28,7 +28,7 @@ class DDPGPolicy(BasePolicy):
|
||||
self._tau = tau
|
||||
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
|
||||
self._gamma = gamma
|
||||
assert 0 <= exploration_noise, 'noise should greater than zero'
|
||||
assert 0 <= exploration_noise, 'noise should not be negative'
|
||||
self._eps = exploration_noise
|
||||
self._range = action_range
|
||||
# self.noise = OUNoise()
|
||||
@ -87,5 +87,8 @@ class DDPGPolicy(BasePolicy):
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
return actor_loss.detach().cpu().numpy(),\
|
||||
critic_loss.detach().cpu().numpy()
|
||||
self.sync_weight()
|
||||
return {
|
||||
'loss/actor': actor_loss.detach().cpu().numpy(),
|
||||
'loss/critic': critic_loss.detach().cpu().numpy(),
|
||||
}
|
||||
|
@ -93,4 +93,4 @@ class DQNPolicy(BasePolicy):
|
||||
loss = F.mse_loss(q, r)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return loss.detach().cpu().numpy()
|
||||
return {'loss': loss.detach().cpu().numpy()}
|
||||
|
@ -45,7 +45,7 @@ class PGPolicy(BasePolicy):
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
return losses
|
||||
return {'loss': losses}
|
||||
|
||||
def _vanilla_returns(self, batch):
|
||||
returns = batch.rew[:]
|
||||
|
50
tianshou/policy/ppo.py
Normal file
50
tianshou/policy/ppo.py
Normal file
@ -0,0 +1,50 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import PGPolicy
|
||||
|
||||
|
||||
class PPOPolicy(PGPolicy):
|
||||
"""docstring for PPOPolicy"""
|
||||
|
||||
def __init__(self, actor, actor_optim,
|
||||
critic, critic_optim,
|
||||
dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
|
||||
eps_clip=None):
|
||||
super().__init__(None, None, dist_fn, discount_factor)
|
||||
self._w_value = vf_coef
|
||||
self._w_entropy = entropy_coef
|
||||
self._eps_clip = eps_clip
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
losses = []
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
result = self(b)
|
||||
dist = result.dist
|
||||
v = result.value
|
||||
a = torch.tensor(b.act, device=dist.logits.device)
|
||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||
critic_loss = F.mse_loss(r[:, None], v)
|
||||
entropy_loss = dist.entropy().mean()
|
||||
loss = actor_loss \
|
||||
+ self._w_value * critic_loss \
|
||||
- self._w_entropy * entropy_loss
|
||||
loss.backward()
|
||||
if self._grad_norm:
|
||||
nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
return losses
|
7
tianshou/trainer/__init__.py
Normal file
7
tianshou/trainer/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from tianshou.trainer.episodic import episodic_trainer
|
||||
from tianshou.trainer.step import step_trainer
|
||||
|
||||
__all__ = [
|
||||
'episodic_trainer',
|
||||
'step_trainer',
|
||||
]
|
68
tianshou/trainer/episodic.py
Normal file
68
tianshou/trainer/episodic.py
Normal file
@ -0,0 +1,68 @@
|
||||
import time
|
||||
import tqdm
|
||||
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
|
||||
|
||||
def episodic_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, episode_per_test,
|
||||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||
writer=None, verbose=True):
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1
|
||||
stat = {}
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = train_collector.collect(n_episode=collect_per_step)
|
||||
losses = policy.learn(train_collector.sample(0), batch_size)
|
||||
train_collector.reset_buffer()
|
||||
step = 1
|
||||
data = {}
|
||||
for k in losses.keys():
|
||||
if isinstance(losses[k], list):
|
||||
step = max(step, len(losses[k]))
|
||||
global_step += step
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
if writer:
|
||||
writer.add_scalar(
|
||||
k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
if writer:
|
||||
writer.add_scalar(
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
t.update(step)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch)
|
||||
result = test_collector.collect(n_episode=episode_per_test)
|
||||
if best_epoch == -1 or best_reward < result['rew']:
|
||||
best_reward = result['rew']
|
||||
best_epoch = epoch
|
||||
if verbose:
|
||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if stop_fn(best_reward):
|
||||
break
|
||||
duration = time.time() - start_time
|
||||
return train_collector.collect_step, train_collector.collect_episode,\
|
||||
test_collector.collect_step, test_collector.collect_episode,\
|
||||
best_reward, duration
|
66
tianshou/trainer/step.py
Normal file
66
tianshou/trainer/step.py
Normal file
@ -0,0 +1,66 @@
|
||||
import time
|
||||
import tqdm
|
||||
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
|
||||
|
||||
def step_trainer(policy, train_collector, test_collector, max_epoch,
|
||||
step_per_epoch, collect_per_step, episode_per_test,
|
||||
batch_size, train_fn=None, test_fn=None, stop_fn=None,
|
||||
writer=None, verbose=True):
|
||||
global_step = 0
|
||||
best_epoch, best_reward = -1, -1
|
||||
stat = {}
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(
|
||||
total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = train_collector.collect(n_step=collect_per_step)
|
||||
for i in range(min(
|
||||
result['n/st'] // collect_per_step,
|
||||
t.total - t.n)):
|
||||
global_step += 1
|
||||
losses = policy.learn(train_collector.sample(batch_size))
|
||||
data = {}
|
||||
for k in result.keys():
|
||||
data[k] = f'{result[k]:.2f}'
|
||||
if writer:
|
||||
writer.add_scalar(
|
||||
k, result[k], global_step=global_step)
|
||||
for k in losses.keys():
|
||||
if stat.get(k) is None:
|
||||
stat[k] = MovAvg()
|
||||
stat[k].add(losses[k])
|
||||
data[k] = f'{stat[k].get():.6f}'
|
||||
if writer:
|
||||
writer.add_scalar(
|
||||
k, stat[k].get(), global_step=global_step)
|
||||
t.update(1)
|
||||
t.set_postfix(**data)
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch)
|
||||
result = test_collector.collect(n_episode=episode_per_test)
|
||||
if best_epoch == -1 or best_reward < result['rew']:
|
||||
best_reward = result['rew']
|
||||
best_epoch = epoch
|
||||
if verbose:
|
||||
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if stop_fn(best_reward):
|
||||
break
|
||||
duration = time.time() - start_time
|
||||
return train_collector.collect_step, train_collector.collect_episode,\
|
||||
test_collector.collect_step, test_collector.collect_episode,\
|
||||
best_reward, duration
|
Loading…
x
Reference in New Issue
Block a user