This commit is contained in:
Trinkle23897 2020-03-18 21:45:41 +08:00
parent 6e563fe61a
commit 64bab0b6a0
18 changed files with 396 additions and 50 deletions

View File

@ -33,8 +33,8 @@ jobs:
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics
flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
- name: Test with pytest
run: |
pip install pytest pytest-cov
pytest --cov tianshou
pytest --cov tianshou -s

View File

@ -37,12 +37,11 @@ setup(
'examples', 'examples.*',
'docs', 'docs.*']),
install_requires=[
# 'ray',
'gym',
'tqdm',
'numpy',
'cloudpickle',
'tensorboard',
'torch>=1.2.0', # for supporting tensorboard
'torch>=1.4.0',
],
)

View File

@ -41,14 +41,14 @@ def get_args():
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--training-num', type=int, default=32)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
@ -57,6 +57,7 @@ def get_args():
# a2c special
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--entropy-coef', type=float, default=0.001)
parser.add_argument('--max-grad-norm', type=float, default=None)
args = parser.parse_known_args()[0]
return args
@ -86,12 +87,12 @@ def test_a2c(args=get_args()):
policy = A2CPolicy(
net, optim, dist, args.gamma,
vf_coef=args.vf_coef,
entropy_coef=args.entropy_coef)
entropy_coef=args.entropy_coef,
max_grad_norm=args.max_grad_norm)
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
# log
stat_loss = MovAvg()
global_step = 0
@ -126,6 +127,8 @@ def test_a2c(args=get_args()):
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()

206
test/test_ddpg.py Normal file
View File

