fix historical issues
This commit is contained in:
parent
6b96f124ae
commit
959955fa2a
@ -211,7 +211,7 @@ Setup policy and collectors:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
|
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step,
|
||||||
use_target_network=True, target_update_freq=target_freq)
|
target_update_freq=target_freq)
|
||||||
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
|
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
|
||||||
test_collector = ts.data.Collector(policy, test_envs)
|
test_collector = ts.data.Collector(policy, test_envs)
|
||||||
```
|
```
|
||||||
@ -242,7 +242,7 @@ collector.collect(n_episode=1, render=1 / 35)
|
|||||||
collector.close()
|
collector.close()
|
||||||
```
|
```
|
||||||
|
|
||||||
Look at the result saved in tensorboard: (on bash script)
|
Look at the result saved in tensorboard: (with bash script in your terminal)
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
tensorboard --logdir log/dqn
|
tensorboard --logdir log/dqn
|
||||||
|
@ -43,6 +43,7 @@ def get_args():
|
|||||||
parser.add_argument('--ent-coef', type=float, default=0.001)
|
parser.add_argument('--ent-coef', type=float, default=0.001)
|
||||||
parser.add_argument('--max-grad-norm', type=float, default=None)
|
parser.add_argument('--max-grad-norm', type=float, default=None)
|
||||||
parser.add_argument('--gae-lambda', type=float, default=1.)
|
parser.add_argument('--gae-lambda', type=float, default=1.)
|
||||||
|
parser.add_argument('--rew-norm', type=bool, default=False)
|
||||||
args = parser.parse_known_args()[0]
|
args = parser.parse_known_args()[0]
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ def test_a2c(args=get_args()):
|
|||||||
policy = A2CPolicy(
|
policy = A2CPolicy(
|
||||||
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
|
actor, critic, optim, dist, args.gamma, gae_lambda=args.gae_lambda,
|
||||||
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
|
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
|
||||||
max_grad_norm=args.max_grad_norm)
|
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
@ -102,6 +102,7 @@ def get_args():
|
|||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument('--rew-norm', type=bool, default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
@ -132,7 +133,8 @@ def test_pg(args=get_args()):
|
|||||||
net = net.to(args.device)
|
net = net.to(args.device)
|
||||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
dist = torch.distributions.Categorical
|
dist = torch.distributions.Categorical
|
||||||
policy = PGPolicy(net, optim, dist, args.gamma)
|
policy = PGPolicy(net, optim, dist, args.gamma,
|
||||||
|
reward_normalization=args.rew_norm)
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
@ -34,7 +34,8 @@ class A2CPolicy(PGPolicy):
|
|||||||
def __init__(self, actor, critic, optim,
|
def __init__(self, actor, critic, optim,
|
||||||
dist_fn=torch.distributions.Categorical,
|
dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
discount_factor=0.99, vf_coef=.5, ent_coef=.01,
|
||||||
max_grad_norm=None, gae_lambda=0.95, **kwargs):
|
max_grad_norm=None, gae_lambda=0.95,
|
||||||
|
reward_normalization=False, **kwargs):
|
||||||
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
super().__init__(None, optim, dist_fn, discount_factor, **kwargs)
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.critic = critic
|
self.critic = critic
|
||||||
@ -44,6 +45,8 @@ class A2CPolicy(PGPolicy):
|
|||||||
self._w_ent = ent_coef
|
self._w_ent = ent_coef
|
||||||
self._grad_norm = max_grad_norm
|
self._grad_norm = max_grad_norm
|
||||||
self._batch = 64
|
self._batch = 64
|
||||||
|
self._rew_norm = reward_normalization
|
||||||
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
if self._lambda in [0, 1]:
|
if self._lambda in [0, 1]:
|
||||||
@ -82,6 +85,9 @@ class A2CPolicy(PGPolicy):
|
|||||||
|
|
||||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
self._batch = batch_size
|
self._batch = batch_size
|
||||||
|
r = batch.returns
|
||||||
|
if self._rew_norm and r.std() > self.__eps:
|
||||||
|
batch.returns = (r - r.mean()) / r.std()
|
||||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
|
@ -3,8 +3,8 @@ import numpy as np
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from tianshou.data import Batch, PrioritizedReplayBuffer
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
|
from tianshou.data import Batch, PrioritizedReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
class DQNPolicy(BasePolicy):
|
class DQNPolicy(BasePolicy):
|
||||||
|
@ -21,14 +21,15 @@ class PGPolicy(BasePolicy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||||
discount_factor=0.99, **kwargs):
|
discount_factor=0.99, reward_normalization=False, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.dist_fn = dist_fn
|
self.dist_fn = dist_fn
|
||||||
self._eps = np.finfo(np.float32).eps.item()
|
|
||||||
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
assert 0 <= discount_factor <= 1, 'discount factor should in [0, 1]'
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
|
self._rew_norm = reward_normalization
|
||||||
|
self.__eps = np.finfo(np.float32).eps.item()
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
r"""Compute the discounted returns for each frame:
|
r"""Compute the discounted returns for each frame:
|
||||||
@ -71,7 +72,8 @@ class PGPolicy(BasePolicy):
|
|||||||
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
def learn(self, batch, batch_size=None, repeat=1, **kwargs):
|
||||||
losses = []
|
losses = []
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
batch.returns = (r - r.mean()) / (r.std() + self._eps)
|
if self._rew_norm and r.std() > self.__eps:
|
||||||
|
batch.returns = (r - r.mean()) / r.std()
|
||||||
for _ in range(repeat):
|
for _ in range(repeat):
|
||||||
for b in batch.split(batch_size):
|
for b in batch.split(batch_size):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user