fix historical issues

This commit is contained in:
Trinkle23897 2020-04-26 16:13:51 +08:00
parent 6b96f124ae
commit 959955fa2a
6 changed files with 20 additions and 9 deletions

View File

@ -211,7 +211,7 @@ Setup policy and collectors:
```python ```python
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
use_target_network=True, target_update_freq=target_freq) target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size)) train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs) test_collector = ts.data.Collector(policy, test_envs)
``` ```
@ -242,7 +242,7 @@ collector.collect(n_episode=1, render=1 / 35)
collector.close() collector.close()
``` ```
Look at the result saved in tensorboard: (on bash script) Look at the result saved in tensorboard: (with bash script in your terminal)
```bash ```bash
tensorboard --logdir log/dqn tensorboard --logdir log/dqn

View File

@ -43,6 +43,7 @@ def get_args():
parser.add_argument('--ent-coef', type=float, default=0.001) parser.add_argument('--ent-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None) parser.add_argument('--max-grad-norm', type=float, default=None)
parser.add_argument('--gae-lambda', type=float, default=1.) parser.add_argument('--gae-lambda', type=float, default=1.)
parser.add_argument('--rew-norm', type=bool, default=False)
args = parser.parse_known_args()[0] args = parser.parse_known_args()[0]
return args return args
@ -74,7 +75,7 @@ def test_a2c(args=get_args()):
policy = A2CPolicy( policy = A2CPolicy(
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda, actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
vf_coef=args.vf_coef, ent_coef=args.ent_coef, vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm) max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -102,6 +102,7 @@ def get_args():
parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log') parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.) parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--rew-norm', type=bool, default=True)
parser.add_argument( parser.add_argument(
'--device', type=str, '--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu') default='cuda' if torch.cuda.is_available() else 'cpu')
@ -132,7 +133,8 @@ def test_pg(args=get_args()):
net = net.to(args.device) net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr) optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical dist = torch.distributions.Categorical
policy = PGPolicy(net, optim, dist, args.gamma) policy = PGPolicy(net, optim, dist, args.gamma,
reward_normalization=args.rew_norm)
# collector # collector
train_collector = Collector( train_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size)) policy, train_envs, ReplayBuffer(args.buffer_size))

View File

@ -34,7 +34,8 @@ class A2CPolicy(PGPolicy):
def __init__(self, actor, critic, optim, def __init__(self, actor, critic, optim,
dist_fn=torch.distributions.Categorical, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, ent_coef=.01, discount_factor=0.99, vf_coef=.5, ent_coef=.01,
max_grad_norm=None, gae_lambda=0.95, **kwargs): max_grad_norm=None, gae_lambda=0.95,
reward_normalization=False, **kwargs):
super().__init__(None, optim, dist_fn, discount_factor, **kwargs) super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
@ -44,6 +45,8 @@ class A2CPolicy(PGPolicy):
self._w_ent = ent_coef self._w_ent = ent_coef
self._grad_norm = max_grad_norm self._grad_norm = max_grad_norm
self._batch = 64 self._batch = 64
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch, buffer, indice): def process_fn(self, batch, buffer, indice):
if self._lambda in [0, 1]: if self._lambda in [0, 1]:
@ -82,6 +85,9 @@ class A2CPolicy(PGPolicy):
def learn(self, batch, batch_size=None, repeat=1, **kwargs): def learn(self, batch, batch_size=None, repeat=1, **kwargs):
self._batch = batch_size self._batch = batch_size
r = batch.returns
if self._rew_norm and r.std() > self.__eps:
batch.returns = (r - r.mean()) / r.std()
losses, actor_losses, vf_losses, ent_losses = [], [], [], [] losses, actor_losses, vf_losses, ent_losses = [], [], [], []
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):

View File

@ -3,8 +3,8 @@ import numpy as np
from copy import deepcopy from copy import deepcopy
import torch.nn.functional as F import torch.nn.functional as F
from tianshou.data import Batch, PrioritizedReplayBuffer
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, PrioritizedReplayBuffer
class DQNPolicy(BasePolicy): class DQNPolicy(BasePolicy):

View File

@ -21,14 +21,15 @@ class PGPolicy(BasePolicy):
""" """
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, **kwargs): discount_factor=0.99, reward_normalization=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.model = model self.model = model
self.optim = optim self.optim = optim
self.dist_fn = dist_fn self.dist_fn = dist_fn
self._eps = np.finfo(np.float32).eps.item()
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]' assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
self._gamma = discount_factor self._gamma = discount_factor
self._rew_norm = reward_normalization
self.__eps = np.finfo(np.float32).eps.item()
def process_fn(self, batch, buffer, indice): def process_fn(self, batch, buffer, indice):
r"""Compute the discounted returns for each frame: r"""Compute the discounted returns for each frame:
@ -71,7 +72,8 @@ class PGPolicy(BasePolicy):
def learn(self, batch, batch_size=None, repeat=1, **kwargs): def learn(self, batch, batch_size=None, repeat=1, **kwargs):
losses = [] losses = []
r = batch.returns r = batch.returns
batch.returns = (r - r.mean()) / (r.std() + self._eps) if self._rew_norm and r.std() > self.__eps:
batch.returns = (r - r.mean()) / r.std()
for _ in range(repeat): for _ in range(repeat):
for b in batch.split(batch_size): for b in batch.split(batch_size):
self.optim.zero_grad() self.optim.zero_grad()