fix bug in test

This commit is contained in:
Trinkle23897 2020-03-17 15:16:30 +08:00
parent 39de63592f
commit fd621971e5
4 changed files with 39 additions and 30 deletions

View File

@ -39,7 +39,7 @@ def get_args():
parser.add_argument('--eps-test', type=float, default=0.05)
parser.add_argument('--eps-train', type=float, default=0.1)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--epoch', type=int, default=100)
@ -48,7 +48,7 @@ def get_args():
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=20)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
@ -99,26 +99,31 @@ def test_dqn(args=get_args()):
policy.train()
policy.sync_weight()
policy.set_eps(args.eps_train)
with tqdm.trange(
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
for _ in t:
with tqdm.tqdm(
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
while t.n < t.total:
result = training_collector.collect(
n_step=args.collect_per_step)
global_step += 1
loss = policy.learn(training_collector.sample(args.batch_size))
stat_loss.add(loss)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
for i in range(min(
result['n_step'] // args.collect_per_step,
t.total - t.n)):
t.update(1)
global_step += 1
loss = policy.learn(
training_collector.sample(args.batch_size))
stat_loss.add(loss)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
# eval
test_collector.reset_env()
test_collector.reset_buffer()
@ -144,7 +149,7 @@ def test_dqn(args=get_args()):
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task)
test_collector = Collector(policy, env, ReplayBuffer(1))
test_collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
test_collector.close()

View File

@ -94,15 +94,15 @@ def get_args():
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--collect-per-step', type=int, default=5)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=3)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--test-num', type=int, default=20)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
@ -157,6 +157,7 @@ def test_pg(args=get_args()):
n_episode=args.collect_per_step)
losses = policy.learn(
training_collector.sample(0), args.batch_size)
training_collector.reset_buffer()
global_step += len(losses)
t.update(len(losses))
stat_loss.add(losses)
@ -196,7 +197,7 @@ def test_pg(args=get_args()):
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task)
test_collector = Collector(policy, env, ReplayBuffer(1))
test_collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
test_collector.close()

View File

@ -11,7 +11,7 @@ from tianshou.utils import MovAvg
class Collector(object):
"""docstring for Collector"""
def __init__(self, policy, env, buffer, stat_size=100):
def __init__(self, policy, env, buffer=ReplayBuffer(20000), stat_size=100):
super().__init__()
self.env = env
self.env_num = 1
@ -168,10 +168,14 @@ class Collector(object):
self._obs = obs_next
self.stat_speed.add((self.collect_step - start_step) / (
time.time() - start_time))
if self._multi_env:
cur_episode = sum(cur_episode)
return {
'reward': self.stat_reward.get(),
'length': self.stat_length.get(),
'speed': self.stat_speed.get(),
'n_episode': cur_episode,
'n_step': cur_step,
}
def sample(self, batch_size):

View File

@ -2,7 +2,6 @@ import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.distributions import Categorical
from tianshou.data import Batch
from tianshou.policy import BasePolicy
@ -11,12 +10,12 @@ from tianshou.policy import BasePolicy
class PGPolicy(BasePolicy, nn.Module):
"""docstring for PGPolicy"""
def __init__(self, model, optim, dist=Categorical,
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, normalized_reward=True):
super().__init__()
self.model = model
self.optim = optim
self.dist = dist
self.dist_fn = dist_fn
self._eps = np.finfo(np.float32).eps.item()
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
self._gamma = discount_factor
@ -35,7 +34,7 @@ class PGPolicy(BasePolicy, nn.Module):
def __call__(self, batch, state=None):
logits, h = self.model(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1)
dist = self.dist(logits)
dist = self.dist_fn(logits)
act = dist.sample().detach().cpu().numpy()
return Batch(logits=logits, act=act, state=h, dist=dist)