@ -0,0 +1,206 @@
import gym
import time
import tqdm
import torch
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import DDPGPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg
from tianshou.data import Collector, ReplayBuffer
class Actor(nn.Module):
def __init__(self, layer_num, state_shape, action_shape,
max_action, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, np.prod(action_shape))]
self.model = nn.Sequential(*self.model)
self._max = max_action
def forward(self, s, **kwargs):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.model(s)
logits = self._max * torch.tanh(logits)
return logits, None
class Critic(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.model += [nn.Linear(128, 1)]
self.model = nn.Sequential(*self.model)
def forward(self, s, a):
s = torch.tensor(s, device=self.device, dtype=torch.float)
if isinstance(a, np.ndarray):
a = torch.tensor(a, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
a = a.view(batch, -1)
logits = self.model(torch.cat([s, a], dim=1))
return logits
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Pendulum-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--actor-lr', type=float, default=1e-4)
parser.add_argument('--actor-wd', type=float, default=0)
parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--critic-wd', type=float, default=1e-2)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--exploration-noise', type=float, default=0.1)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=2400)
parser.add_argument('--collect-per-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=128)
parser.add_argument('--layer-num', type=int, default=1)
parser.add_argument('--training-num', type=int, default=1)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_known_args()[0]
return args
def test_ddpg(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
actor = Actor(
args.layer_num, args.state_shape, args.action_shape,
args.max_action, args.device
).to(args.device)
actor_optim = torch.optim.Adam(
actor.parameters(), lr=args.actor_lr, weight_decay=args.actor_wd)
critic = Critic(
args.layer_num, args.state_shape, args.action_shape, args.device
).to(args.device)
critic_optim = torch.optim.Adam(
critic.parameters(), lr=args.critic_lr, weight_decay=args.critic_wd)
policy = DDPGPolicy(
actor, actor_optim, critic, critic_optim,
[env.action_space.low[0], env.action_space.high[0]],
args.tau, args.gamma, args.exploration_noise)
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
# log
stat_a_loss = MovAvg()
stat_c_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10
start_time = time.time()
# training_collector.collect(n_step=1000)
for epoch in range(1, 1 + args.epoch):
desc = f'Epoch #{epoch}'
# train
policy.train()
with tqdm.tqdm(
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
while t.n < t.total:
result = training_collector.collect(
n_step=args.collect_per_step)
for i in range(min(
result['n_step'] // args.collect_per_step,
t.total - t.n)):
t.update(1)
global_step += 1
actor_loss, critic_loss = policy.learn(
training_collector.sample(args.batch_size))
policy.sync_weight()
stat_a_loss.add(actor_loss)
stat_c_loss.add(critic_loss)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'actor_loss', stat_a_loss.get(),
global_step=global_step)
writer.add_scalar(
'critic_loss', stat_a_loss.get(),
global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(actor_loss=f'{stat_a_loss.get():.6f}',
critic_loss=f'{stat_c_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if args.task == 'Pendulum-v0' and best_reward >= -250:
break
if args.task == 'Pendulum-v0':
assert best_reward >= -250
training_collector.close()
test_collector.close()
if __name__ == '__main__':
train_cnt = training_collector.collect_step
test_cnt = test_collector.collect_step
duration = time.time() - start_time
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task)
test_collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
test_collector.close()
if __name__ == '__main__':
test_ddpg()

View File

@ -78,13 +78,11 @@ def test_dqn(args=get_args()):
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
loss = nn.MSELoss()
policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step)
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
training_collector.collect(n_step=args.batch_size)
# log
stat_loss = MovAvg()
@ -124,6 +122,8 @@ def test_dqn(args=get_args()):
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()

View File

@ -133,8 +133,7 @@ def test_pg(args=get_args()):
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
# log
stat_loss = MovAvg()
global_step = 0
@ -169,6 +168,8 @@ def test_pg(args=get_args()):
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
if t.n <= t.total:
t.update()
# eval
test_collector.reset_env()
test_collector.reset_buffer()

View File

@ -1,9 +1,10 @@
from tianshou import data, env, utils, policy
from tianshou import data, env, utils, policy, exploration
__version__ = '0.2.0'
__all__ = [
'env',
'data',
'utils',
'policy'
'policy',
'exploration',
]

View File

@ -10,6 +10,10 @@ class ReplayBuffer(object):
self._maxsize = size
self.reset()
def __del__(self):
for k in list(self.__dict__.keys()):
del self.__dict__[k]
def __len__(self):
return self._size
@ -24,6 +28,9 @@ class ReplayBuffer(object):
[{} for _ in range(self._maxsize)])
else: # assume `inst` is a number
self.__dict__[name] = np.zeros([self._maxsize])
if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape:
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
self.__dict__[name][self._index] = inst
def update(self, buffer):

View File

@ -20,7 +20,7 @@ class Collector(object):
self.policy = policy
self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv)
self._multi_buf = False # buf is a list
self._multi_buf = False # True if buf is a list
# need multiple cache buffers only if storing in one buffer
self._cached_buf = []
if self._multi_env:
@ -65,9 +65,9 @@ class Collector(object):
if hasattr(self.env, 'seed'):
self.env.seed(seed)
def render(self):
def render(self, **kwargs):
if hasattr(self.env, 'render'):
self.env.render()
self.env.render(**kwargs)
def close(self):
if hasattr(self.env, 'close'):
@ -101,7 +101,10 @@ class Collector(object):
info=self._make_batch(self._info))
result = self.policy(batch_data, self.state)
self.state = result.state if hasattr(result, 'state') else None
self._act = result.act
if isinstance(result.act, torch.Tensor):
self._act = result.act.detach().cpu().numpy()
else:
self._act = np.array(result.act)
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0])
if render > 0:
@ -141,7 +144,10 @@ class Collector(object):
if isinstance(self.state, list):
self.state[i] = None
elif self.state is not None:
self.state[i] = self.state[i] * 0
if isinstance(self.state[i], dict):
self.state[i] = {}
else:
self.state[i] = self.state[i] * 0
if isinstance(self.state, torch.Tensor):
# remove ref count in pytorch (?)
self.state = self.state.detach()

View File

