finish pg
This commit is contained in:
parent
8b0b970c9b
commit
39de63592f
2
setup.py
2
setup.py
@ -41,8 +41,8 @@ setup(
|
|||||||
'gym',
|
'gym',
|
||||||
'tqdm',
|
'tqdm',
|
||||||
'numpy',
|
'numpy',
|
||||||
'torch>=1.2.0', # for supporting tensorboard
|
|
||||||
'cloudpickle',
|
'cloudpickle',
|
||||||
'tensorboard',
|
'tensorboard',
|
||||||
|
'torch>=1.2.0', # for supporting tensorboard
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -14,6 +14,9 @@ def test_batch():
|
|||||||
assert batch[0].obs == batch[1].obs
|
assert batch[0].obs == batch[1].obs
|
||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
batch[2]
|
batch[2]
|
||||||
|
batch.obs = np.arange(5)
|
||||||
|
for i, b in enumerate(batch.split(1)):
|
||||||
|
assert b.obs == batch[i].obs
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -26,9 +26,7 @@ class Net(nn.Module):
|
|||||||
self.model = nn.Sequential(*self.model)
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
def forward(self, s, **kwargs):
|
def forward(self, s, **kwargs):
|
||||||
if not isinstance(s, torch.Tensor):
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
s = torch.Tensor(s)
|
|
||||||
s = s.to(self.device)
|
|
||||||
batch = s.shape[0]
|
batch = s.shape[0]
|
||||||
q = self.model(s.view(batch, -1))
|
q = self.model(s.view(batch, -1))
|
||||||
return q, None
|
return q, None
|
||||||
@ -63,6 +61,7 @@ def test_dqn(args=get_args()):
|
|||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
train_envs = SubprocVectorEnv(
|
train_envs = SubprocVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||||
reset_after_done=True)
|
reset_after_done=True)
|
||||||
@ -118,7 +117,7 @@ def test_dqn(args=get_args()):
|
|||||||
'speed', result['speed'], global_step=global_step)
|
'speed', result['speed'], global_step=global_step)
|
||||||
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||||
reward=f'{result["reward"]:.6f}',
|
reward=f'{result["reward"]:.6f}',
|
||||||
length=f'{result["length"]:.6f}',
|
length=f'{result["length"]:.2f}',
|
||||||
speed=f'{result["speed"]:.2f}')
|
speed=f'{result["speed"]:.2f}')
|
||||||
# eval
|
# eval
|
||||||
test_collector.reset_env()
|
test_collector.reset_env()
|
||||||
@ -131,9 +130,11 @@ def test_dqn(args=get_args()):
|
|||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
if args.task == 'CartPole-v0' and best_reward >= 200:
|
if best_reward >= env.spec.reward_threshold:
|
||||||
break
|
break
|
||||||
assert best_reward >= 200
|
assert best_reward >= env.spec.reward_threshold
|
||||||
|
training_collector.close()
|
||||||
|
test_collector.close()
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_cnt = training_collector.collect_step
|
train_cnt = training_collector.collect_step
|
||||||
test_cnt = test_collector.collect_step
|
test_cnt = test_collector.collect_step
|
||||||
@ -143,18 +144,10 @@ def test_dqn(args=get_args()):
|
|||||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||||
# Let's watch its performance!
|
# Let's watch its performance!
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
obs = env.reset()
|
test_collector = Collector(policy, env, ReplayBuffer(1))
|
||||||
done = False
|
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||||
total = 0
|
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||||
while not done:
|
test_collector.close()
|
||||||
q, _ = net([obs])
|
|
||||||
action = q.max(dim=1)[1]
|
|
||||||
obs, rew, done, info = env.step(action[0].detach().cpu().numpy())
|
|
||||||
total += rew
|
|
||||||
env.render()
|
|
||||||
time.sleep(1 / 35)
|
|
||||||
env.close()
|
|
||||||
print(f'Final test: {total}')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
207
test/test_pg.py
Normal file
207
test/test_pg.py
Normal file
@ -0,0 +1,207 @@
|
|||||||
|
import gym
|
||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import PGPolicy
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.utils import tqdm_config, MovAvg
|
||||||
|
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
def compute_return_base(batch, aa=None, bb=None, gamma=0.1):
|
||||||
|
returns = np.zeros_like(batch.rew)
|
||||||
|
last = 0
|
||||||
|
for i in reversed(range(len(batch.rew))):
|
||||||
|
returns[i] = batch.rew[i]
|
||||||
|
if not batch.done[i]:
|
||||||
|
returns[i] += last * gamma
|
||||||
|
last = returns[i]
|
||||||
|
batch.update(returns=returns)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_fn(size=2560):
|
||||||
|
policy = PGPolicy(
|
||||||
|
None, None, None, discount_factor=0.1, normalized_reward=False)
|
||||||
|
fn = policy.process_fn
|
||||||
|
# fn = compute_return_base
|
||||||
|
batch = Batch(
|
||||||
|
done=np.array([1, 0, 0, 1, 0, 1, 0, 1.]),
|
||||||
|
rew=np.array([0, 1, 2, 3, 4, 5, 6, 7.]),
|
||||||
|
)
|
||||||
|
batch = fn(batch, None, None)
|
||||||
|
ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7])
|
||||||
|
ans -= ans.mean()
|
||||||
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
|
batch = Batch(
|
||||||
|
done=np.array([0, 1, 0, 1, 0, 1, 0.]),
|
||||||
|
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||||
|
)
|
||||||
|
batch = fn(batch, None, None)
|
||||||
|
ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5])
|
||||||
|
ans -= ans.mean()
|
||||||
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
|
batch = Batch(
|
||||||
|
done=np.array([0, 1, 0, 1, 0, 0, 1.]),
|
||||||
|
rew=np.array([7, 6, 1, 2, 3, 4, 5.]),
|
||||||
|
)
|
||||||
|
batch = fn(batch, None, None)
|
||||||
|
ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5])
|
||||||
|
ans -= ans.mean()
|
||||||
|
assert abs(batch.returns - ans).sum() <= 1e-5
|
||||||
|
if __name__ == '__main__':
|
||||||
|
batch = Batch(
|
||||||
|
done=np.random.randint(100, size=size) == 0,
|
||||||
|
rew=np.random.random(size),
|
||||||
|
)
|
||||||
|
cnt = 3000
|
||||||
|
t = time.time()
|
||||||
|
for _ in range(cnt):
|
||||||
|
compute_return_base(batch)
|
||||||
|
print(f'vanilla: {(time.time() - t) / cnt}')
|
||||||
|
t = time.time()
|
||||||
|
for _ in range(cnt):
|
||||||
|
policy.process_fn(batch, None, None)
|
||||||
|
print(f'policy: {(time.time() - t) / cnt}')
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.model = [
|
||||||
|
nn.Linear(np.prod(state_shape), 128),
|
||||||
|
nn.ReLU(inplace=True)]
|
||||||
|
for i in range(layer_num):
|
||||||
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
|
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||||
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
|
def forward(self, s, **kwargs):
|
||||||
|
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||||
|
batch = s.shape[0]
|
||||||
|
logits = self.model(s.view(batch, -1))
|
||||||
|
return logits, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=1626)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||||
|
parser.add_argument('--collect-per-step', type=int, default=5)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=20)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_pg(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# train_envs = gym.make(args.task)
|
||||||
|
train_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||||
|
reset_after_done=True)
|
||||||
|
# test_envs = gym.make(args.task)
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||||
|
reset_after_done=False)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||||
|
net = net.to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
dist = torch.distributions.Categorical
|
||||||
|
policy = PGPolicy(net, optim, dist, args.gamma)
|
||||||
|
# collector
|
||||||
|
training_collector = Collector(
|
||||||
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
test_collector = Collector(
|
||||||
|
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||||
|
# log
|
||||||
|
stat_loss = MovAvg()
|
||||||
|
global_step = 0
|
||||||
|
writer = SummaryWriter(args.logdir)
|
||||||
|
best_epoch = -1
|
||||||
|
best_reward = -1e10
|
||||||
|
start_time = time.time()
|
||||||
|
for epoch in range(1, 1 + args.epoch):
|
||||||
|
desc = f"Epoch #{epoch}"
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
|
with tqdm.tqdm(
|
||||||
|
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||||
|
while t.n < t.total:
|
||||||
|
result = training_collector.collect(
|
||||||
|
n_episode=args.collect_per_step)
|
||||||
|
losses = policy.learn(
|
||||||
|
training_collector.sample(0), args.batch_size)
|
||||||
|
global_step += len(losses)
|
||||||
|
t.update(len(losses))
|
||||||
|
stat_loss.add(losses)
|
||||||
|
writer.add_scalar(
|
||||||
|
'reward', result['reward'], global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'length', result['length'], global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'loss', stat_loss.get(), global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'speed', result['speed'], global_step=global_step)
|
||||||
|
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||||
|
reward=f'{result["reward"]:.6f}',
|
||||||
|
length=f'{result["length"]:.2f}',
|
||||||
|
speed=f'{result["speed"]:.2f}')
|
||||||
|
# eval
|
||||||
|
test_collector.reset_env()
|
||||||
|
test_collector.reset_buffer()
|
||||||
|
policy.eval()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num)
|
||||||
|
if best_reward < result['reward']:
|
||||||
|
best_reward = result['reward']
|
||||||
|
best_epoch = epoch
|
||||||
|
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||||
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
|
if best_reward >= env.spec.reward_threshold:
|
||||||
|
break
|
||||||
|
assert best_reward >= env.spec.reward_threshold
|
||||||
|
training_collector.close()
|
||||||
|
test_collector.close()
|
||||||
|
if __name__ == '__main__':
|
||||||
|
train_cnt = training_collector.collect_step
|
||||||
|
test_cnt = test_collector.collect_step
|
||||||
|
duration = time.time() - start_time
|
||||||
|
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||||
|
f'in {duration:.2f}s, '
|
||||||
|
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||||
|
# Let's watch its performance!
|
||||||
|
env = gym.make(args.task)
|
||||||
|
test_collector = Collector(policy, env, ReplayBuffer(1))
|
||||||
|
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||||
|
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||||
|
test_collector.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# test_fn()
|
||||||
|
test_pg()
|
||||||
@ -38,3 +38,14 @@ class Batch(object):
|
|||||||
raise TypeError(
|
raise TypeError(
|
||||||
'No support for append with type {} in class Batch.'
|
'No support for append with type {} in class Batch.'
|
||||||
.format(type(batch.__dict__[k])))
|
.format(type(batch.__dict__[k])))
|
||||||
|
|
||||||
|
def split(self, size=None):
|
||||||
|
length = min([
|
||||||
|
len(self.__dict__[k]) for k in self.__dict__.keys()
|
||||||
|
if self.__dict__[k] is not None])
|
||||||
|
if size is None:
|
||||||
|
size = length
|
||||||
|
temp = 0
|
||||||
|
while temp < length:
|
||||||
|
yield self[temp:temp + size]
|
||||||
|
temp += size
|
||||||
|
|||||||
@ -59,7 +59,10 @@ class ReplayBuffer(object):
|
|||||||
if batch_size > 0:
|
if batch_size > 0:
|
||||||
indice = np.random.choice(self._size, batch_size)
|
indice = np.random.choice(self._size, batch_size)
|
||||||
else:
|
else:
|
||||||
indice = np.arange(self._size)
|
indice = np.concatenate([
|
||||||
|
np.arange(self._index, self._size),
|
||||||
|
np.arange(0, self._index),
|
||||||
|
])
|
||||||
return Batch(
|
return Batch(
|
||||||
obs=self.obs[indice],
|
obs=self.obs[indice],
|
||||||
act=self.act[indice],
|
act=self.act[indice],
|
||||||
|
|||||||
18
tianshou/env/wrapper.py
vendored
18
tianshou/env/wrapper.py
vendored
@ -121,8 +121,10 @@ class VectorEnv(BaseVectorEnv):
|
|||||||
np.stack(self._done), np.stack(self._info)
|
np.stack(self._done), np.stack(self._info)
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if np.isscalar(seed) or seed is None:
|
if np.isscalar(seed):
|
||||||
seed = [seed for _ in range(self.env_num)]
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
|
elif seed is None:
|
||||||
|
seed = [seed] * self.env_num
|
||||||
for e, s in zip(self.envs, seed):
|
for e, s in zip(self.envs, seed):
|
||||||
if hasattr(e, 'seed'):
|
if hasattr(e, 'seed'):
|
||||||
e.seed(s)
|
e.seed(s)
|
||||||
@ -198,8 +200,10 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
return np.stack([p.recv() for p in self.parent_remote])
|
return np.stack([p.recv() for p in self.parent_remote])
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if np.isscalar(seed) or seed is None:
|
if np.isscalar(seed):
|
||||||
seed = [seed for _ in range(self.env_num)]
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
|
elif seed is None:
|
||||||
|
seed = [seed] * self.env_num
|
||||||
for p, s in zip(self.parent_remote, seed):
|
for p, s in zip(self.parent_remote, seed):
|
||||||
p.send(['seed', s])
|
p.send(['seed', s])
|
||||||
for p in self.parent_remote:
|
for p in self.parent_remote:
|
||||||
@ -272,8 +276,10 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
if not hasattr(self.envs[0], 'seed'):
|
if not hasattr(self.envs[0], 'seed'):
|
||||||
return
|
return
|
||||||
if np.isscalar(seed) or seed is None:
|
if np.isscalar(seed):
|
||||||
seed = [seed for _ in range(self.env_num)]
|
seed = [seed + _ for _ in range(self.env_num)]
|
||||||
|
elif seed is None:
|
||||||
|
seed = [seed] * self.env_num
|
||||||
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
result_obj = [e.seed.remote(s) for e, s in zip(self.envs, seed)]
|
||||||
for r in result_obj:
|
for r in result_obj:
|
||||||
ray.get(r)
|
ray.get(r)
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
from tianshou.policy.base import BasePolicy
|
from tianshou.policy.base import BasePolicy
|
||||||
from tianshou.policy.dqn import DQNPolicy
|
from tianshou.policy.dqn import DQNPolicy
|
||||||
|
from tianshou.policy.policy_gradient import PGPolicy
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BasePolicy',
|
'BasePolicy',
|
||||||
'DQNPolicy',
|
'DQNPolicy',
|
||||||
|
'PGPolicy',
|
||||||
]
|
]
|
||||||
|
|||||||
@ -8,17 +8,17 @@ class BasePolicy(ABC):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __call__(self, batch, hidden_state=None):
|
|
||||||
# return Batch(act=np.array(), state=None, ...)
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def learn(self, batch):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, batch, state=None):
|
||||||
|
# return Batch(logits=..., act=np.array(), state=None, ...)
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def learn(self, batch, batch_size=None):
|
||||||
|
pass
|
||||||
|
|
||||||
def sync_weight(self):
|
def sync_weight(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -10,14 +10,14 @@ from tianshou.policy import BasePolicy
|
|||||||
class DQNPolicy(BasePolicy, nn.Module):
|
class DQNPolicy(BasePolicy, nn.Module):
|
||||||
"""docstring for DQNPolicy"""
|
"""docstring for DQNPolicy"""
|
||||||
|
|
||||||
def __init__(self, model, optim, loss,
|
def __init__(self, model, optim, loss_fn,
|
||||||
discount_factor=0.99,
|
discount_factor=0.99,
|
||||||
estimation_step=1,
|
estimation_step=1,
|
||||||
use_target_network=True):
|
use_target_network=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optim = optim
|
self.optim = optim
|
||||||
self.loss = loss
|
self.loss_fn = loss_fn
|
||||||
self.eps = 0
|
self.eps = 0
|
||||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
@ -28,20 +28,6 @@ class DQNPolicy(BasePolicy, nn.Module):
|
|||||||
self.model_old = deepcopy(self.model)
|
self.model_old = deepcopy(self.model)
|
||||||
self.model_old.eval()
|
self.model_old.eval()
|
||||||
|
|
||||||
def __call__(self, batch, hidden_state=None,
|
|
||||||
model='model', input='obs', eps=None):
|
|
||||||
model = getattr(self, model)
|
|
||||||
obs = getattr(batch, input)
|
|
||||||
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
|
|
||||||
act = q.max(dim=1)[1].detach().cpu().numpy()
|
|
||||||
# add eps to act
|
|
||||||
if eps is None:
|
|
||||||
eps = self.eps
|
|
||||||
for i in range(len(q)):
|
|
||||||
if np.random.rand() < eps:
|
|
||||||
act[i] = np.random.randint(q.shape[1])
|
|
||||||
return Batch(Q=q, act=act, state=h)
|
|
||||||
|
|
||||||
def set_eps(self, eps):
|
def set_eps(self, eps):
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
@ -70,12 +56,12 @@ class DQNPolicy(BasePolicy, nn.Module):
|
|||||||
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||||
a = self(buffer[terminal], input='obs_next', eps=0).act
|
a = self(buffer[terminal], input='obs_next', eps=0).act
|
||||||
target_q = self(
|
target_q = self(
|
||||||
buffer[terminal], model='model_old', input='obs_next').Q
|
buffer[terminal], model='model_old', input='obs_next').logits
|
||||||
if isinstance(target_q, torch.Tensor):
|
if isinstance(target_q, torch.Tensor):
|
||||||
target_q = target_q.detach().cpu().numpy()
|
target_q = target_q.detach().cpu().numpy()
|
||||||
target_q = target_q[np.arange(len(a)), a]
|
target_q = target_q[np.arange(len(a)), a]
|
||||||
else:
|
else:
|
||||||
target_q = self(buffer[terminal], input='obs_next').Q
|
target_q = self(buffer[terminal], input='obs_next').logits
|
||||||
if isinstance(target_q, torch.Tensor):
|
if isinstance(target_q, torch.Tensor):
|
||||||
target_q = target_q.detach().cpu().numpy()
|
target_q = target_q.detach().cpu().numpy()
|
||||||
target_q = target_q.max(axis=1)
|
target_q = target_q.max(axis=1)
|
||||||
@ -84,14 +70,28 @@ class DQNPolicy(BasePolicy, nn.Module):
|
|||||||
batch.update(returns=returns)
|
batch.update(returns=returns)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def learn(self, batch):
|
def __call__(self, batch, state=None,
|
||||||
|
model='model', input='obs', eps=None):
|
||||||
|
model = getattr(self, model)
|
||||||
|
obs = getattr(batch, input)
|
||||||
|
q, h = model(obs, state=state, info=batch.info)
|
||||||
|
act = q.max(dim=1)[1].detach().cpu().numpy()
|
||||||
|
# add eps to act
|
||||||
|
if eps is None:
|
||||||
|
eps = self.eps
|
||||||
|
for i in range(len(q)):
|
||||||
|
if np.random.rand() < eps:
|
||||||
|
act[i] = np.random.randint(q.shape[1])
|
||||||
|
return Batch(logits=q, act=act, state=h)
|
||||||
|
|
||||||
|
def learn(self, batch, batch_size=None):
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
q = self(batch).Q
|
q = self(batch).logits
|
||||||
q = q[np.arange(len(q)), batch.act]
|
q = q[np.arange(len(q)), batch.act]
|
||||||
r = batch.returns
|
r = batch.returns
|
||||||
if isinstance(r, np.ndarray):
|
if isinstance(r, np.ndarray):
|
||||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||||
loss = self.loss(q, r)
|
loss = self.loss_fn(q, r)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
return loss.detach().cpu().numpy()
|
return loss.detach().cpu().numpy()
|
||||||
|
|||||||
76
tianshou/policy/policy_gradient.py
Normal file
76
tianshou/policy/policy_gradient.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.distributions import Categorical
|
||||||
|
|
||||||
|
from tianshou.data import Batch
|
||||||
|
from tianshou.policy import BasePolicy
|
||||||
|
|
||||||
|
|
||||||
|
class PGPolicy(BasePolicy, nn.Module):
|
||||||
|
"""docstring for PGPolicy"""
|
||||||
|
|
||||||
|
def __init__(self, model, optim, dist=Categorical,
|
||||||
|
discount_factor=0.99, normalized_reward=True):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.optim = optim
|
||||||
|
self.dist = dist
|
||||||
|
self._eps = np.finfo(np.float32).eps.item()
|
||||||
|
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||||
|
self._gamma = discount_factor
|
||||||
|
self._rew_norm = normalized_reward
|
||||||
|
|
||||||
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
batch_size = len(batch.rew)
|
||||||
|
returns = self._vanilla_returns(batch, batch_size)
|
||||||
|
# returns = self._vectorized_returns(batch, batch_size)
|
||||||
|
returns = returns - returns.mean()
|
||||||
|
if self._rew_norm:
|
||||||
|
returns = returns / (returns.std() + self._eps)
|
||||||
|
batch.update(returns=returns)
|
||||||
|
return batch
|
||||||
|
|
||||||
|
def __call__(self, batch, state=None):
|
||||||
|
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||||
|
logits = F.softmax(logits, dim=1)
|
||||||
|
dist = self.dist(logits)
|
||||||
|
act = dist.sample().detach().cpu().numpy()
|
||||||
|
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||||
|
|
||||||
|
def learn(self, batch, batch_size=None):
|
||||||
|
losses = []
|
||||||
|
for b in batch.split(batch_size):
|
||||||
|
self.optim.zero_grad()
|
||||||
|
dist = self(b).dist
|
||||||
|
a = torch.tensor(b.act, device=dist.logits.device)
|
||||||
|
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||||
|
loss = -(dist.log_prob(a) * r).sum()
|
||||||
|
loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
losses.append(loss.detach().cpu().numpy())
|
||||||
|
return losses
|
||||||
|
|
||||||
|
def _vanilla_returns(self, batch, batch_size):
|
||||||
|
returns = batch.rew[:]
|
||||||
|
last = 0
|
||||||
|
for i in range(batch_size - 1, -1, -1):
|
||||||
|
if not batch.done[i]:
|
||||||
|
returns[i] += self._gamma * last
|
||||||
|
last = returns[i]
|
||||||
|
return returns
|
||||||
|
|
||||||
|
def _vectorized_returns(self, batch, batch_size):
|
||||||
|
# according to my tests, it is slower than vanilla
|
||||||
|
# import scipy.signal
|
||||||
|
convolve = np.convolve
|
||||||
|
# convolve = scipy.signal.convolve
|
||||||
|
rew = batch.rew[::-1]
|
||||||
|
gammas = self._gamma ** np.arange(batch_size)
|
||||||
|
c = convolve(rew, gammas)[:batch_size]
|
||||||
|
T = np.where(batch.done[::-1])[0]
|
||||||
|
d = np.zeros_like(rew)
|
||||||
|
d[T] += c[T] - rew[T]
|
||||||
|
d[T[1:]] -= d[T[:-1]] * self._gamma ** np.diff(T)
|
||||||
|
return (c - convolve(d, gammas)[:batch_size])[::-1]
|
||||||
@ -11,7 +11,11 @@ class MovAvg(object):
|
|||||||
def add(self, x):
|
def add(self, x):
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
x = x.detach().cpu().numpy()
|
x = x.detach().cpu().numpy()
|
||||||
if x != np.inf:
|
if isinstance(x, list):
|
||||||
|
for _ in x:
|
||||||
|
if _ != np.inf:
|
||||||
|
self.cache.append(_)
|
||||||
|
elif x != np.inf:
|
||||||
self.cache.append(x)
|
self.cache.append(x)
|
||||||
if self.size > 0 and len(self.cache) > self.size:
|
if self.size > 0 and len(self.cache) > self.size:
|
||||||
self.cache = self.cache[-self.size:]
|
self.cache = self.cache[-self.size:]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user