fix bug in test
This commit is contained in:
parent
39de63592f
commit
fd621971e5
@ -39,7 +39,7 @@ def get_args():
|
|||||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
parser.add_argument('--eps-train', type=float, default=0.1)
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--n-step', type=int, default=1)
|
parser.add_argument('--n-step', type=int, default=1)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
@ -48,7 +48,7 @@ def get_args():
|
|||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=3)
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
parser.add_argument('--training-num', type=int, default=8)
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
parser.add_argument('--test-num', type=int, default=20)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
@ -99,13 +99,18 @@ def test_dqn(args=get_args()):
|
|||||||
policy.train()
|
policy.train()
|
||||||
policy.sync_weight()
|
policy.sync_weight()
|
||||||
policy.set_eps(args.eps_train)
|
policy.set_eps(args.eps_train)
|
||||||
with tqdm.trange(
|
with tqdm.tqdm(
|
||||||
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||||
for _ in t:
|
while t.n < t.total:
|
||||||
result = training_collector.collect(
|
result = training_collector.collect(
|
||||||
n_step=args.collect_per_step)
|
n_step=args.collect_per_step)
|
||||||
|
for i in range(min(
|
||||||
|
result['n_step'] // args.collect_per_step,
|
||||||
|
t.total - t.n)):
|
||||||
|
t.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
loss = policy.learn(training_collector.sample(args.batch_size))
|
loss = policy.learn(
|
||||||
|
training_collector.sample(args.batch_size))
|
||||||
stat_loss.add(loss)
|
stat_loss.add(loss)
|
||||||
writer.add_scalar(
|
writer.add_scalar(
|
||||||
'reward', result['reward'], global_step=global_step)
|
'reward', result['reward'], global_step=global_step)
|
||||||
@ -144,7 +149,7 @@ def test_dqn(args=get_args()):
|
|||||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
test_collector = Collector(policy, env, ReplayBuffer(1))
|
test_collector = Collector(policy, env)
|
||||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -94,15 +94,15 @@ def get_args():
|
|||||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
parser.add_argument('--seed', type=int, default=1626)
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
parser.add_argument('--lr', type=float, default=1e-3)
|
parser.add_argument('--lr', type=float, default=3e-4)
|
||||||
parser.add_argument('--gamma', type=float, default=0.9)
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
parser.add_argument('--epoch', type=int, default=100)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=64)
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
parser.add_argument('--layer-num', type=int, default=3)
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
parser.add_argument('--training-num', type=int, default=8)
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
parser.add_argument('--test-num', type=int, default=20)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
@ -157,6 +157,7 @@ def test_pg(args=get_args()):
|
|||||||
n_episode=args.collect_per_step)
|
n_episode=args.collect_per_step)
|
||||||
losses = policy.learn(
|
losses = policy.learn(
|
||||||
training_collector.sample(0), args.batch_size)
|
training_collector.sample(0), args.batch_size)
|
||||||
|
training_collector.reset_buffer()
|
||||||
global_step += len(losses)
|
global_step += len(losses)
|
||||||
t.update(len(losses))
|
t.update(len(losses))
|
||||||
stat_loss.add(losses)
|
stat_loss.add(losses)
|
||||||
@ -196,7 +197,7 @@ def test_pg(args=get_args()):
|
|||||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
test_collector = Collector(policy, env, ReplayBuffer(1))
|
test_collector = Collector(policy, env)
|
||||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||||
test_collector.close()
|
test_collector.close()
|
||||||
|
@ -11,7 +11,7 @@ from tianshou.utils import MovAvg
|
|||||||
class Collector(object):
|
class Collector(object):
|
||||||
"""docstring for Collector"""
|
"""docstring for Collector"""
|
||||||
|
|
||||||
def __init__(self, policy, env, buffer, stat_size=100):
|
def __init__(self, policy, env, buffer=ReplayBuffer(20000), stat_size=100):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
@ -168,10 +168,14 @@ class Collector(object):
|
|||||||
self._obs = obs_next
|
self._obs = obs_next
|
||||||
self.stat_speed.add((self.collect_step - start_step) / (
|
self.stat_speed.add((self.collect_step - start_step) / (
|
||||||
time.time() - start_time))
|
time.time() - start_time))
|
||||||
|
if self._multi_env:
|
||||||
|
cur_episode = sum(cur_episode)
|
||||||
return {
|
return {
|
||||||
'reward': self.stat_reward.get(),
|
'reward': self.stat_reward.get(),
|
||||||
'length': self.stat_length.get(),
|
'length': self.stat_length.get(),
|
||||||
'speed': self.stat_speed.get(),
|
'speed': self.stat_speed.get(),
|
||||||
|
'n_episode': cur_episode,
|
||||||
|
'n_step': cur_step,
|
||||||
}
|
}
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
|
@ -2,7 +2,6 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributions import Categorical
|
|
||||||
|
|
||||||
from tianshou.data import Batch
|
from tianshou.data import Batch
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
@ -11,12 +10,12 @@ from tianshou.policy import BasePolicy
|
|||||||
class PGPolicy(BasePolicy, nn.Module):
|
class PGPolicy(BasePolicy, nn.Module):
|
||||||
"""docstring for PGPolicy"""
|
"""docstring for PGPolicy"""
|
||||||
|
|
||||||
def __init__(self, model, optim, dist=Categorical,
|
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, normalized_reward=True):
|
discount_factor=0.99, normalized_reward=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.dist = dist
|
self.dist_fn = dist_fn
|
||||||
self._eps = np.finfo(np.float32).eps.item()
|
self._eps = np.finfo(np.float32).eps.item()
|
||||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
@ -35,7 +34,7 @@ class PGPolicy(BasePolicy, nn.Module):
|
|||||||
def __call__(self, batch, state=None):
|
def __call__(self, batch, state=None):
|
||||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
logits = F.softmax(logits, dim=1)
|
logits = F.softmax(logits, dim=1)
|
||||||
dist = self.dist(logits)
|
dist = self.dist_fn(logits)
|
||||||
act = dist.sample().detach().cpu().numpy()
|
act = dist.sample().detach().cpu().numpy()
|
||||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user