@ -24,9 +24,9 @@ class EnvWrapper(object):
if hasattr(self.env, 'seed'):
self.env.seed(seed)
def render(self):
def render(self, **kwargs):
if hasattr(self.env, 'render'):
self.env.render()
self.env.render(**kwargs)
def close(self):
self.env.close()
@ -83,7 +83,7 @@ class BaseVectorEnv(ABC):
pass
@abstractmethod
def render(self):
def render(self, **kwargs):
pass
@abstractmethod
@ -129,10 +129,10 @@ class VectorEnv(BaseVectorEnv):
if hasattr(e, 'seed'):
e.seed(s)
def render(self):
def render(self, **kwargs):
for e in self.envs:
if hasattr(e, 'render'):
e.render()
e.render(**kwargs)
def close(self):
for e in self.envs:
@ -160,7 +160,7 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
p.close()
break
elif cmd == 'render':
p.send(env.render() if hasattr(env, 'render') else None)
p.send(env.render(**data) if hasattr(env, 'render') else None)
elif cmd == 'seed':
p.send(env.seed(data) if hasattr(env, 'seed') else None)
else:
@ -213,9 +213,9 @@ class SubprocVectorEnv(BaseVectorEnv):
for p in self.parent_remote:
p.recv()
def render(self):
def render(self, **kwargs):
for p in self.parent_remote:
p.send(['render', None])
p.send(['render', kwargs])
for p in self.parent_remote:
p.recv()
@ -239,7 +239,7 @@ class RayVectorEnv(BaseVectorEnv):
ray.init()
except NameError:
raise ImportError(
'Please install ray to support VectorEnv: pip3 install ray -U')
'Please install ray to support RayVectorEnv: pip3 install ray')
self.envs = [
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
for e in env_fns]
@ -288,10 +288,10 @@ class RayVectorEnv(BaseVectorEnv):
for r in result_obj:
ray.get(r)
def render(self):
def render(self, **kwargs):
if not hasattr(self.envs[0], 'render'):
return
result_obj = [e.render.remote() for e in self.envs]
result_obj = [e.render.remote(**kwargs) for e in self.envs]
for r in result_obj:
ray.get(r)

View File

@ -0,0 +1,5 @@
from tianshou.exploration.random import OUNoise
__all__ = [
'OUNoise',
]

View File

@ -0,0 +1,21 @@
import numpy as np
class OUNoise(object):
"""docstring for OUNoise"""
def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None):
self.alpha = theta * dt
self.beta = sigma * np.sqrt(dt)
self.x0 = x0
self.reset()
def __call__(self, size, mu=.1):
if self.x is None or self.x.shape != size:
self.x = 0
self.x = self.x + self.alpha * (mu - self.x) + \
self.beta * np.random.normal(size=size)
return self.x
def reset(self):
self.x = None

View File

@ -1,11 +1,13 @@
from tianshou.policy.base import BasePolicy
from tianshou.policy.dqn import DQNPolicy
from tianshou.policy.policy_gradient import PGPolicy
from tianshou.policy.pg import PGPolicy
from tianshou.policy.a2c import A2CPolicy
from tianshou.policy.ddpg import DDPGPolicy
__all__ = [
'BasePolicy',
'DQNPolicy',
'PGPolicy',
'A2CPolicy',
'DDPGPolicy'
]

View File

