finish pg
This commit is contained in:
parent
8b0b970c9b
commit
39de63592f
2
setup.py
2
setup.py
@ -41,8 +41,8 @@ setup(
|
||||
'gym',
|
||||
'tqdm',
|
||||
'numpy',
|
||||
'torch>=1.2.0', # for supporting tensorboard
|
||||
'cloudpickle',
|
||||
'tensorboard',
|
||||
'torch>=1.2.0', # for supporting tensorboard
|
||||
],
|
||||
)
|
||||
|
||||
@ -14,6 +14,9 @@ def test_batch():
|
||||
assert batch[0].obs == batch[1].obs
|
||||
with pytest.raises(IndexError):
|
||||
batch[2]
|
||||
batch.obs = np.arange(5)
|
||||
for i, b in enumerate(batch.split(1)):
|
||||
assert b.obs == batch[i].obs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -26,9 +26,7 @@ class Net(nn.Module):
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.Tensor(s)
|
||||
s = s.to(self.device)
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
q = self.model(s.view(batch, -1))
|
||||
return q, None
|
||||
@ -63,6 +61,7 @@ def test_dqn(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
@ -118,7 +117,7 @@ def test_dqn(args=get_args()):
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
@ -131,9 +130,11 @@ def test_dqn(args=get_args()):
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if args.task == 'CartPole-v0' and best_reward >= 200:
|
||||
if best_reward >= env.spec.reward_threshold:
|
||||
break
|
||||
assert best_reward >= 200
|
||||
assert best_reward >= env.spec.reward_threshold
|
||||
training_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
@ -143,18 +144,10 @@ def test_dqn(args=get_args()):
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
obs = env.reset()
|
||||
done = False
|
||||
total = 0
|
||||
while not done:
|
||||
q, _ = net([obs])
|
||||
action = q.max(dim=1)[1]
|
||||
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
|
||||
total += rew
|
||||
env.render()
|
||||
time.sleep(1 / 35)
|
||||
env.close()
|
||||
print(f'Final test: {total}')
|
||||
test_collector = Collector(policy, env, ReplayBuffer(1))
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
207
test/test_pg.py
Normal file
207
test/test_pg.py
Normal file
@ -0,0 +1,207 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
|
||||
|
||||
def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
||||
returns = np.zeros_like(batch.rew)
|
||||
last = 0
|
||||
for i in reversed(range(len(batch.rew))):
|
||||
returns[i] = batch.rew[i]
|
||||
if not batch.done[i]:
|
||||
returns[i] += last * gamma
|
||||
last = returns[i]
|
||||
batch.update(returns=returns)
|
||||
return batch
|
||||
|
||||
|
||||
def test_fn(size=2560):
|
||||
policy = PGPolicy(
|
||||
None, None, None, discount_factor=0.1, normalized_reward=False)
|
||||
fn = policy.process_fn
|
||||
# fn = compute_return_base
|
||||
batch = Batch(
|
||||
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
||||
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||
ans -= ans.mean()
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||
ans -= ans.mean()
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
batch = Batch(
|
||||
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||
)
|
||||
batch = fn(batch, None, None)
|
||||
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||
ans -= ans.mean()
|
||||
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||
if __name__ == '__main__':
|
||||
batch = Batch(
|
||||
done=np.random.randint(100, size=size) == 0,
|
||||
rew=np.random.random(size),
|
||||
)
|
||||
cnt = 3000
|
||||
t = time.time()
|
||||
for _ in range(cnt):
|
||||
compute_return_base(batch)
|
||||
print(f'vanilla: {(time.time() - t) / cnt}')
|
||||
t = time.time()
|
||||
for _ in range(cnt):
|
||||
policy.process_fn(batch, None, None)
|
||||
print(f'policy: {(time.time() - t) / cnt}')
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
nn.Linear(np.prod(state_shape), 128),
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
logits = self.model(s.view(batch, -1))
|
||||
return logits, None
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||
parser.add_argument('--collect-per-step', type=int, default=5)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=3)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--test-num', type=int, default=20)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_pg(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
dist = torch.distributions.Categorical
|
||||
policy = PGPolicy(net, optim, dist, args.gamma)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(
|
||||
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
start_time = time.time()
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f"Epoch #{epoch}"
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_episode=args.collect_per_step)
|
||||
losses = policy.learn(
|
||||
training_collector.sample(0), args.batch_size)
|
||||
global_step += len(losses)
|
||||
t.update(len(losses))
|
||||
stat_loss.add(losses)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'loss', stat_loss.get(), global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if best_reward >= env.spec.reward_threshold:
|
||||
break
|
||||
assert best_reward >= env.spec.reward_threshold
|
||||
training_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env, ReplayBuffer(1))
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# test_fn()
|
||||
test_pg()
|
||||
@ -38,3 +38,14 @@ class Batch(object):
|
||||
raise TypeError(
|
||||
'No support for append with type {} in class Batch.'
|
||||
.format(type(batch.__dict__[k])))
|
||||
|
||||
def split(self, size=None):
|
||||
length = min([
|
||||
len(self.__dict__[k]) for k in self.__dict__.keys()
|
||||
if self.__dict__[k] is not None])
|
||||
if size is None:
|
||||
size = length
|
||||
temp = 0
|
||||
while temp < length:
|
||||
yield self[temp:temp + size]
|
||||
temp += size
|
||||
|
||||
@ -59,7 +59,10 @@ class ReplayBuffer(object):
|
||||
if batch_size > 0:
|
||||
indice = np.random.choice(self._size, batch_size)
|
||||
else:
|
||||
indice = np.arange(self._size)
|
||||
indice = np.concatenate([
|
||||
np.arange(self._index, self._size),
|
||||
np.arange(0, self._index),
|
||||
])
|
||||
return Batch(
|
||||
obs=self.obs[indice],
|
||||
act=self.act[indice],
|
||||
|
||||
18
tianshou/env/wrapper.py
vendored
18
tianshou/env/wrapper.py
vendored
@ -121,8 +121,10 @@ class VectorEnv(BaseVectorEnv):
|
||||
np.stack(self._done), np.stack(self._info)
|
||||
|
||||
def seed(self, seed=None):
|
||||
if np.isscalar(seed) or seed is None:
|
||||
seed = [seed for _ in range(self.env_num)]
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
for e, s in zip(self.envs, seed):
|
||||
if hasattr(e, 'seed'):
|
||||
e.seed(s)
|
||||
@ -198,8 +200,10 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
return np.stack([p.recv() for p in self.parent_remote])
|
||||
|
||||
def seed(self, seed=None):
|
||||
if np.isscalar(seed) or seed is None:
|
||||
seed = [seed for _ in range(self.env_num)]
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
for p, s in zip(self.parent_remote, seed):
|
||||
p.send(['seed', s])
|
||||
for p in self.parent_remote:
|
||||
@ -272,8 +276,10 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
def seed(self, seed=None):
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
return
|
||||
if np.isscalar(seed) or seed is None:
|
||||
seed = [seed for _ in range(self.env_num)]
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||
for r in result_obj:
|
||||
ray.get(r)
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.dqn import DQNPolicy
|
||||
from tianshou.policy.policy_gradient import PGPolicy
|
||||
|
||||
__all__ = [
|
||||
'BasePolicy',
|
||||
'DQNPolicy',
|
||||
'PGPolicy',
|
||||
]
|
||||
|
||||
@ -8,17 +8,17 @@ class BasePolicy(ABC):
|
||||
super().__init__()
|
||||
self.model = None
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, batch, hidden_state=None):
|
||||
# return Batch(act=np.array(), state=None, ...)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, batch):
|
||||
pass
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, batch, state=None):
|
||||
# return Batch(logits=..., act=np.array(), state=None, ...)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, batch, batch_size=None):
|
||||
pass
|
||||
|
||||
def sync_weight(self):
|
||||
pass
|
||||
|
||||
@ -10,14 +10,14 @@ from tianshou.policy import BasePolicy
|
||||
class DQNPolicy(BasePolicy, nn.Module):
|
||||
"""docstring for DQNPolicy"""
|
||||
|
||||
def __init__(self, model, optim, loss,
|
||||
def __init__(self, model, optim, loss_fn,
|
||||
discount_factor=0.99,
|
||||
estimation_step=1,
|
||||
use_target_network=True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.loss = loss
|
||||
self.loss_fn = loss_fn
|
||||
self.eps = 0
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
@ -28,20 +28,6 @@ class DQNPolicy(BasePolicy, nn.Module):
|
||||
self.model_old = deepcopy(self.model)
|
||||
self.model_old.eval()
|
||||
|
||||
def __call__(self, batch, hidden_state=None,
|
||||
model='model', input='obs', eps=None):
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
|
||||
act = q.max(dim=1)[1].detach().cpu().numpy()
|
||||
# add eps to act
|
||||
if eps is None:
|
||||
eps = self.eps
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < eps:
|
||||
act[i] = np.random.randint(q.shape[1])
|
||||
return Batch(Q=q, act=act, state=h)
|
||||
|
||||
def set_eps(self, eps):
|
||||
self.eps = eps
|
||||
|
||||
@ -70,12 +56,12 @@ class DQNPolicy(BasePolicy, nn.Module):
|
||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||
a = self(buffer[terminal], input='obs_next', eps=0).act
|
||||
target_q = self(
|
||||
buffer[terminal], model='model_old', input='obs_next').Q
|
||||
buffer[terminal], model='model_old', input='obs_next').logits
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = target_q.detach().cpu().numpy()
|
||||
target_q = target_q[np.arange(len(a)), a]
|
||||
else:
|
||||
target_q = self(buffer[terminal], input='obs_next').Q
|
||||
target_q = self(buffer[terminal], input='obs_next').logits
|
||||
if isinstance(target_q, torch.Tensor):
|
||||
target_q = target_q.detach().cpu().numpy()
|
||||
target_q = target_q.max(axis=1)
|
||||
@ -84,14 +70,28 @@ class DQNPolicy(BasePolicy, nn.Module):
|
||||
batch.update(returns=returns)
|
||||
return batch
|
||||
|
||||
def learn(self, batch):
|
||||
def __call__(self, batch, state=None,
|
||||
model='model', input='obs', eps=None):
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
q, h = model(obs, state=state, info=batch.info)
|
||||
act = q.max(dim=1)[1].detach().cpu().numpy()
|
||||
# add eps to act
|
||||
if eps is None:
|
||||
eps = self.eps
|
||||
for i in range(len(q)):
|
||||
if np.random.rand() < eps:
|
||||
act[i] = np.random.randint(q.shape[1])
|
||||
return Batch(logits=q, act=act, state=h)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
self.optim.zero_grad()
|
||||
q = self(batch).Q
|
||||
q = self(batch).logits
|
||||
q = q[np.arange(len(q)), batch.act]
|
||||
r = batch.returns
|
||||
if isinstance(r, np.ndarray):
|
||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||
loss = self.loss(q, r)
|
||||
loss = self.loss_fn(q, r)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return loss.detach().cpu().numpy()
|
||||
|
||||
76
tianshou/policy/policy_gradient.py
Normal file
76
tianshou/policy/policy_gradient.py
Normal file
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy, nn.Module):
|
||||
"""docstring for PGPolicy"""
|
||||
|
||||
def __init__(self, model, optim, dist=Categorical,
|
||||
discount_factor=0.99, normalized_reward=True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.dist = dist
|
||||
self._eps = np.finfo(np.float32).eps.item()
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
self._gamma = discount_factor
|
||||
self._rew_norm = normalized_reward
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
batch_size = len(batch.rew)
|
||||
returns = self._vanilla_returns(batch, batch_size)
|
||||
# returns = self._vectorized_returns(batch, batch_size)
|
||||
returns = returns - returns.mean()
|
||||
if self._rew_norm:
|
||||
returns = returns / (returns.std() + self._eps)
|
||||
batch.update(returns=returns)
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist(logits)
|
||||
act = dist.sample().detach().cpu().numpy()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
losses = []
|
||||
for b in batch.split(batch_size):
|
||||
self.optim.zero_grad()
|
||||
dist = self(b).dist
|
||||
a = torch.tensor(b.act, device=dist.logits.device)
|
||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||
loss = -(dist.log_prob(a) * r).sum()
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
return losses
|
||||
|
||||
def _vanilla_returns(self, batch, batch_size):
|
||||
returns = batch.rew[:]
|
||||
last = 0
|
||||
for i in range(batch_size - 1, -1, -1):
|
||||
if not batch.done[i]:
|
||||
returns[i] += self._gamma * last
|
||||
last = returns[i]
|
||||
return returns
|
||||
|
||||
def _vectorized_returns(self, batch, batch_size):
|
||||
# according to my tests, it is slower than vanilla
|
||||
# import scipy.signal
|
||||
convolve = np.convolve
|
||||
# convolve = scipy.signal.convolve
|
||||
rew = batch.rew[::-1]
|
||||
gammas = self._gamma ** np.arange(batch_size)
|
||||
c = convolve(rew, gammas)[:batch_size]
|
||||
T = np.where(batch.done[::-1])[0]
|
||||
d = np.zeros_like(rew)
|
||||
d[T] += c[T] - rew[T]
|
||||
d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
|
||||
return (c - convolve(d, gammas)[:batch_size])[::-1]
|
||||
@ -11,7 +11,11 @@ class MovAvg(object):
|
||||
def add(self, x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
x = x.detach().cpu().numpy()
|
||||
if x != np.inf:
|
||||
if isinstance(x, list):
|
||||
for _ in x:
|
||||
if _ != np.inf:
|
||||
self.cache.append(_)
|
||||
elif x != np.inf:
|
||||
self.cache.append(x)
|
||||
if self.size > 0 and len(self.cache) > self.size:
|
||||
self.cache = self.cache[-self.size:]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user