add trainer

This commit is contained in:
Trinkle23897 2020-03-19 17:23:46 +08:00
parent 9c5417dd51
commit c87fe3c18c
16 changed files with 371 additions and 309 deletions

View File

@ -1,6 +1,4 @@
import gym import gym
import time
import tqdm
import torch import torch
import argparse import argparse
import numpy as np import numpy as np
@ -9,12 +7,12 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg from tianshou.trainer import episodic_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
class Net(nn.Module): class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device='cpu'): def __init__(self, layer_num, state_shape, device='cpu'):
super().__init__() super().__init__()
self.device = device self.device = device
self.model = [ self.model = [
@ -22,18 +20,40 @@ class Net(nn.Module):
nn.ReLU(inplace=True)] nn.ReLU(inplace=True)]
for i in range(layer_num): for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.actor = self.model + [nn.Linear(128, np.prod(action_shape))] self.model = nn.Sequential(*self.model)
self.critic = self.model + [nn.Linear(128, 1)]
self.actor = nn.Sequential(*self.actor)
self.critic = nn.Sequential(*self.critic)
def forward(self, s, **kwargs): def forward(self, s):
s = torch.tensor(s, device=self.device, dtype=torch.float) s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0] batch = s.shape[0]
s = s.view(batch, -1) s = s.view(batch, -1)
logits = self.actor(s) logits = self.model(s)
value = self.critic(s) return logits
return logits, value, None
class Actor(nn.Module):
def __init__(self, preprocess_net, action_shape):
super().__init__()
self.model = nn.Sequential(*[
preprocess_net,
nn.Linear(128, np.prod(action_shape)),
])
def forward(self, s, **kwargs):
logits = self.model(s)
return logits, None
class Critic(nn.Module):
def __init__(self, preprocess_net):
super().__init__()
self.model = nn.Sequential(*[
preprocess_net,
nn.Linear(128, 1),
])
def forward(self, s):
logits = self.model(s)
return logits
def get_args(): def get_args():
@ -80,83 +100,45 @@ def test_a2c(args=get_args()):
train_envs.seed(args.seed) train_envs.seed(args.seed)
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) net = Net(args.layer_num, args.state_shape, args.device)
net = net.to(args.device) actor = Actor(net, args.action_shape).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr) critic = Critic(net).to(args.device)
optim = torch.optim.Adam(list(
actor.parameters()) + list(critic.parameters()), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = A2CPolicy( policy = A2CPolicy(
net, optim, dist, args.gamma, actor, critic, optim, dist, args.gamma, vf_coef=args.vf_coef,
vf_coef=args.vf_coef, entropy_coef=args.entropy_coef, max_grad_norm=args.max_grad_norm)
entropy_coef=args.entropy_coef,
max_grad_norm=args.max_grad_norm)
# collector # collector
training_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs, stat_size=args.test_num)
# log # log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10 def stop_fn(x):
start_time = time.time() return x >= env.spec.reward_threshold
for epoch in range(1, 1 + args.epoch):
desc = f'Epoch #{epoch}' # trainer
# train train_step, train_episode, test_step, test_episode, best_rew, duration = \
policy.train() episodic_trainer(
with tqdm.tqdm( policy, train_collector, test_collector, args.epoch,
total=args.step_per_epoch, desc=desc, **tqdm_config) as t: args.step_per_epoch, args.collect_per_step, args.test_num,
while t.n < t.total: args.batch_size, stop_fn=stop_fn, writer=writer)
result = training_collector.collect( assert stop_fn(best_rew)
n_episode=args.collect_per_step) train_collector.close()
losses = policy.learn(
training_collector.sample(0), args.batch_size)
training_collector.reset_buffer()
global_step += len(losses)
t.update(len(losses))
stat_loss.add(losses)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if best_reward >= env.spec.reward_threshold:
break
assert best_reward >= env.spec.reward_threshold
training_collector.close()
test_collector.close() test_collector.close()
if __name__ == '__main__': if __name__ == '__main__':
train_cnt = training_collector.collect_step print(f'Collect {train_step} frame / {train_episode} episode during '
test_cnt = test_collector.collect_step f'training and {test_step} frame / {test_episode} episode during'
duration = time.time() - start_time f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' f'{(train_step + test_step) / duration:.2f}it/s')
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
test_collector = Collector(policy, env) collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35) result = collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}') print(f'Final reward: {result["rew"]}, length: {result["len"]}')
test_collector.close() collector.close()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,6 +1,4 @@
import gym import gym
import time
import tqdm
import torch import torch
import argparse import argparse
import numpy as np import numpy as np
@ -8,7 +6,7 @@ from torch import nn
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.utils import tqdm_config, MovAvg from tianshou.trainer import step_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
from tianshou.env import VectorEnv, SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
@ -121,85 +119,39 @@ def test_ddpg(args=get_args()):
[env.action_space.low[0], env.action_space.high[0]], [env.action_space.low[0], env.action_space.high[0]],
args.tau, args.gamma, args.exploration_noise) args.tau, args.gamma, args.exploration_noise)
# collector # collector
training_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size), 1) policy, train_envs, ReplayBuffer(args.buffer_size), 1)
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs, stat_size=args.test_num)
# log # log
stat_a_loss = MovAvg()
stat_c_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10 def stop_fn(x):
start_time = time.time()
# training_collector.collect(n_step=1000)
for epoch in range(1, 1 + args.epoch):
desc = f'Epoch #{epoch}'
# train
policy.train()
with tqdm.tqdm(
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
while t.n < t.total:
result = training_collector.collect(
n_step=args.collect_per_step)
for i in range(min(
result['n_step'] // args.collect_per_step,
t.total - t.n)):
t.update(1)
global_step += 1
actor_loss, critic_loss = policy.learn(
training_collector.sample(args.batch_size))
policy.sync_weight()
stat_a_loss.add(actor_loss)
stat_c_loss.add(critic_loss)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'actor_loss', stat_a_loss.get(),
global_step=global_step)
writer.add_scalar(
'critic_loss', stat_a_loss.get(),
global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(actor_loss=f'{stat_a_loss.get():.6f}',
critic_loss=f'{stat_c_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'Pendulum-v0' and best_reward >= -250:
break
if args.task == 'Pendulum-v0': if args.task == 'Pendulum-v0':
assert best_reward >= -250 return x >= -250
training_collector.close() else:
return False
# trainer
train_step, train_episode, test_step, test_episode, best_rew, duration = \
step_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, stop_fn=stop_fn, writer=writer)
if args.task == 'Pendulum-v0':
assert stop_fn(best_rew)
train_collector.close()
test_collector.close() test_collector.close()
if __name__ == '__main__': if __name__ == '__main__':
train_cnt = training_collector.collect_step print(f'Collect {train_step} frame / {train_episode} episode during '
test_cnt = test_collector.collect_step f'training and {test_step} frame / {test_episode} episode during'
duration = time.time() - start_time f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' f'{(train_step + test_step) / duration:.2f}it/s')
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
test_collector = Collector(policy, env) collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35) result = collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}') print(f'Final reward: {result["rew"]}, length: {result["len"]}')
test_collector.close() collector.close()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,6 +1,4 @@
import gym import gym
import time
import tqdm
import torch import torch
import argparse import argparse
import numpy as np import numpy as np
@ -9,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DQNPolicy from tianshou.policy import DQNPolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg from tianshou.trainer import step_trainer
from tianshou.data import Collector, ReplayBuffer from tianshou.data import Collector, ReplayBuffer
@ -80,79 +78,45 @@ def test_dqn(args=get_args()):
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = DQNPolicy(net, optim, args.gamma, args.n_step) policy = DQNPolicy(net, optim, args.gamma, args.n_step)
# collector # collector
training_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs, stat_size=args.test_num)
training_collector.collect(n_step=args.batch_size) train_collector.collect(n_step=args.batch_size)
# log # log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10 def stop_fn(x):
start_time = time.time() return x >= env.spec.reward_threshold
for epoch in range(1, 1 + args.epoch):
desc = f"Epoch #{epoch}" def train_fn(x):
# train
policy.train()
policy.sync_weight() policy.sync_weight()
policy.set_eps(args.eps_train) policy.set_eps(args.eps_train)
with tqdm.tqdm(
total=args.step_per_epoch, desc=desc, **tqdm_config) as t: def test_fn(x):
while t.n < t.total:
result = training_collector.collect(
n_step=args.collect_per_step)
for i in range(min(
result['n_step'] // args.collect_per_step,
t.total - t.n)):
t.update(1)
global_step += 1
loss = policy.learn(
training_collector.sample(args.batch_size))
stat_loss.add(loss)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']: # trainer
best_reward = result['reward'] train_step, train_episode, test_step, test_episode, best_rew, duration = \
best_epoch = epoch step_trainer(
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' policy, train_collector, test_collector, args.epoch,
f'best_reward: {best_reward:.6f} in #{best_epoch}') args.step_per_epoch, args.collect_per_step, args.test_num,
if best_reward >= env.spec.reward_threshold: args.batch_size, train_fn=train_fn, test_fn=test_fn,
break stop_fn=stop_fn, writer=writer)
assert best_reward >= env.spec.reward_threshold
training_collector.close() assert stop_fn(best_rew)
train_collector.close()
test_collector.close() test_collector.close()
if __name__ == '__main__': if __name__ == '__main__':
train_cnt = training_collector.collect_step print(f'Collect {train_step} frame / {train_episode} episode during '
test_cnt = test_collector.collect_step f'training and {test_step} frame / {test_episode} episode during'
duration = time.time() - start_time f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' f'{(train_step + test_step) / duration:.2f}it/s')
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
test_collector = Collector(policy, env) collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35) result = collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}') print(f'Final reward: {result["rew"]}, length: {result["len"]}')
test_collector.close() collector.close()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,6 +1,5 @@
import gym import gym
import time import time
import tqdm
import torch import torch
import argparse import argparse
import numpy as np import numpy as np
@ -9,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg from tianshou.trainer import episodic_trainer
from tianshou.data import Batch, Collector, ReplayBuffer from tianshou.data import Batch, Collector, ReplayBuffer
@ -131,73 +130,35 @@ def test_pg(args=get_args()):
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma) policy = PGPolicy(net, optim, dist, args.gamma)
# collector # collector
training_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs, stat_size=args.test_num) test_collector = Collector(policy, test_envs, stat_size=args.test_num)
# log # log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir) writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10 def stop_fn(x):
start_time = time.time() return x >= env.spec.reward_threshold
for epoch in range(1, 1 + args.epoch):
desc = f"Epoch #{epoch}" # trainer
# train train_step, train_episode, test_step, test_episode, best_rew, duration = \
policy.train() episodic_trainer(
with tqdm.tqdm( policy, train_collector, test_collector, args.epoch,
total=args.step_per_epoch, desc=desc, **tqdm_config) as t: args.step_per_epoch, args.collect_per_step, args.test_num,
while t.n < t.total: args.batch_size, stop_fn=stop_fn, writer=writer)
result = training_collector.collect( assert stop_fn(best_rew)
n_episode=args.collect_per_step) train_collector.close()
losses = policy.learn(
training_collector.sample(0), args.batch_size)
training_collector.reset_buffer()
global_step += len(losses)
t.update(len(losses))
stat_loss.add(losses)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if best_reward >= env.spec.reward_threshold:
break
assert best_reward >= env.spec.reward_threshold
training_collector.close()
test_collector.close() test_collector.close()
if __name__ == '__main__': if __name__ == '__main__':
train_cnt = training_collector.collect_step print(f'Collect {train_step} frame / {train_episode} episode during '
test_cnt = test_collector.collect_step f'training and {test_step} frame / {test_episode} episode during'
duration = time.time() - start_time f' test in {duration:.2f}s, best_reward: {best_rew}, speed: '
print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' f'{(train_step + test_step) / duration:.2f}it/s')
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
test_collector = Collector(policy, env) collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35) result = collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}') print(f'Final reward: {result["rew"]}, length: {result["len"]}')
test_collector.close() collector.close()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,4 +1,5 @@
from tianshou import data, env, utils, policy, exploration from tianshou import data, env, utils, policy, trainer,\
exploration
__version__ = '0.2.0' __version__ = '0.2.0'
__all__ = [ __all__ = [
@ -6,5 +7,6 @@ __all__ = [
'data', 'data',
'utils', 'utils',
'policy', 'policy',
'trainer',
'exploration', 'exploration',
] ]

View File

@ -16,6 +16,7 @@ class Collector(object):
self.env = env self.env = env
self.env_num = 1 self.env_num = 1
self.collect_step = 0 self.collect_step = 0
self.collect_episode = 0
self.buffer = buffer self.buffer = buffer
self.policy = policy self.policy = policy
self.process_fn = policy.process_fn self.process_fn = policy.process_fn
@ -39,9 +40,8 @@ class Collector(object):
self.reset_buffer() self.reset_buffer()
# state over batch is either a list, an np.ndarray, or a torch.Tensor # state over batch is either a list, an np.ndarray, or a torch.Tensor
self.state = None self.state = None
self.stat_reward = MovAvg(stat_size) self.step_speed = MovAvg(stat_size)
self.stat_length = MovAvg(stat_size) self.episode_speed = MovAvg(stat_size)
self.stat_speed = MovAvg(stat_size)
def reset_buffer(self): def reset_buffer(self):
if self._multi_buf: if self._multi_buf:
@ -81,11 +81,12 @@ class Collector(object):
def collect(self, n_step=0, n_episode=0, render=0): def collect(self, n_step=0, n_episode=0, render=0):
start_time = time.time() start_time = time.time()
start_step = self.collect_step
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\ assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
"One and only one collection number specification permitted!" "One and only one collection number specification permitted!"
cur_step = 0 cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0 cur_episode = np.zeros(self.env_num) if self._multi_env else 0
reward_sum = 0
length_sum = 0
while True: while True:
if self._multi_env: if self._multi_env:
batch_data = Batch( batch_data = Batch(
@ -126,20 +127,17 @@ class Collector(object):
elif self._multi_buf: elif self._multi_buf:
self.buffer[i].add(**data) self.buffer[i].add(**data)
cur_step += 1 cur_step += 1
self.collect_step += 1
else: else:
self.buffer.add(**data) self.buffer.add(**data)
cur_step += 1 cur_step += 1
self.collect_step += 1
if self._done[i]: if self._done[i]:
cur_episode[i] += 1 cur_episode[i] += 1
self.stat_reward.add(self.reward[i]) reward_sum += self.reward[i]
self.stat_length.add(self.length[i]) length_sum += self.length[i]
self.reward[i], self.length[i] = 0, 0 self.reward[i], self.length[i] = 0, 0
if self._cached_buf: if self._cached_buf:
self.buffer.update(self._cached_buf[i]) self.buffer.update(self._cached_buf[i])
cur_step += len(self._cached_buf[i]) cur_step += len(self._cached_buf[i])
self.collect_step += len(self._cached_buf[i])
self._cached_buf[i].reset() self._cached_buf[i].reset()
if isinstance(self.state, list): if isinstance(self.state, list):
self.state[i] = None self.state[i] = None
@ -158,11 +156,10 @@ class Collector(object):
self._obs, self._act[0], self._rew, self._obs, self._act[0], self._rew,
self._done, obs_next, self._info) self._done, obs_next, self._info)
cur_step += 1 cur_step += 1
self.collect_step += 1
if self._done: if self._done:
cur_episode += 1 cur_episode += 1
self.stat_reward.add(self.reward) reward_sum += self.reward
self.stat_length.add(self.length) length_sum += self.length
self.reward, self.length = 0, 0 self.reward, self.length = 0, 0
self.state = None self.state = None
self._obs = self.env.reset() self._obs = self.env.reset()
@ -172,16 +169,20 @@ class Collector(object):
break break
self._obs = obs_next self._obs = obs_next
self._obs = obs_next self._obs = obs_next
self.stat_speed.add((self.collect_step - start_step) / (
time.time() - start_time))
if self._multi_env: if self._multi_env:
cur_episode = sum(cur_episode) cur_episode = sum(cur_episode)
duration = time.time() - start_time
self.step_speed.add(cur_step / duration)
self.episode_speed.add(cur_episode / duration)
self.collect_step += cur_step
self.collect_episode += cur_episode
return { return {
'reward': self.stat_reward.get(), 'n/ep': cur_episode,
'length': self.stat_length.get(), 'n/st': cur_step,
'speed': self.stat_speed.get(), 'speed/st': self.step_speed.get(),
'n_episode': cur_episode, 'speed/ep': self.episode_speed.get(),
'n_step': cur_step, 'rew': reward_sum / cur_episode,
'len': length_sum / cur_episode,
} }
def sample(self, batch_size): def sample(self, batch_size):

View File

@ -3,11 +3,13 @@ from tianshou.policy.dqn import DQNPolicy
from tianshou.policy.pg import PGPolicy from tianshou.policy.pg import PGPolicy
from tianshou.policy.a2c import A2CPolicy from tianshou.policy.a2c import A2CPolicy
from tianshou.policy.ddpg import DDPGPolicy from tianshou.policy.ddpg import DDPGPolicy
from tianshou.policy.ppo import PPOPolicy
__all__ = [ __all__ = [
'BasePolicy', 'BasePolicy',
'DQNPolicy', 'DQNPolicy',
'PGPolicy', 'PGPolicy',
'A2CPolicy', 'A2CPolicy',
'DDPGPolicy' 'DDPGPolicy',
'PPOPolicy',
] ]

View File

@ -9,20 +9,23 @@ from tianshou.policy import PGPolicy
class A2CPolicy(PGPolicy): class A2CPolicy(PGPolicy):
"""docstring for A2CPolicy""" """docstring for A2CPolicy"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, entropy_coef=.01, discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
max_grad_norm=None): max_grad_norm=None):
super().__init__(model, optim, dist_fn, discount_factor) super().__init__(None, optim, dist_fn, discount_factor)
self.actor = actor
self.critic = critic
self._w_value = vf_coef self._w_value = vf_coef
self._w_entropy = entropy_coef self._w_entropy = entropy_coef
self._grad_norm = max_grad_norm self._grad_norm = max_grad_norm
def __call__(self, batch, state=None): def __call__(self, batch, state=None):
logits, value, h = self.model(batch.obs, state=state, info=batch.info) logits, h = self.actor(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1) logits = F.softmax(logits, dim=1)
dist = self.dist_fn(logits) dist = self.dist_fn(logits)
act = dist.sample() act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist, value=value) return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None): def learn(self, batch, batch_size=None):
losses = [] losses = []
@ -30,7 +33,7 @@ class A2CPolicy(PGPolicy):
self.optim.zero_grad() self.optim.zero_grad()
result = self(b) result = self(b)
dist = result.dist dist = result.dist
v = result.value v = self.critic(b.obs)
a = torch.tensor(b.act, device=dist.logits.device) a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device) r = torch.tensor(b.returns, device=dist.logits.device)
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean() actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
@ -45,4 +48,4 @@ class A2CPolicy(PGPolicy):
self.model.parameters(), max_norm=self._grad_norm) self.model.parameters(), max_norm=self._grad_norm)
self.optim.step() self.optim.step()
losses.append(loss.detach().cpu().numpy()) losses.append(loss.detach().cpu().numpy())
return losses return {'loss': losses}

View File

@ -18,6 +18,7 @@ class BasePolicy(ABC, nn.Module):
@abstractmethod @abstractmethod
def learn(self, batch, batch_size=None): def learn(self, batch, batch_size=None):
# return a dict which includes loss and its name
pass pass
def sync_weight(self): def sync_weight(self):

View File

@ -28,7 +28,7 @@ class DDPGPolicy(BasePolicy):
self._tau = tau self._tau = tau
assert 0 < gamma <= 1, 'gamma should in (0, 1]' assert 0 < gamma <= 1, 'gamma should in (0, 1]'
self._gamma = gamma self._gamma = gamma
assert 0 <= exploration_noise, 'noise should greater than zero' assert 0 <= exploration_noise, 'noise should not be negative'
self._eps = exploration_noise self._eps = exploration_noise
self._range = action_range self._range = action_range
# self.noise = OUNoise() # self.noise = OUNoise()
@ -87,5 +87,8 @@ class DDPGPolicy(BasePolicy):
self.actor_optim.zero_grad() self.actor_optim.zero_grad()
actor_loss.backward() actor_loss.backward()
self.actor_optim.step() self.actor_optim.step()
return actor_loss.detach().cpu().numpy(),\ self.sync_weight()
critic_loss.detach().cpu().numpy() return {
'loss/actor': actor_loss.detach().cpu().numpy(),
'loss/critic': critic_loss.detach().cpu().numpy(),
}

View File

@ -93,4 +93,4 @@ class DQNPolicy(BasePolicy):
loss = F.mse_loss(q, r) loss = F.mse_loss(q, r)
loss.backward() loss.backward()
self.optim.step() self.optim.step()
return loss.detach().cpu().numpy() return {'loss': loss.detach().cpu().numpy()}

View File

@ -45,7 +45,7 @@ class PGPolicy(BasePolicy):
loss.backward() loss.backward()
self.optim.step() self.optim.step()
losses.append(loss.detach().cpu().numpy()) losses.append(loss.detach().cpu().numpy())
return losses return {'loss': losses}
def _vanilla_returns(self, batch): def _vanilla_returns(self, batch):
returns = batch.rew[:] returns = batch.rew[:]

50
tianshou/policy/ppo.py Normal file
View File

@ -0,0 +1,50 @@
import torch
from torch import nn
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import PGPolicy
class PPOPolicy(PGPolicy):
"""docstring for PPOPolicy"""
def __init__(self, actor, actor_optim,
critic, critic_optim,
dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
eps_clip=None):
super().__init__(None, None, dist_fn, discount_factor)
self._w_value = vf_coef
self._w_entropy = entropy_coef
self._eps_clip = eps_clip
def __call__(self, batch, state=None):
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1)
dist = self.dist_fn(logits)
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
def learn(self, batch, batch_size=None):
losses = []
for b in batch.split(batch_size):
self.optim.zero_grad()
result = self(b)
dist = result.dist
v = result.value
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
critic_loss = F.mse_loss(r[:, None], v)
entropy_loss = dist.entropy().mean()
loss = actor_loss \
+ self._w_value * critic_loss \
- self._w_entropy * entropy_loss
loss.backward()
if self._grad_norm:
nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self._grad_norm)
self.optim.step()
losses.append(loss.detach().cpu().numpy())
return losses

View File

@ -0,0 +1,7 @@
from tianshou.trainer.episodic import episodic_trainer
from tianshou.trainer.step import step_trainer
__all__ = [
'episodic_trainer',
'step_trainer',
]

View File

@ -0,0 +1,68 @@
import time
import tqdm
from tianshou.utils import tqdm_config, MovAvg
def episodic_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size, train_fn=None, test_fn=None, stop_fn=None,
writer=None, verbose=True):
global_step = 0
best_epoch, best_reward = -1, -1
stat = {}
start_time = time.time()
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_episode=collect_per_step)
losses = policy.learn(train_collector.sample(0), batch_size)
train_collector.reset_buffer()
step = 1
data = {}
for k in losses.keys():
if isinstance(losses[k], list):
step = max(step, len(losses[k]))
global_step += step
for k in result.keys():
data[k] = f'{result[k]:.2f}'
if writer:
writer.add_scalar(
k, result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f'{stat[k].get():.6f}'
if writer:
writer.add_scalar(
k, stat[k].get(), global_step=global_step)
t.update(step)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
if test_fn:
test_fn(epoch)
result = test_collector.collect(n_episode=episode_per_test)
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
best_epoch = epoch
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if stop_fn(best_reward):
break
duration = time.time() - start_time
return train_collector.collect_step, train_collector.collect_episode,\
test_collector.collect_step, test_collector.collect_episode,\
best_reward, duration

66
tianshou/trainer/step.py Normal file
View File

@ -0,0 +1,66 @@
import time
import tqdm
from tianshou.utils import tqdm_config, MovAvg
def step_trainer(policy, train_collector, test_collector, max_epoch,
step_per_epoch, collect_per_step, episode_per_test,
batch_size, train_fn=None, test_fn=None, stop_fn=None,
writer=None, verbose=True):
global_step = 0
best_epoch, best_reward = -1, -1
stat = {}
start_time = time.time()
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(
total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
result = train_collector.collect(n_step=collect_per_step)
for i in range(min(
result['n/st'] // collect_per_step,
t.total - t.n)):
global_step += 1
losses = policy.learn(train_collector.sample(batch_size))
data = {}
for k in result.keys():
data[k] = f'{result[k]:.2f}'
if writer:
writer.add_scalar(
k, result[k], global_step=global_step)
for k in losses.keys():
if stat.get(k) is None:
stat[k] = MovAvg()
stat[k].add(losses[k])
data[k] = f'{stat[k].get():.6f}'
if writer:
writer.add_scalar(
k, stat[k].get(), global_step=global_step)
t.update(1)
t.set_postfix(**data)
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
if test_fn:
test_fn(epoch)
result = test_collector.collect(n_episode=episode_per_test)
if best_epoch == -1 or best_reward < result['rew']:
best_reward = result['rew']
best_epoch = epoch
if verbose:
print(f'Epoch #{epoch}: test_reward: {result["rew"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if stop_fn(best_reward):
break
duration = time.time() - start_time
return train_collector.collect_step, train_collector.collect_episode,\
test_collector.collect_step, test_collector.collect_episode,\
best_reward, duration