ddpg
This commit is contained in:
parent
6e563fe61a
commit
64bab0b6a0
4
.github/workflows/pytest.yml
vendored
4
.github/workflows/pytest.yml
vendored
@ -33,8 +33,8 @@ jobs:
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics
|
||||
flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pip install pytest pytest-cov
|
||||
pytest --cov tianshou
|
||||
pytest --cov tianshou -s
|
||||
|
||||
3
setup.py
3
setup.py
@ -37,12 +37,11 @@ setup(
|
||||
'examples', 'examples.*',
|
||||
'docs', 'docs.*']),
|
||||
install_requires=[
|
||||
# 'ray',
|
||||
'gym',
|
||||
'tqdm',
|
||||
'numpy',
|
||||
'cloudpickle',
|
||||
'tensorboard',
|
||||
'torch>=1.2.0', # for supporting tensorboard
|
||||
'torch>=1.4.0',
|
||||
],
|
||||
)
|
||||
|
||||
@ -41,14 +41,14 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.9)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--layer-num', type=int, default=2)
|
||||
parser.add_argument('--training-num', type=int, default=8)
|
||||
parser.add_argument('--training-num', type=int, default=32)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument(
|
||||
@ -57,6 +57,7 @@ def get_args():
|
||||
# a2c special
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--entropy-coef', type=float, default=0.001)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=None)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
@ -86,12 +87,12 @@ def test_a2c(args=get_args()):
|
||||
policy = A2CPolicy(
|
||||
net, optim, dist, args.gamma,
|
||||
vf_coef=args.vf_coef,
|
||||
entropy_coef=args.entropy_coef)
|
||||
entropy_coef=args.entropy_coef,
|
||||
max_grad_norm=args.max_grad_norm)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(
|
||||
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
@ -126,6 +127,8 @@ def test_a2c(args=get_args()):
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
|
||||
206
test/test_ddpg.py
Normal file
206
test/test_ddpg.py
Normal file
@ -0,0 +1,206 @@
|
||||
import gym
|
||||
import time
|
||||
import tqdm
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils import tqdm_config, MovAvg
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape,
|
||||
max_action, device='cpu'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
nn.Linear(np.prod(state_shape), 128),
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
self._max = max_action
|
||||
|
||||
def forward(self, s, **kwargs):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
logits = self.model(s)
|
||||
logits = self._max * torch.tanh(logits)
|
||||
return logits, None
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.model = [
|
||||
nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128),
|
||||
nn.ReLU(inplace=True)]
|
||||
for i in range(layer_num):
|
||||
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||
self.model += [nn.Linear(128, 1)]
|
||||
self.model = nn.Sequential(*self.model)
|
||||
|
||||
def forward(self, s, a):
|
||||
s = torch.tensor(s, device=self.device, dtype=torch.float)
|
||||
if isinstance(a, np.ndarray):
|
||||
a = torch.tensor(a, device=self.device, dtype=torch.float)
|
||||
batch = s.shape[0]
|
||||
s = s.view(batch, -1)
|
||||
a = a.view(batch, -1)
|
||||
logits = self.model(torch.cat([s, a], dim=1))
|
||||
return logits
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Pendulum-v0')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-4)
|
||||
parser.add_argument('--actor-wd', type=float, default=0)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-wd', type=float, default=1e-2)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
||||
parser.add_argument('--collect-per-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_ddpg(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
reset_after_done=True)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
reset_after_done=False)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
actor = Actor(
|
||||
args.layer_num, args.state_shape, args.action_shape,
|
||||
args.max_action, args.device
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(
|
||||
actor.parameters(), lr=args.actor_lr, weight_decay=args.actor_wd)
|
||||
critic = Critic(
|
||||
args.layer_num, args.state_shape, args.action_shape, args.device
|
||||
).to(args.device)
|
||||
critic_optim = torch.optim.Adam(
|
||||
critic.parameters(), lr=args.critic_lr, weight_decay=args.critic_wd)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
[env.action_space.low[0], env.action_space.high[0]],
|
||||
args.tau, args.gamma, args.exploration_noise)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size), 1)
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
# log
|
||||
stat_a_loss = MovAvg()
|
||||
stat_c_loss = MovAvg()
|
||||
global_step = 0
|
||||
writer = SummaryWriter(args.logdir)
|
||||
best_epoch = -1
|
||||
best_reward = -1e10
|
||||
start_time = time.time()
|
||||
# training_collector.collect(n_step=1000)
|
||||
for epoch in range(1, 1 + args.epoch):
|
||||
desc = f'Epoch #{epoch}'
|
||||
# train
|
||||
policy.train()
|
||||
with tqdm.tqdm(
|
||||
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
result = training_collector.collect(
|
||||
n_step=args.collect_per_step)
|
||||
for i in range(min(
|
||||
result['n_step'] // args.collect_per_step,
|
||||
t.total - t.n)):
|
||||
t.update(1)
|
||||
global_step += 1
|
||||
actor_loss, critic_loss = policy.learn(
|
||||
training_collector.sample(args.batch_size))
|
||||
policy.sync_weight()
|
||||
stat_a_loss.add(actor_loss)
|
||||
stat_c_loss.add(critic_loss)
|
||||
writer.add_scalar(
|
||||
'reward', result['reward'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'length', result['length'], global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'actor_loss', stat_a_loss.get(),
|
||||
global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'critic_loss', stat_a_loss.get(),
|
||||
global_step=global_step)
|
||||
writer.add_scalar(
|
||||
'speed', result['speed'], global_step=global_step)
|
||||
t.set_postfix(actor_loss=f'{stat_a_loss.get():.6f}',
|
||||
critic_loss=f'{stat_c_loss.get():.6f}',
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
policy.eval()
|
||||
result = test_collector.collect(n_episode=args.test_num)
|
||||
if best_reward < result['reward']:
|
||||
best_reward = result['reward']
|
||||
best_epoch = epoch
|
||||
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
|
||||
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||
if args.task == 'Pendulum-v0' and best_reward >= -250:
|
||||
break
|
||||
if args.task == 'Pendulum-v0':
|
||||
assert best_reward >= -250
|
||||
training_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
train_cnt = training_collector.collect_step
|
||||
test_cnt = test_collector.collect_step
|
||||
duration = time.time() - start_time
|
||||
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
|
||||
f'in {duration:.2f}s, '
|
||||
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
test_collector = Collector(policy, env)
|
||||
result = test_collector.collect(n_episode=1, render=1 / 35)
|
||||
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
|
||||
test_collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ddpg()
|
||||
@ -78,13 +78,11 @@ def test_dqn(args=get_args()):
|
||||
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||
net = net.to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
loss = nn.MSELoss()
|
||||
policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step)
|
||||
policy = DQNPolicy(net, optim, args.gamma, args.n_step)
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(
|
||||
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
training_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
@ -124,6 +122,8 @@ def test_dqn(args=get_args()):
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
|
||||
@ -133,8 +133,7 @@ def test_pg(args=get_args()):
|
||||
# collector
|
||||
training_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(
|
||||
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
|
||||
# log
|
||||
stat_loss = MovAvg()
|
||||
global_step = 0
|
||||
@ -169,6 +168,8 @@ def test_pg(args=get_args()):
|
||||
reward=f'{result["reward"]:.6f}',
|
||||
length=f'{result["length"]:.2f}',
|
||||
speed=f'{result["speed"]:.2f}')
|
||||
if t.n <= t.total:
|
||||
t.update()
|
||||
# eval
|
||||
test_collector.reset_env()
|
||||
test_collector.reset_buffer()
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from tianshou import data, env, utils, policy
|
||||
from tianshou import data, env, utils, policy, exploration
|
||||
|
||||
__version__ = '0.2.0'
|
||||
__all__ = [
|
||||
'env',
|
||||
'data',
|
||||
'utils',
|
||||
'policy'
|
||||
'policy',
|
||||
'exploration',
|
||||
]
|
||||
|
||||
@ -10,6 +10,10 @@ class ReplayBuffer(object):
|
||||
self._maxsize = size
|
||||
self.reset()
|
||||
|
||||
def __del__(self):
|
||||
for k in list(self.__dict__.keys()):
|
||||
del self.__dict__[k]
|
||||
|
||||
def __len__(self):
|
||||
return self._size
|
||||
|
||||
@ -24,6 +28,9 @@ class ReplayBuffer(object):
|
||||
[{} for _ in range(self._maxsize)])
|
||||
else: # assume `inst` is a number
|
||||
self.__dict__[name] = np.zeros([self._maxsize])
|
||||
if isinstance(inst, np.ndarray) and \
|
||||
self.__dict__[name].shape[1:] != inst.shape:
|
||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer):
|
||||
|
||||
@ -20,7 +20,7 @@ class Collector(object):
|
||||
self.policy = policy
|
||||
self.process_fn = policy.process_fn
|
||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||
self._multi_buf = False # buf is a list
|
||||
self._multi_buf = False # True if buf is a list
|
||||
# need multiple cache buffers only if storing in one buffer
|
||||
self._cached_buf = []
|
||||
if self._multi_env:
|
||||
@ -65,9 +65,9 @@ class Collector(object):
|
||||
if hasattr(self.env, 'seed'):
|
||||
self.env.seed(seed)
|
||||
|
||||
def render(self):
|
||||
def render(self, **kwargs):
|
||||
if hasattr(self.env, 'render'):
|
||||
self.env.render()
|
||||
self.env.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
if hasattr(self.env, 'close'):
|
||||
@ -101,7 +101,10 @@ class Collector(object):
|
||||
info=self._make_batch(self._info))
|
||||
result = self.policy(batch_data, self.state)
|
||||
self.state = result.state if hasattr(result, 'state') else None
|
||||
self._act = result.act
|
||||
if isinstance(result.act, torch.Tensor):
|
||||
self._act = result.act.detach().cpu().numpy()
|
||||
else:
|
||||
self._act = np.array(result.act)
|
||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||
self._act if self._multi_env else self._act[0])
|
||||
if render > 0:
|
||||
@ -141,7 +144,10 @@ class Collector(object):
|
||||
if isinstance(self.state, list):
|
||||
self.state[i] = None
|
||||
elif self.state is not None:
|
||||
self.state[i] = self.state[i] * 0
|
||||
if isinstance(self.state[i], dict):
|
||||
self.state[i] = {}
|
||||
else:
|
||||
self.state[i] = self.state[i] * 0
|
||||
if isinstance(self.state, torch.Tensor):
|
||||
# remove ref count in pytorch (?)
|
||||
self.state = self.state.detach()
|
||||
|
||||
22
tianshou/env/wrapper.py
vendored
22
tianshou/env/wrapper.py
vendored
@ -24,9 +24,9 @@ class EnvWrapper(object):
|
||||
if hasattr(self.env, 'seed'):
|
||||
self.env.seed(seed)
|
||||
|
||||
def render(self):
|
||||
def render(self, **kwargs):
|
||||
if hasattr(self.env, 'render'):
|
||||
self.env.render()
|
||||
self.env.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
self.env.close()
|
||||
@ -83,7 +83,7 @@ class BaseVectorEnv(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render(self):
|
||||
def render(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -129,10 +129,10 @@ class VectorEnv(BaseVectorEnv):
|
||||
if hasattr(e, 'seed'):
|
||||
e.seed(s)
|
||||
|
||||
def render(self):
|
||||
def render(self, **kwargs):
|
||||
for e in self.envs:
|
||||
if hasattr(e, 'render'):
|
||||
e.render()
|
||||
e.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
for e in self.envs:
|
||||
@ -160,7 +160,7 @@ def worker(parent, p, env_fn_wrapper, reset_after_done):
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render() if hasattr(env, 'render') else None)
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
else:
|
||||
@ -213,9 +213,9 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
for p in self.parent_remote:
|
||||
p.recv()
|
||||
|
||||
def render(self):
|
||||
def render(self, **kwargs):
|
||||
for p in self.parent_remote:
|
||||
p.send(['render', None])
|
||||
p.send(['render', kwargs])
|
||||
for p in self.parent_remote:
|
||||
p.recv()
|
||||
|
||||
@ -239,7 +239,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
ray.init()
|
||||
except NameError:
|
||||
raise ImportError(
|
||||
'Please install ray to support VectorEnv: pip3 install ray -U')
|
||||
'Please install ray to support RayVectorEnv: pip3 install ray')
|
||||
self.envs = [
|
||||
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
@ -288,10 +288,10 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
for r in result_obj:
|
||||
ray.get(r)
|
||||
|
||||
def render(self):
|
||||
def render(self, **kwargs):
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return
|
||||
result_obj = [e.render.remote() for e in self.envs]
|
||||
result_obj = [e.render.remote(**kwargs) for e in self.envs]
|
||||
for r in result_obj:
|
||||
ray.get(r)
|
||||
|
||||
|
||||
5
tianshou/exploration/__init__.py
Normal file
5
tianshou/exploration/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from tianshou.exploration.random import OUNoise
|
||||
|
||||
__all__ = [
|
||||
'OUNoise',
|
||||
]
|
||||
21
tianshou/exploration/random.py
Normal file
21
tianshou/exploration/random.py
Normal file
@ -0,0 +1,21 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OUNoise(object):
|
||||
"""docstring for OUNoise"""
|
||||
|
||||
def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None):
|
||||
self.alpha = theta * dt
|
||||
self.beta = sigma * np.sqrt(dt)
|
||||
self.x0 = x0
|
||||
self.reset()
|
||||
|
||||
def __call__(self, size, mu=.1):
|
||||
if self.x is None or self.x.shape != size:
|
||||
self.x = 0
|
||||
self.x = self.x + self.alpha * (mu - self.x) + \
|
||||
self.beta * np.random.normal(size=size)
|
||||
return self.x
|
||||
|
||||
def reset(self):
|
||||
self.x = None
|
||||
@ -1,11 +1,13 @@
|
||||
from tianshou.policy.base import BasePolicy
|
||||
from tianshou.policy.dqn import DQNPolicy
|
||||
from tianshou.policy.policy_gradient import PGPolicy
|
||||
from tianshou.policy.pg import PGPolicy
|
||||
from tianshou.policy.a2c import A2CPolicy
|
||||
from tianshou.policy.ddpg import DDPGPolicy
|
||||
|
||||
__all__ = [
|
||||
'BasePolicy',
|
||||
'DQNPolicy',
|
||||
'PGPolicy',
|
||||
'A2CPolicy',
|
||||
'DDPGPolicy'
|
||||
]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
@ -9,16 +10,18 @@ class A2CPolicy(PGPolicy):
|
||||
"""docstring for A2CPolicy"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
discount_factor=0.99, vf_coef=.5, entropy_coef=.01):
|
||||
discount_factor=0.99, vf_coef=.5, entropy_coef=.01,
|
||||
max_grad_norm=None):
|
||||
super().__init__(model, optim, dist_fn, discount_factor)
|
||||
self._w_value = vf_coef
|
||||
self._w_entropy = entropy_coef
|
||||
self._grad_norm = max_grad_norm
|
||||
|
||||
def __call__(self, batch, state=None):
|
||||
logits, value, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample().detach().cpu().numpy()
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist, value=value)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
@ -31,12 +34,15 @@ class A2CPolicy(PGPolicy):
|
||||
a = torch.tensor(b.act, device=dist.logits.device)
|
||||
r = torch.tensor(b.returns, device=dist.logits.device)
|
||||
actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
|
||||
critic_loss = (r - v).pow(2).mean()
|
||||
critic_loss = F.mse_loss(r[:, None], v)
|
||||
entropy_loss = dist.entropy().mean()
|
||||
loss = actor_loss \
|
||||
+ self._w_value * critic_loss \
|
||||
- self._w_entropy * entropy_loss
|
||||
loss.backward()
|
||||
if self._grad_norm:
|
||||
nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), max_norm=self._grad_norm)
|
||||
self.optim.step()
|
||||
losses.append(loss.detach().cpu().numpy())
|
||||
return losses
|
||||
|
||||
@ -1,19 +1,19 @@
|
||||
from torch import nn
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BasePolicy(ABC):
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
"""docstring for BasePolicy"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = None
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
return batch
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, batch, state=None):
|
||||
# return Batch(logits=..., act=np.array(), state=None, ...)
|
||||
# return Batch(logits=..., act=..., state=None, ...)
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
91
tianshou/policy/ddpg.py
Normal file
91
tianshou/policy/ddpg.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
# from tianshou.exploration import OUNoise
|
||||
|
||||
|
||||
class DDPGPolicy(BasePolicy):
|
||||
"""docstring for DDPGPolicy"""
|
||||
|
||||
def __init__(self, actor, actor_optim,
|
||||
critic, critic_optim, action_range,
|
||||
tau=0.005, gamma=0.99, exploration_noise=0.1):
|
||||
super().__init__()
|
||||
self.actor = actor
|
||||
self.actor_old = deepcopy(actor)
|
||||
self.actor_old.load_state_dict(self.actor.state_dict())
|
||||
self.actor_old.eval()
|
||||
self.actor_optim = actor_optim
|
||||
self.critic = critic
|
||||
self.critic_old = deepcopy(critic)
|
||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||
self.critic_old.eval()
|
||||
self.critic_optim = critic_optim
|
||||
assert 0 < tau <= 1, 'tau should in (0, 1]'
|
||||
self._tau = tau
|
||||
assert 0 < gamma <= 1, 'gamma should in (0, 1]'
|
||||
self._gamma = gamma
|
||||
assert 0 <= exploration_noise, 'noise should greater than zero'
|
||||
self._eps = exploration_noise
|
||||
self._range = action_range
|
||||
# self.noise = OUNoise()
|
||||
|
||||
def set_eps(self, eps):
|
||||
self._eps = eps
|
||||
|
||||
def train(self):
|
||||
self.training = True
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
|
||||
def sync_weight(self):
|
||||
for o, n in zip(self.actor_old.parameters(), self.actor.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
for o, n in zip(
|
||||
self.critic_old.parameters(), self.critic.parameters()):
|
||||
o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau)
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
return batch
|
||||
|
||||
def __call__(self, batch, state=None,
|
||||
model='actor', input='obs', eps=None):
|
||||
model = getattr(self, model)
|
||||
obs = getattr(batch, input)
|
||||
logits, h = model(obs, state=state, info=batch.info)
|
||||
# noise = np.random.normal(0, self._eps, size=logits.shape)
|
||||
logits += torch.randn(
|
||||
size=logits.shape, device=logits.device) * self._eps
|
||||
# noise = self.noise(logits.shape, self._eps)
|
||||
# logits += torch.tensor(noise, device=logits.device)
|
||||
logits = logits.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=logits, state=h)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
target_q = self.critic_old(
|
||||
batch.obs_next, self.actor_old(batch.obs_next, state=None)[0])
|
||||
dev = target_q.device
|
||||
rew = torch.tensor(batch.rew, dtype=torch.float, device=dev)
|
||||
done = torch.tensor(batch.done, dtype=torch.float, device=dev)
|
||||
target_q = rew[:, None] + ((
|
||||
1. - done[:, None]) * self._gamma * target_q).detach()
|
||||
current_q = self.critic(batch.obs, batch.act)
|
||||
critic_loss = F.mse_loss(current_q, target_q)
|
||||
self.critic_optim.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optim.step()
|
||||
actor_loss = -self.critic(
|
||||
batch.obs, self.actor(batch.obs, state=None)[0]).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optim.step()
|
||||
return actor_loss.detach().cpu().numpy(),\
|
||||
critic_loss.detach().cpu().numpy()
|
||||
@ -1,25 +1,24 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from copy import deepcopy
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
class DQNPolicy(BasePolicy, nn.Module):
|
||||
class DQNPolicy(BasePolicy):
|
||||
"""docstring for DQNPolicy"""
|
||||
|
||||
def __init__(self, model, optim, loss_fn,
|
||||
def __init__(self, model, optim,
|
||||
discount_factor=0.99,
|
||||
estimation_step=1,
|
||||
use_target_network=True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.optim = optim
|
||||
self.loss_fn = loss_fn
|
||||
self.eps = 0
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
|
||||
self._gamma = discount_factor
|
||||
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||
self._n_step = estimation_step
|
||||
@ -91,7 +90,7 @@ class DQNPolicy(BasePolicy, nn.Module):
|
||||
r = batch.returns
|
||||
if isinstance(r, np.ndarray):
|
||||
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||
loss = self.loss_fn(q, r)
|
||||
loss = F.mse_loss(q, r)
|
||||
loss.backward()
|
||||
self.optim.step()
|
||||
return loss.detach().cpu().numpy()
|
||||
|
||||
@ -1,13 +1,12 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
class PGPolicy(BasePolicy, nn.Module):
|
||||
class PGPolicy(BasePolicy):
|
||||
"""docstring for PGPolicy"""
|
||||
|
||||
def __init__(self, model, optim, dist_fn=torch.distributions.Categorical,
|
||||
@ -17,7 +16,7 @@ class PGPolicy(BasePolicy, nn.Module):
|
||||
self.optim = optim
|
||||
self.dist_fn = dist_fn
|
||||
self._eps = np.finfo(np.float32).eps.item()
|
||||
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||
assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]'
|
||||
self._gamma = discount_factor
|
||||
|
||||
def process_fn(self, batch, buffer, indice):
|
||||
@ -30,7 +29,7 @@ class PGPolicy(BasePolicy, nn.Module):
|
||||
logits, h = self.model(batch.obs, state=state, info=batch.info)
|
||||
logits = F.softmax(logits, dim=1)
|
||||
dist = self.dist_fn(logits)
|
||||
act = dist.sample().detach().cpu().numpy()
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=h, dist=dist)
|
||||
|
||||
def learn(self, batch, batch_size=None):
|
||||
Loading…
x
Reference in New Issue
Block a user