@ -1,4 +1,5 @@
import torch
from torch import nn
import torch.nn.functional as F
from tianshou.data import Batch
@ -9,16 +10,18 @@ class A2CPolicy(PGPolicy):
"""docstring for A2CPolicy"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
discount_factor=0.99, vf_coef=.5, entropy_coef=.01):
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
max_grad_norm=None):
super().__init__(model, optim, dist_fn, discount_factor)
self._w_value = vf_coef
self._w_entropy = entropy_coef
self._grad_norm = max_grad_norm
def __call__(self, batch, state=None):
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1)
dist = self.dist_fn(logits)
act = dist.sample().detach().cpu().numpy()
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
def learn(self, batch, batch_size=None):
@ -31,12 +34,15 @@ class A2CPolicy(PGPolicy):
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
critic_loss = (r - v).pow(2).mean()
critic_loss = F.mse_loss(r[:, None], v)
entropy_loss = dist.entropy().mean()
loss = actor_loss \
+ self._w_value * critic_loss \
- self._w_entropy * entropy_loss
loss.backward()
if self._grad_norm:
nn.utils.clip_grad_norm_(
self.model.parameters(), max_norm=self._grad_norm)
self.optim.step()
losses.append(loss.detach().cpu().numpy())
return losses

View File

@ -1,19 +1,19 @@
from torch import nn
from abc import ABC, abstractmethod
class BasePolicy(ABC):
class BasePolicy(ABC, nn.Module):
"""docstring for BasePolicy"""
def __init__(self):
super().__init__()
self.model = None
def process_fn(self, batch, buffer, indice):
return batch
@abstractmethod
def __call__(self, batch, state=None):
# return Batch(logits=..., act=np.array(), state=None, ...)
# return Batch(logits=..., act=..., state=None, ...)
pass
@abstractmethod

91
tianshou/policy/ddpg.py Normal file
View File

@ -0,0 +1,91 @@
import torch
from copy import deepcopy
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise
class DDPGPolicy(BasePolicy):
"""docstring for DDPGPolicy"""
def __init__(self, actor, actor_optim,
critic, critic_optim, action_range,
tau=0.005, gamma=0.99, exploration_noise=0.1):
super().__init__()
self.actor = actor
self.actor_old = deepcopy(actor)
self.actor_old.load_state_dict(self.actor.state_dict())
self.actor_old.eval()
self.actor_optim = actor_optim
self.critic = critic
self.critic_old = deepcopy(critic)
self.critic_old.load_state_dict(self.critic.state_dict())
self.critic_old.eval()
self.critic_optim = critic_optim
assert 0 < tau <= 1, 'tau should in (0, 1]'
self._tau = tau
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
self._gamma = gamma
assert 0 <= exploration_noise, 'noise should greater than zero'
self._eps = exploration_noise
self._range = action_range
# self.noise = OUNoise()
def set_eps(self, eps):
self._eps = eps
def train(self):
self.training = True
self.actor.train()
self.critic.train()
def eval(self):
self.training = False
self.actor.eval()
self.critic.eval()
def sync_weight(self):
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
for o, n in zip(
self.critic_old.parameters(), self.critic.parameters()):
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
def process_fn(self, batch, buffer, indice):
return batch
def __call__(self, batch, state=None,
model='actor', input='obs', eps=None):
model = getattr(self, model)
obs = getattr(batch, input)
logits, h = model(obs, state=state, info=batch.info)
# noise = np.random.normal(0, self._eps, size=logits.shape)
logits += torch.randn(
size=logits.shape, device=logits.device) * self._eps
# noise = self.noise(logits.shape, self._eps)
# logits += torch.tensor(noise, device=logits.device)
logits = logits.clamp(self._range[0], self._range[1])
return Batch(act=logits, state=h)
def learn(self, batch, batch_size=None):
target_q = self.critic_old(
batch.obs_next, self.actor_old(batch.obs_next, state=None)[0])
dev = target_q.device
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)
done = torch.tensor(batch.done, dtype=torch.float, device=dev)
target_q = rew[:, None] + ((
1. - done[:, None]) * self._gamma * target_q).detach()
current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optim.zero_grad()
critic_loss.backward()
self.critic_optim.step()
actor_loss = -self.critic(
batch.obs, self.actor(batch.obs, state=None)[0]).mean()
self.actor_optim.zero_grad()
actor_loss.backward()
self.actor_optim.step()
return actor_loss.detach().cpu().numpy(),\
critic_loss.detach().cpu().numpy()

View File

@ -1,25 +1,24 @@
import torch
import numpy as np
from torch import nn
from copy import deepcopy
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class DQNPolicy(BasePolicy, nn.Module):
class DQNPolicy(BasePolicy):
"""docstring for DQNPolicy"""
def __init__(self, model, optim, loss_fn,
def __init__(self, model, optim,
discount_factor=0.99,
estimation_step=1,
use_target_network=True):
super().__init__()
self.model = model
self.optim = optim
self.loss_fn = loss_fn
self.eps = 0
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
self._gamma = discount_factor
assert estimation_step > 0, 'estimation_step should greater than 0'
self._n_step = estimation_step
@ -91,7 +90,7 @@ class DQNPolicy(BasePolicy, nn.Module):
r = batch.returns
if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype)
loss = self.loss_fn(q, r)
loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()
return loss.detach().cpu().numpy()

View File

@ -1,13 +1,12 @@
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from tianshou.data import Batch
from tianshou.policy import BasePolicy
class PGPolicy(BasePolicy, nn.Module):
class PGPolicy(BasePolicy):
"""docstring for PGPolicy"""
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
@ -17,7 +16,7 @@ class PGPolicy(BasePolicy, nn.Module):
self.optim = optim
self.dist_fn = dist_fn
self._eps = np.finfo(np.float32).eps.item()
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
self._gamma = discount_factor
def process_fn(self, batch, buffer, indice):
@ -30,7 +29,7 @@ class PGPolicy(BasePolicy, nn.Module):
logits, h = self.model(batch.obs, state=state, info=batch.info)
logits = F.softmax(logits, dim=1)
dist = self.dist_fn(logits)
act = dist.sample().detach().cpu().numpy()
act = dist.sample()
return Batch(logits=logits, act=act, state=h, dist=dist)
def learn(self, batch, batch_size=None):