Yet another 3 fix (#160)

1. DQN learn should keep eps=0

2. Add a warning of env.seed in VecEnv

3. fix #162 of multi-dim action
This commit is contained in:
n+e 2020-07-24 17:38:12 +08:00 committed by GitHub
parent bfeffe1f97
commit 38a95c19da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 67 additions and 14 deletions

View File

@ -48,6 +48,17 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)
.. warning::
If you use your own environment, please make sure the ``seed`` method is set up properly, e.g.,
::
def seed(self, seed):
np.random.seed(seed)
Otherwise, the outputs of these envs may be the same with each other.
.. _preprocess_fn:
Handle Batched Data Stream in Collector

View File

@ -40,6 +40,17 @@ Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_
For the demonstration, here we use the second block of codes.
.. warning::
If you use your own environment, please make sure the ``seed`` method is set up properly, e.g.,
::
def seed(self, seed):
np.random.seed(seed)
Otherwise, the outputs of these envs may be the same with each other.
.. _build_the_network:
Build the Network

View File

@ -1,20 +1,30 @@
import gym
import time
from gym.spaces.discrete import Discrete
import numpy as np
from gym.spaces import Discrete, MultiDiscrete, Box
class MyTestEnv(gym.Env):
"""This is a "going right" task. The task is to go right ``size`` steps.
"""
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0,
multidiscrete_action=False):
self.size = size
self.sleep = sleep
self.dict_state = dict_state
self.ma_rew = ma_rew
self.action_space = Discrete(2)
self._md_action = multidiscrete_action
self.observation_space = Box(shape=(1, ), low=0, high=size - 1)
if multidiscrete_action:
self.action_space = MultiDiscrete([2, 2])
else:
self.action_space = Discrete(2)
self.reset()
def seed(self, seed=0):
np.random.seed(seed)
def reset(self, state=0):
self.done = False
self.index = state
@ -29,9 +39,12 @@ class MyTestEnv(gym.Env):
def _get_dict_state(self):
"""Generate a dict_state if dict_state is True."""
return {'index': self.index} if self.dict_state else self.index
return {'index': self.index, 'rand': np.random.rand()} \
if self.dict_state else self.index
def step(self, action):
if self._md_action:
action = action[0]
if self.done:
raise ValueError('step after done !!!')
if self.sleep > 0:

View File

@ -1,5 +1,5 @@
import torch
import copy
import torch
import pickle
import pytest
import numpy as np

View File

@ -113,8 +113,12 @@ def test_collector_with_dict_state():
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns)
envs.seed(666)
obs = envs.reset()
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
c1 = Collector(policy, envs, ReplayBuffer(size=100),
Logger.single_preprocess_fn)
c1.seed(0)
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
batch = c1.sample(10)

View File

@ -86,7 +86,7 @@ def get_args():
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)

View File

@ -27,6 +27,17 @@ class BaseVectorEnv(ABC, gym.Env):
obs, rew, done, info = envs.step([1] * 8) # step synchronously
envs.render() # render all environments
envs.close() # close all environments
.. warning::
If you use your own environment, please make sure the ``seed`` method
is set up properly, e.g.,
::
def seed(self, seed):
np.random.seed(seed)
Otherwise, the outputs of these envs may be the same with each other.
"""
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:

View File

@ -108,8 +108,9 @@ class A2CPolicy(PGPolicy):
v = self.critic(b.obs).flatten()
a = to_torch_as(b.act, v)
r = to_torch_as(b.returns, v)
a_loss = -(dist.log_prob(a).flatten() * (r - v).detach()
).mean()
log_prob = dist.log_prob(a).reshape(
r.shape[0], -1).transpose(0, 1)
a_loss = -(log_prob * (r - v).detach()).mean()
vf_loss = F.mse_loss(r, v)
ent_loss = dist.entropy().mean()
loss = a_loss + self._w_vf * vf_loss - self._w_ent * ent_loss

View File

@ -157,7 +157,7 @@ class DQNPolicy(BasePolicy):
if self._target and self._cnt % self._freq == 0:
self.sync_weight()
self.optim.zero_grad()
q = self(batch).logits
q = self(batch, eps=0.).logits
q = q[np.arange(len(q)), batch.act]
r = to_torch_as(batch.returns, q).flatten()
if hasattr(batch, 'update_weight'):

View File

@ -90,7 +90,9 @@ class PGPolicy(BasePolicy):
dist = self(b).dist
a = to_torch_as(b.act, dist.logits)
r = to_torch_as(b.returns, dist.logits)
loss = -(dist.log_prob(a).flatten() * r).sum()
log_prob = dist.log_prob(a).reshape(
r.shape[0], -1).transpose(0, 1)
loss = -(log_prob * r).mean()
loss.backward()
self.optim.step()
losses.append(loss.item())

View File

@ -132,7 +132,7 @@ class PPOPolicy(PGPolicy):
to_torch_as(b.act, v[0])))
batch.v = torch.cat(v, dim=0).flatten() # old value
batch.act = to_torch_as(batch.act, v[0])
batch.logp_old = torch.cat(old_log_prob, dim=0).flatten()
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = to_torch_as(batch.returns, v[0])
if self._rew_norm:
mean, std = batch.returns.mean(), batch.returns.std()
@ -147,8 +147,8 @@ class PPOPolicy(PGPolicy):
for b in batch.split(batch_size):
dist = self(b).dist
value = self.critic(b.obs).flatten()
ratio = (dist.log_prob(b.act).flatten()
- b.logp_old).exp().float()
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * b.adv
surr2 = ratio.clamp(1. - self._eps_clip,
1. + self._eps_clip) * b.adv

View File

@ -36,7 +36,7 @@ class Net(nn.Module):
def forward(self, s, state=None, info={}):
"""s -> flatten -> logits"""
s = to_torch(s, device=self.device, dtype=torch.float32)
s = s.flatten(1)
s = s.reshape(s.size(0), -1)
logits = self.model(s)
return logits, state