fix bug in test
This commit is contained in:
parent
39de63592f
commit
fd621971e5
@ -39,7 +39,7 @@ def get_args():
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
@ -48,7 +48,7 @@ def get_args():
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=20)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
@ -99,26 +99,31 @@ def test_dqn(args=get_args()):
|
||||
policy.train()
|
||||
policy.sync_weight()
|
||||
policy.set_eps(args.eps_train)
|
||||
with tqdm.trange(
|
||||
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
for _ in t:
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_step=args.collect_per_step)
|
||||
global_step += 1
|
||||
loss = policy.learn(training_collector.sample(args.batch_size))
|
||||
stat_loss.add(loss)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
for i in range(min(
|
||||
result['n_step'] // args.collect_per_step,
|
||||
t.total - t.n)):
|
||||
t.update(1)
|
||||
global_step += 1
|
||||
loss = policy.learn(
|
||||
training_collector.sample(args.batch_size))
|
||||
stat_loss.add(loss)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
@ -144,7 +149,7 @@ def test_dqn(args=get_args()):
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env, ReplayBuffer(1))
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
|
@ -94,15 +94,15 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=20)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
@ -157,6 +157,7 @@ def test_pg(args=get_args()):
|
||||
n_episode=args.collect_per_step)
|
||||
losses = policy.learn(
|
||||
training_collector.sample(0), args.batch_size)
|
||||
training_collector.reset_buffer()
|
||||
global_step += len(losses)
|
||||
t.update(len(losses))
|
||||
stat_loss.add(losses)
|
||||
@ -196,7 +197,7 @@ def test_pg(args=get_args()):
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env, ReplayBuffer(1))
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
|
@ -11,7 +11,7 @@ from tianshou.utils import MovAvg
|
||||
class Collector(object):
|
||||
"""docstring for Collector"""
|
||||
|
||||
def __init__(self, policy, env, buffer, stat_size=100):
|
||||
def __init__(self, policy, env, buffer=ReplayBuffer(20000), stat_size=100):
|
||||
super().__init__()
|
||||
self.env = env
|
||||
self.env_num = 1
|
||||
@ -168,10 +168,14 @@ class Collector(object):
|
||||
self._obs = obs_next
|
||||
self.stat_speed.add((self.collect_step - start_step) / (
|
||||
time.time() - start_time))
|
||||
if self._multi_env:
|
||||
cur_episode = sum(cur_episode)
|
||||
return {
|
||||
'reward': self.stat_reward.get(),
|
||||
'length': self.stat_length.get(),
|
||||
'speed': self.stat_speed.get(),
|
||||
'n_episode': cur_episode,
|
||||
'n_step': cur_step,
|
||||
}
|
||||
|
||||
def sample(self, batch_size):
|
||||
|
@ -2,7 +2,6 @@ import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -11,12 +10,12 @@ from tianshou.policy import BasePolicy
|
||||
class PGPolicy(BasePolicy, nn.Module):
|
||||
"""docstring for PGPolicy"""
|
||||
|
||||
def __init__(self, model, optim, dist=Categorical,
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, normalized_reward=True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.dist = dist
|
||||
self.dist_fn = dist_fn
|
||||
self._eps = np.finfo(np.float32).eps.item()
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
@ -35,7 +34,7 @@ class PGPolicy(BasePolicy, nn.Module):
|
||||
def __call__(self, batch, state=None):
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist(logits)
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample().detach().cpu().numpy()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user