diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index eb0b796..b0d72a2 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -33,8 +33,8 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=20 --max-line-length=79 --statistics + flake8 . --count --exit-zero --max-complexity=30 --max-line-length=79 --statistics - name: Test with pytest run: | pip install pytest pytest-cov - pytest --cov tianshou + pytest --cov tianshou -s diff --git a/setup.py b/setup.py index 2a73f9d..0e2b66d 100644 --- a/setup.py +++ b/setup.py @@ -37,12 +37,11 @@ setup( 'examples', 'examples.*', 'docs', 'docs.*']), install_requires=[ - # 'ray', 'gym', 'tqdm', 'numpy', 'cloudpickle', 'tensorboard', - 'torch>=1.2.0', # for supporting tensorboard + 'torch>=1.4.0', ], ) diff --git a/test/test_a2c.py b/test/test_a2c.py index 4033c4b..7cb2358 100644 --- a/test/test_a2c.py +++ b/test/test_a2c.py @@ -41,14 +41,14 @@ def get_args(): parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--seed', type=int, default=1626) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.9) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--step-per-epoch', type=int, default=320) parser.add_argument('--collect-per-step', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--layer-num', type=int, default=2) - parser.add_argument('--training-num', type=int, default=8) + parser.add_argument('--training-num', type=int, default=32) parser.add_argument('--test-num', type=int, default=100) parser.add_argument('--logdir', type=str, default='log') parser.add_argument( @@ -57,6 +57,7 @@ def get_args(): # a2c special parser.add_argument('--vf-coef', type=float, default=0.5) parser.add_argument('--entropy-coef', type=float, default=0.001) + parser.add_argument('--max-grad-norm', type=float, default=None) args = parser.parse_known_args()[0] return args @@ -86,12 +87,12 @@ def test_a2c(args=get_args()): policy = A2CPolicy( net, optim, dist, args.gamma, vf_coef=args.vf_coef, - entropy_coef=args.entropy_coef) + entropy_coef=args.entropy_coef, + max_grad_norm=args.max_grad_norm) # collector training_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector( - policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num) + test_collector = Collector(policy, test_envs, stat_size=args.test_num) # log stat_loss = MovAvg() global_step = 0 @@ -126,6 +127,8 @@ def test_a2c(args=get_args()): reward=f'{result["reward"]:.6f}', length=f'{result["length"]:.2f}', speed=f'{result["speed"]:.2f}') + if t.n <= t.total: + t.update() # eval test_collector.reset_env() test_collector.reset_buffer() diff --git a/test/test_ddpg.py b/test/test_ddpg.py new file mode 100644 index 0000000..3dcc962 --- /dev/null +++ b/test/test_ddpg.py @@ -0,0 +1,206 @@ +import gym +import time +import tqdm +import torch +import argparse +import numpy as np +from torch import nn +from torch.utils.tensorboard import SummaryWriter + +from tianshou.policy import DDPGPolicy +from tianshou.env import SubprocVectorEnv +from tianshou.utils import tqdm_config, MovAvg +from tianshou.data import Collector, ReplayBuffer + + +class Actor(nn.Module): + def __init__(self, layer_num, state_shape, action_shape, + max_action, device='cpu'): + super().__init__() + self.device = device + self.model = [ + nn.Linear(np.prod(state_shape), 128), + nn.ReLU(inplace=True)] + for i in range(layer_num): + self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] + self.model += [nn.Linear(128, np.prod(action_shape))] + self.model = nn.Sequential(*self.model) + self._max = max_action + + def forward(self, s, **kwargs): + s = torch.tensor(s, device=self.device, dtype=torch.float) + batch = s.shape[0] + s = s.view(batch, -1) + logits = self.model(s) + logits = self._max * torch.tanh(logits) + return logits, None + + +class Critic(nn.Module): + def __init__(self, layer_num, state_shape, action_shape, device='cpu'): + super().__init__() + self.device = device + self.model = [ + nn.Linear(np.prod(state_shape) + np.prod(action_shape), 128), + nn.ReLU(inplace=True)] + for i in range(layer_num): + self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)] + self.model += [nn.Linear(128, 1)] + self.model = nn.Sequential(*self.model) + + def forward(self, s, a): + s = torch.tensor(s, device=self.device, dtype=torch.float) + if isinstance(a, np.ndarray): + a = torch.tensor(a, device=self.device, dtype=torch.float) + batch = s.shape[0] + s = s.view(batch, -1) + a = a.view(batch, -1) + logits = self.model(torch.cat([s, a], dim=1)) + return logits + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, default='Pendulum-v0') + parser.add_argument('--seed', type=int, default=1626) + parser.add_argument('--buffer-size', type=int, default=20000) + parser.add_argument('--actor-lr', type=float, default=1e-4) + parser.add_argument('--actor-wd', type=float, default=0) + parser.add_argument('--critic-lr', type=float, default=1e-3) + parser.add_argument('--critic-wd', type=float, default=1e-2) + parser.add_argument('--gamma', type=float, default=0.99) + parser.add_argument('--tau', type=float, default=0.005) + parser.add_argument('--exploration-noise', type=float, default=0.1) + parser.add_argument('--epoch', type=int, default=100) + parser.add_argument('--step-per-epoch', type=int, default=2400) + parser.add_argument('--collect-per-step', type=int, default=1) + parser.add_argument('--batch-size', type=int, default=128) + parser.add_argument('--layer-num', type=int, default=1) + parser.add_argument('--training-num', type=int, default=1) + parser.add_argument('--test-num', type=int, default=100) + parser.add_argument('--logdir', type=str, default='log') + parser.add_argument( + '--device', type=str, + default='cuda' if torch.cuda.is_available() else 'cpu') + args = parser.parse_known_args()[0] + return args + + +def test_ddpg(args=get_args()): + env = gym.make(args.task) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + args.max_action = env.action_space.high[0] + # train_envs = gym.make(args.task) + train_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.training_num)], + reset_after_done=True) + # test_envs = gym.make(args.task) + test_envs = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + reset_after_done=False) + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + # model + actor = Actor( + args.layer_num, args.state_shape, args.action_shape, + args.max_action, args.device + ).to(args.device) + actor_optim = torch.optim.Adam( + actor.parameters(), lr=args.actor_lr, weight_decay=args.actor_wd) + critic = Critic( + args.layer_num, args.state_shape, args.action_shape, args.device + ).to(args.device) + critic_optim = torch.optim.Adam( + critic.parameters(), lr=args.critic_lr, weight_decay=args.critic_wd) + policy = DDPGPolicy( + actor, actor_optim, critic, critic_optim, + [env.action_space.low[0], env.action_space.high[0]], + args.tau, args.gamma, args.exploration_noise) + # collector + training_collector = Collector( + policy, train_envs, ReplayBuffer(args.buffer_size), 1) + test_collector = Collector(policy, test_envs, stat_size=args.test_num) + # log + stat_a_loss = MovAvg() + stat_c_loss = MovAvg() + global_step = 0 + writer = SummaryWriter(args.logdir) + best_epoch = -1 + best_reward = -1e10 + start_time = time.time() + # training_collector.collect(n_step=1000) + for epoch in range(1, 1 + args.epoch): + desc = f'Epoch #{epoch}' + # train + policy.train() + with tqdm.tqdm( + total=args.step_per_epoch, desc=desc, **tqdm_config) as t: + while t.n < t.total: + result = training_collector.collect( + n_step=args.collect_per_step) + for i in range(min( + result['n_step'] // args.collect_per_step, + t.total - t.n)): + t.update(1) + global_step += 1 + actor_loss, critic_loss = policy.learn( + training_collector.sample(args.batch_size)) + policy.sync_weight() + stat_a_loss.add(actor_loss) + stat_c_loss.add(critic_loss) + writer.add_scalar( + 'reward', result['reward'], global_step=global_step) + writer.add_scalar( + 'length', result['length'], global_step=global_step) + writer.add_scalar( + 'actor_loss', stat_a_loss.get(), + global_step=global_step) + writer.add_scalar( + 'critic_loss', stat_a_loss.get(), + global_step=global_step) + writer.add_scalar( + 'speed', result['speed'], global_step=global_step) + t.set_postfix(actor_loss=f'{stat_a_loss.get():.6f}', + critic_loss=f'{stat_c_loss.get():.6f}', + reward=f'{result["reward"]:.6f}', + length=f'{result["length"]:.2f}', + speed=f'{result["speed"]:.2f}') + if t.n <= t.total: + t.update() + # eval + test_collector.reset_env() + test_collector.reset_buffer() + policy.eval() + result = test_collector.collect(n_episode=args.test_num) + if best_reward < result['reward']: + best_reward = result['reward'] + best_epoch = epoch + print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, ' + f'best_reward: {best_reward:.6f} in #{best_epoch}') + if args.task == 'Pendulum-v0' and best_reward >= -250: + break + if args.task == 'Pendulum-v0': + assert best_reward >= -250 + training_collector.close() + test_collector.close() + if __name__ == '__main__': + train_cnt = training_collector.collect_step + test_cnt = test_collector.collect_step + duration = time.time() - start_time + print(f'Collect {train_cnt} training frame and {test_cnt} test frame ' + f'in {duration:.2f}s, ' + f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s') + # Let's watch its performance! + env = gym.make(args.task) + test_collector = Collector(policy, env) + result = test_collector.collect(n_episode=1, render=1 / 35) + print(f'Final reward: {result["reward"]}, length: {result["length"]}') + test_collector.close() + + +if __name__ == '__main__': + test_ddpg() diff --git a/test/test_dqn.py b/test/test_dqn.py index c3fbbc7..e2450d6 100644 --- a/test/test_dqn.py +++ b/test/test_dqn.py @@ -78,13 +78,11 @@ def test_dqn(args=get_args()): net = Net(args.layer_num, args.state_shape, args.action_shape, args.device) net = net.to(args.device) optim = torch.optim.Adam(net.parameters(), lr=args.lr) - loss = nn.MSELoss() - policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step) + policy = DQNPolicy(net, optim, args.gamma, args.n_step) # collector training_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector( - policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num) + test_collector = Collector(policy, test_envs, stat_size=args.test_num) training_collector.collect(n_step=args.batch_size) # log stat_loss = MovAvg() @@ -124,6 +122,8 @@ def test_dqn(args=get_args()): reward=f'{result["reward"]:.6f}', length=f'{result["length"]:.2f}', speed=f'{result["speed"]:.2f}') + if t.n <= t.total: + t.update() # eval test_collector.reset_env() test_collector.reset_buffer() diff --git a/test/test_pg.py b/test/test_pg.py index 93c8603..e2f45a6 100644 --- a/test/test_pg.py +++ b/test/test_pg.py @@ -133,8 +133,7 @@ def test_pg(args=get_args()): # collector training_collector = Collector( policy, train_envs, ReplayBuffer(args.buffer_size)) - test_collector = Collector( - policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num) + test_collector = Collector(policy, test_envs, stat_size=args.test_num) # log stat_loss = MovAvg() global_step = 0 @@ -169,6 +168,8 @@ def test_pg(args=get_args()): reward=f'{result["reward"]:.6f}', length=f'{result["length"]:.2f}', speed=f'{result["speed"]:.2f}') + if t.n <= t.total: + t.update() # eval test_collector.reset_env() test_collector.reset_buffer() diff --git a/tianshou/__init__.py b/tianshou/__init__.py index 38edeca..5d71a93 100644 --- a/tianshou/__init__.py +++ b/tianshou/__init__.py @@ -1,9 +1,10 @@ -from tianshou import data, env, utils, policy +from tianshou import data, env, utils, policy, exploration __version__ = '0.2.0' __all__ = [ 'env', 'data', 'utils', - 'policy' + 'policy', + 'exploration', ] diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index c015e69..ccb0e92 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -10,6 +10,10 @@ class ReplayBuffer(object): self._maxsize = size self.reset() + def __del__(self): + for k in list(self.__dict__.keys()): + del self.__dict__[k] + def __len__(self): return self._size @@ -24,6 +28,9 @@ class ReplayBuffer(object): [{} for _ in range(self._maxsize)]) else: # assume `inst` is a number self.__dict__[name] = np.zeros([self._maxsize]) + if isinstance(inst, np.ndarray) and \ + self.__dict__[name].shape[1:] != inst.shape: + self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) self.__dict__[name][self._index] = inst def update(self, buffer): diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index d7ac138..938769a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -20,7 +20,7 @@ class Collector(object): self.policy = policy self.process_fn = policy.process_fn self._multi_env = isinstance(env, BaseVectorEnv) - self._multi_buf = False # buf is a list + self._multi_buf = False # True if buf is a list # need multiple cache buffers only if storing in one buffer self._cached_buf = [] if self._multi_env: @@ -65,9 +65,9 @@ class Collector(object): if hasattr(self.env, 'seed'): self.env.seed(seed) - def render(self): + def render(self, **kwargs): if hasattr(self.env, 'render'): - self.env.render() + self.env.render(**kwargs) def close(self): if hasattr(self.env, 'close'): @@ -101,7 +101,10 @@ class Collector(object): info=self._make_batch(self._info)) result = self.policy(batch_data, self.state) self.state = result.state if hasattr(result, 'state') else None - self._act = result.act + if isinstance(result.act, torch.Tensor): + self._act = result.act.detach().cpu().numpy() + else: + self._act = np.array(result.act) obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if render > 0: @@ -141,7 +144,10 @@ class Collector(object): if isinstance(self.state, list): self.state[i] = None elif self.state is not None: - self.state[i] = self.state[i] * 0 + if isinstance(self.state[i], dict): + self.state[i] = {} + else: + self.state[i] = self.state[i] * 0 if isinstance(self.state, torch.Tensor): # remove ref count in pytorch (?) self.state = self.state.detach() diff --git a/tianshou/env/wrapper.py b/tianshou/env/wrapper.py index f67be4f..6de6801 100644 --- a/tianshou/env/wrapper.py +++ b/tianshou/env/wrapper.py @@ -24,9 +24,9 @@ class EnvWrapper(object): if hasattr(self.env, 'seed'): self.env.seed(seed) - def render(self): + def render(self, **kwargs): if hasattr(self.env, 'render'): - self.env.render() + self.env.render(**kwargs) def close(self): self.env.close() @@ -83,7 +83,7 @@ class BaseVectorEnv(ABC): pass @abstractmethod - def render(self): + def render(self, **kwargs): pass @abstractmethod @@ -129,10 +129,10 @@ class VectorEnv(BaseVectorEnv): if hasattr(e, 'seed'): e.seed(s) - def render(self): + def render(self, **kwargs): for e in self.envs: if hasattr(e, 'render'): - e.render() + e.render(**kwargs) def close(self): for e in self.envs: @@ -160,7 +160,7 @@ def worker(parent, p, env_fn_wrapper, reset_after_done): p.close() break elif cmd == 'render': - p.send(env.render() if hasattr(env, 'render') else None) + p.send(env.render(**data) if hasattr(env, 'render') else None) elif cmd == 'seed': p.send(env.seed(data) if hasattr(env, 'seed') else None) else: @@ -213,9 +213,9 @@ class SubprocVectorEnv(BaseVectorEnv): for p in self.parent_remote: p.recv() - def render(self): + def render(self, **kwargs): for p in self.parent_remote: - p.send(['render', None]) + p.send(['render', kwargs]) for p in self.parent_remote: p.recv() @@ -239,7 +239,7 @@ class RayVectorEnv(BaseVectorEnv): ray.init() except NameError: raise ImportError( - 'Please install ray to support VectorEnv: pip3 install ray -U') + 'Please install ray to support RayVectorEnv: pip3 install ray') self.envs = [ ray.remote(EnvWrapper).options(num_cpus=0).remote(e()) for e in env_fns] @@ -288,10 +288,10 @@ class RayVectorEnv(BaseVectorEnv): for r in result_obj: ray.get(r) - def render(self): + def render(self, **kwargs): if not hasattr(self.envs[0], 'render'): return - result_obj = [e.render.remote() for e in self.envs] + result_obj = [e.render.remote(**kwargs) for e in self.envs] for r in result_obj: ray.get(r) diff --git a/tianshou/exploration/__init__.py b/tianshou/exploration/__init__.py new file mode 100644 index 0000000..220913e --- /dev/null +++ b/tianshou/exploration/__init__.py @@ -0,0 +1,5 @@ +from tianshou.exploration.random import OUNoise + +__all__ = [ + 'OUNoise', +] diff --git a/tianshou/exploration/random.py b/tianshou/exploration/random.py new file mode 100644 index 0000000..011afbe --- /dev/null +++ b/tianshou/exploration/random.py @@ -0,0 +1,21 @@ +import numpy as np + + +class OUNoise(object): + """docstring for OUNoise""" + + def __init__(self, sigma=0.3, theta=0.15, dt=1e-2, x0=None): + self.alpha = theta * dt + self.beta = sigma * np.sqrt(dt) + self.x0 = x0 + self.reset() + + def __call__(self, size, mu=.1): + if self.x is None or self.x.shape != size: + self.x = 0 + self.x = self.x + self.alpha * (mu - self.x) + \ + self.beta * np.random.normal(size=size) + return self.x + + def reset(self): + self.x = None diff --git a/tianshou/policy/__init__.py b/tianshou/policy/__init__.py index 4301913..aba6a73 100644 --- a/tianshou/policy/__init__.py +++ b/tianshou/policy/__init__.py @@ -1,11 +1,13 @@ from tianshou.policy.base import BasePolicy from tianshou.policy.dqn import DQNPolicy -from tianshou.policy.policy_gradient import PGPolicy +from tianshou.policy.pg import PGPolicy from tianshou.policy.a2c import A2CPolicy +from tianshou.policy.ddpg import DDPGPolicy __all__ = [ 'BasePolicy', 'DQNPolicy', 'PGPolicy', 'A2CPolicy', + 'DDPGPolicy' ] diff --git a/tianshou/policy/a2c.py b/tianshou/policy/a2c.py index f546a62..1dd1cc8 100644 --- a/tianshou/policy/a2c.py +++ b/tianshou/policy/a2c.py @@ -1,4 +1,5 @@ import torch +from torch import nn import torch.nn.functional as F from tianshou.data import Batch @@ -9,16 +10,18 @@ class A2CPolicy(PGPolicy): """docstring for A2CPolicy""" def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, - discount_factor=0.99, vf_coef=.5, entropy_coef=.01): + discount_factor=0.99, vf_coef=.5, entropy_coef=.01, + max_grad_norm=None): super().__init__(model, optim, dist_fn, discount_factor) self._w_value = vf_coef self._w_entropy = entropy_coef + self._grad_norm = max_grad_norm def __call__(self, batch, state=None): logits, value, h = self.model(batch.obs, state=state, info=batch.info) logits = F.softmax(logits, dim=1) dist = self.dist_fn(logits) - act = dist.sample().detach().cpu().numpy() + act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist, value=value) def learn(self, batch, batch_size=None): @@ -31,12 +34,15 @@ class A2CPolicy(PGPolicy): a = torch.tensor(b.act, device=dist.logits.device) r = torch.tensor(b.returns, device=dist.logits.device) actor_loss = -(dist.log_prob(a) * (r - v).detach()).mean() - critic_loss = (r - v).pow(2).mean() + critic_loss = F.mse_loss(r[:, None], v) entropy_loss = dist.entropy().mean() loss = actor_loss \ + self._w_value * critic_loss \ - self._w_entropy * entropy_loss loss.backward() + if self._grad_norm: + nn.utils.clip_grad_norm_( + self.model.parameters(), max_norm=self._grad_norm) self.optim.step() losses.append(loss.detach().cpu().numpy()) return losses diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index ee96a50..f4f4f83 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -1,19 +1,19 @@ +from torch import nn from abc import ABC, abstractmethod -class BasePolicy(ABC): +class BasePolicy(ABC, nn.Module): """docstring for BasePolicy""" def __init__(self): super().__init__() - self.model = None def process_fn(self, batch, buffer, indice): return batch @abstractmethod def __call__(self, batch, state=None): - # return Batch(logits=..., act=np.array(), state=None, ...) + # return Batch(logits=..., act=..., state=None, ...) pass @abstractmethod diff --git a/tianshou/policy/ddpg.py b/tianshou/policy/ddpg.py new file mode 100644 index 0000000..63b32db --- /dev/null +++ b/tianshou/policy/ddpg.py @@ -0,0 +1,91 @@ +import torch +from copy import deepcopy +import torch.nn.functional as F + +from tianshou.data import Batch +from tianshou.policy import BasePolicy +# from tianshou.exploration import OUNoise + + +class DDPGPolicy(BasePolicy): + """docstring for DDPGPolicy""" + + def __init__(self, actor, actor_optim, + critic, critic_optim, action_range, + tau=0.005, gamma=0.99, exploration_noise=0.1): + super().__init__() + self.actor = actor + self.actor_old = deepcopy(actor) + self.actor_old.load_state_dict(self.actor.state_dict()) + self.actor_old.eval() + self.actor_optim = actor_optim + self.critic = critic + self.critic_old = deepcopy(critic) + self.critic_old.load_state_dict(self.critic.state_dict()) + self.critic_old.eval() + self.critic_optim = critic_optim + assert 0 < tau <= 1, 'tau should in (0, 1]' + self._tau = tau + assert 0 < gamma <= 1, 'gamma should in (0, 1]' + self._gamma = gamma + assert 0 <= exploration_noise, 'noise should greater than zero' + self._eps = exploration_noise + self._range = action_range + # self.noise = OUNoise() + + def set_eps(self, eps): + self._eps = eps + + def train(self): + self.training = True + self.actor.train() + self.critic.train() + + def eval(self): + self.training = False + self.actor.eval() + self.critic.eval() + + def sync_weight(self): + for o, n in zip(self.actor_old.parameters(), self.actor.parameters()): + o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + for o, n in zip( + self.critic_old.parameters(), self.critic.parameters()): + o.data.copy_(o.data * (1 - self._tau) + n.data * self._tau) + + def process_fn(self, batch, buffer, indice): + return batch + + def __call__(self, batch, state=None, + model='actor', input='obs', eps=None): + model = getattr(self, model) + obs = getattr(batch, input) + logits, h = model(obs, state=state, info=batch.info) + # noise = np.random.normal(0, self._eps, size=logits.shape) + logits += torch.randn( + size=logits.shape, device=logits.device) * self._eps + # noise = self.noise(logits.shape, self._eps) + # logits += torch.tensor(noise, device=logits.device) + logits = logits.clamp(self._range[0], self._range[1]) + return Batch(act=logits, state=h) + + def learn(self, batch, batch_size=None): + target_q = self.critic_old( + batch.obs_next, self.actor_old(batch.obs_next, state=None)[0]) + dev = target_q.device + rew = torch.tensor(batch.rew, dtype=torch.float, device=dev) + done = torch.tensor(batch.done, dtype=torch.float, device=dev) + target_q = rew[:, None] + (( + 1. - done[:, None]) * self._gamma * target_q).detach() + current_q = self.critic(batch.obs, batch.act) + critic_loss = F.mse_loss(current_q, target_q) + self.critic_optim.zero_grad() + critic_loss.backward() + self.critic_optim.step() + actor_loss = -self.critic( + batch.obs, self.actor(batch.obs, state=None)[0]).mean() + self.actor_optim.zero_grad() + actor_loss.backward() + self.actor_optim.step() + return actor_loss.detach().cpu().numpy(),\ + critic_loss.detach().cpu().numpy() diff --git a/tianshou/policy/dqn.py b/tianshou/policy/dqn.py index 8d99767..a4dec21 100644 --- a/tianshou/policy/dqn.py +++ b/tianshou/policy/dqn.py @@ -1,25 +1,24 @@ import torch import numpy as np -from torch import nn from copy import deepcopy +import torch.nn.functional as F from tianshou.data import Batch from tianshou.policy import BasePolicy -class DQNPolicy(BasePolicy, nn.Module): +class DQNPolicy(BasePolicy): """docstring for DQNPolicy""" - def __init__(self, model, optim, loss_fn, + def __init__(self, model, optim, discount_factor=0.99, estimation_step=1, use_target_network=True): super().__init__() self.model = model self.optim = optim - self.loss_fn = loss_fn self.eps = 0 - assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' + assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]' self._gamma = discount_factor assert estimation_step > 0, 'estimation_step should greater than 0' self._n_step = estimation_step @@ -91,7 +90,7 @@ class DQNPolicy(BasePolicy, nn.Module): r = batch.returns if isinstance(r, np.ndarray): r = torch.tensor(r, device=q.device, dtype=q.dtype) - loss = self.loss_fn(q, r) + loss = F.mse_loss(q, r) loss.backward() self.optim.step() return loss.detach().cpu().numpy() diff --git a/tianshou/policy/policy_gradient.py b/tianshou/policy/pg.py similarity index 92% rename from tianshou/policy/policy_gradient.py rename to tianshou/policy/pg.py index 8b698bd..430442b 100644 --- a/tianshou/policy/policy_gradient.py +++ b/tianshou/policy/pg.py @@ -1,13 +1,12 @@ import torch import numpy as np -from torch import nn import torch.nn.functional as F from tianshou.data import Batch from tianshou.policy import BasePolicy -class PGPolicy(BasePolicy, nn.Module): +class PGPolicy(BasePolicy): """docstring for PGPolicy""" def __init__(self, model, optim, dist_fn=torch.distributions.Categorical, @@ -17,7 +16,7 @@ class PGPolicy(BasePolicy, nn.Module): self.optim = optim self.dist_fn = dist_fn self._eps = np.finfo(np.float32).eps.item() - assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]' + assert 0 < discount_factor <= 1, 'discount_factor should in (0, 1]' self._gamma = discount_factor def process_fn(self, batch, buffer, indice): @@ -30,7 +29,7 @@ class PGPolicy(BasePolicy, nn.Module): logits, h = self.model(batch.obs, state=state, info=batch.info) logits = F.softmax(logits, dim=1) dist = self.dist_fn(logits) - act = dist.sample().detach().cpu().numpy() + act = dist.sample() return Batch(logits=logits, act=act, state=h, dist=dist) def learn(self, batch, batch_size=None):