From 8af7196a9a00f53cee4f904019ee895bf0944bff Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Fri, 29 May 2020 14:45:21 +0200 Subject: [PATCH] Robust conversion from/to numpy/pytorch (#63) * Enable to convert Batch data back to torch. * Add torch converter to collector. * Fix * Move to_numpy/to_torch convert in dedicated utils.py. * Use to_numpy/to_torch to convert arrays. * fix lint * fix * Add unit test to check Batch from/to numpy. * Fix Batch over Batch. Co-authored-by: Alexis Duburcq --- test/base/test_batch.py | 15 +++++++++++++ tianshou/data/__init__.py | 5 ++++- tianshou/data/batch.py | 20 ++++++++++++++++- tianshou/data/collector.py | 21 +++--------------- tianshou/data/utils.py | 36 +++++++++++++++++++++++++++++++ tianshou/policy/imitation/base.py | 6 +++--- tianshou/policy/modelfree/a2c.py | 8 +++---- tianshou/policy/modelfree/ddpg.py | 12 +++++------ tianshou/policy/modelfree/dqn.py | 19 ++++++++-------- tianshou/policy/modelfree/pg.py | 6 +++--- tianshou/policy/modelfree/ppo.py | 10 ++++----- tianshou/policy/modelfree/sac.py | 12 +++++------ tianshou/policy/modelfree/td3.py | 10 ++++----- 13 files changed, 119 insertions(+), 61 deletions(-) create mode 100644 tianshou/data/utils.py diff --git a/test/base/test_batch.py b/test/base/test_batch.py index 7929387..d169085 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,4 +1,5 @@ import pytest +import torch import numpy as np from tianshou.data import Batch @@ -29,6 +30,20 @@ def test_batch_over_batch(): assert batch2[-1].b.b == 0 +def test_batch_from_to_numpy_without_copy(): + batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) + a_mem_addr_orig = batch["a"].__array_interface__['data'][0] + c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0] + batch.to_torch() + assert isinstance(batch["a"], torch.Tensor) + assert isinstance(batch["b"]["c"], torch.Tensor) + batch.to_numpy() + a_mem_addr_new = batch["a"].__array_interface__['data'][0] + c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0] + assert a_mem_addr_new == a_mem_addr_orig + assert c_mem_addr_new == c_mem_addr_orig + + if __name__ == '__main__': test_batch() test_batch_over_batch() diff --git a/tianshou/data/__init__.py b/tianshou/data/__init__.py index b37c027..fd57d87 100644 --- a/tianshou/data/__init__.py +++ b/tianshou/data/__init__.py @@ -1,12 +1,15 @@ from tianshou.data.batch import Batch +from tianshou.data.utils import to_numpy, to_torch from tianshou.data.buffer import ReplayBuffer, \ ListReplayBuffer, PrioritizedReplayBuffer from tianshou.data.collector import Collector __all__ = [ 'Batch', + 'to_numpy', + 'to_torch', 'ReplayBuffer', 'ListReplayBuffer', 'PrioritizedReplayBuffer', - 'Collector', + 'Collector' ] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 4ab5bf0..526cea0 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -141,13 +141,31 @@ class Batch: return self.__getattr__(k) return d - def to_numpy(self) -> np.ndarray: + def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray. This is an inplace operation. """ for k, v in self.__dict__.items(): if isinstance(v, torch.Tensor): self.__dict__[k] = v.cpu().numpy() + elif isinstance(v, Batch): + v.to_numpy() + + def to_torch(self, + dtype: Optional[torch.dtype] = None, + device: Union[str, int] = 'cpu' + ) -> None: + """Change all numpy.ndarray to torch.Tensor. This is an inplace + operation. + """ + for k, v in self.__dict__.items(): + if isinstance(v, np.ndarray): + v = torch.from_numpy(v).to(device) + if dtype is not None: + v = v.type(dtype) + self.__dict__[k] = v + elif isinstance(v, Batch): + v.to_torch() def append(self, batch: 'Batch') -> None: """Append a :class:`~tianshou.data.Batch` object to current batch.""" diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 657e3ae..837f4fa 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List, Union, Optional, Callable from tianshou.utils import MovAvg from tianshou.env import BaseVectorEnv from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer +from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy class Collector(object): @@ -201,21 +201,6 @@ class Collector(object): elif isinstance(self.state, (torch.Tensor, np.ndarray)): self.state[id] = 0 - def _to_numpy(self, x: Union[ - torch.Tensor, dict, Batch, np.ndarray]) -> None: - """Return an object without torch.Tensor.""" - if isinstance(x, torch.Tensor): - return x.cpu().numpy() - elif isinstance(x, dict): - for k in x: - if isinstance(x[k], torch.Tensor): - x[k] = x[k].cpu().numpy() - return x - elif isinstance(x, Batch): - x.to_numpy() - return x - return x - def collect(self, n_step: int = 0, n_episode: Union[int, List[int]] = 0, @@ -270,9 +255,9 @@ class Collector(object): with torch.no_grad(): result = self.policy(batch, self.state) self.state = result.get('state', None) - self._policy = self._to_numpy(result.policy) \ + self._policy = to_numpy(result.policy) \ if hasattr(result, 'policy') else [{}] * self.env_num - self._act = self._to_numpy(result.act) + self._act = to_numpy(result.act) obs_next, self._rew, self._done, self._info = self.env.step( self._act if self._multi_env else self._act[0]) if not self._multi_env: diff --git a/tianshou/data/utils.py b/tianshou/data/utils.py new file mode 100644 index 0000000..8494074 --- /dev/null +++ b/tianshou/data/utils.py @@ -0,0 +1,36 @@ +import torch +import numpy as np +from typing import Union, Optional + +from tianshou.data import Batch + + +def to_numpy(x: Union[ + torch.Tensor, dict, Batch, np.ndarray]) -> Union[ + dict, Batch, np.ndarray]: + """Return an object without torch.Tensor.""" + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + elif isinstance(x, dict): + for k, v in x.items(): + x[k] = to_numpy(v) + elif isinstance(x, Batch): + x.to_numpy() + return x + + +def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray], + dtype: Optional[torch.dtype] = None, + device: Union[str, int] = 'cpu' + ) -> Union[dict, Batch, torch.Tensor]: + """Return an object without np.ndarray.""" + if isinstance(x, np.ndarray): + x = torch.from_numpy(x).to(device) + if dtype is not None: + x = x.type(dtype) + elif isinstance(x, dict): + for k, v in x.items(): + x[k] = to_torch(v, dtype, device) + elif isinstance(x, Batch): + x.to_torch() + return x diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 64c3c7e..aeb0eb3 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -3,7 +3,7 @@ import numpy as np import torch.nn.functional as F from typing import Dict, Union, Optional -from tianshou.data import Batch +from tianshou.data import Batch, to_torch from tianshou.policy import BasePolicy @@ -46,11 +46,11 @@ class ImitationPolicy(BasePolicy): self.optim.zero_grad() if self.mode == 'continuous': a = self(batch).act - a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device) + a_ = to_torch(batch.act, dtype=torch.float, device=a.device) loss = F.mse_loss(a, a_) elif self.mode == 'discrete': # classification a = self(batch).logits - a_ = torch.tensor(batch.act, dtype=torch.long, device=a.device) + a_ = to_torch(batch.act, dtype=torch.long, device=a.device) loss = F.nll_loss(a, a_) loss.backward() self.optim.step() diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 8219fef..bb40a25 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from typing import Dict, List, Union, Optional from tianshou.policy import PGPolicy -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy class A2CPolicy(PGPolicy): @@ -64,7 +64,7 @@ class A2CPolicy(PGPolicy): v_ = [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): - v_.append(self.critic(b.obs_next).detach().cpu().numpy()) + v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda) @@ -106,8 +106,8 @@ class A2CPolicy(PGPolicy): self.optim.zero_grad() dist = self(b).dist v = self.critic(b.obs) - a = torch.tensor(b.act, device=v.device) - r = torch.tensor(b.returns, device=v.device) + a = to_torch(b.act, device=v.device) + r = to_torch(b.returns, device=v.device) a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() vf_loss = F.mse_loss(r[:, None], v) ent_loss = dist.entropy().mean() diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index c4cd465..b211b44 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional from tianshou.policy import BasePolicy # from tianshou.exploration import OUNoise -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, to_torch class DDPGPolicy(BasePolicy): @@ -135,7 +135,7 @@ class DDPGPolicy(BasePolicy): eps = self._eps if eps > 0: # noise = np.random.normal(0, eps, size=logits.shape) - # logits += torch.tensor(noise, device=logits.device) + # logits += to_torch(noise, device=logits.device) # noise = self.noise(logits.shape, eps) logits += torch.randn( size=logits.shape, device=logits.device) * eps @@ -147,10 +147,10 @@ class DDPGPolicy(BasePolicy): target_q = self.critic_old(batch.obs_next, self( batch, model='actor_old', input='obs_next', eps=0).act) dev = target_q.device - rew = torch.tensor(batch.rew, - dtype=torch.float, device=dev)[:, None] - done = torch.tensor(batch.done, - dtype=torch.float, device=dev)[:, None] + rew = to_torch(batch.rew, + dtype=torch.float, device=dev)[:, None] + done = to_torch(batch.done, + dtype=torch.float, device=dev)[:, None] target_q = (rew + (1. - done) * self._gamma * target_q) current_q = self.critic(batch.obs, batch.act) critic_loss = F.mse_loss(current_q, target_q) diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 8f87d10..71c9866 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -5,7 +5,8 @@ import torch.nn.functional as F from typing import Dict, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer +from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ + to_torch, to_numpy class DQNPolicy(BasePolicy): @@ -96,12 +97,12 @@ class DQNPolicy(BasePolicy): target_q = self( terminal_data, model='model_old', input='obs_next').logits if isinstance(target_q, torch.Tensor): - target_q = target_q.detach().cpu().numpy() + target_q = to_numpy(target_q) target_q = target_q[np.arange(len(a)), a] else: target_q = self(terminal_data, input='obs_next').logits if isinstance(target_q, torch.Tensor): - target_q = target_q.detach().cpu().numpy() + target_q = to_numpy(target_q) target_q = target_q.max(axis=1) target_q[gammas != self._n_step] = 0 returns += (self._gamma ** gammas) * target_q @@ -111,11 +112,11 @@ class DQNPolicy(BasePolicy): q = q[np.arange(len(q)), batch.act] r = batch.returns if isinstance(r, np.ndarray): - r = torch.tensor(r, device=q.device, dtype=q.dtype) + r = to_torch(r, device=q.device, dtype=q.dtype) td = r - q - buffer.update_weight(indice, td.detach().cpu().numpy()) - impt_weight = torch.tensor(batch.impt_weight, - device=q.device, dtype=torch.float) + buffer.update_weight(indice, to_numpy(td)) + impt_weight = to_torch(batch.impt_weight, + device=q.device, dtype=torch.float) loss = (td.pow(2) * impt_weight).mean() if not hasattr(batch, 'loss'): batch.loss = loss @@ -147,7 +148,7 @@ class DQNPolicy(BasePolicy): model = getattr(self, model) obs = getattr(batch, input) q, h = model(obs, state=state, info=batch.info) - act = q.max(dim=1)[1].detach().cpu().numpy() + act = to_numpy(q.max(dim=1)[1]) # add eps to act if eps is None: eps = self.eps @@ -168,7 +169,7 @@ class DQNPolicy(BasePolicy): q = q[np.arange(len(q)), batch.act] r = batch.returns if isinstance(r, np.ndarray): - r = torch.tensor(r, device=q.device, dtype=q.dtype) + r = to_torch(r, device=q.device, dtype=q.dtype) loss = F.mse_loss(q, r) loss.backward() self.optim.step() diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index fd05adb..bd018ec 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -3,7 +3,7 @@ import numpy as np from typing import Dict, List, Union, Optional from tianshou.policy import BasePolicy -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, to_torch class PGPolicy(BasePolicy): @@ -88,8 +88,8 @@ class PGPolicy(BasePolicy): for b in batch.split(batch_size): self.optim.zero_grad() dist = self(b).dist - a = torch.tensor(b.act, device=dist.logits.device) - r = torch.tensor(b.returns, device=dist.logits.device) + a = to_torch(b.act, device=dist.logits.device) + r = to_torch(b.returns, device=dist.logits.device) loss = -(dist.log_prob(a) * r).sum() loss.backward() self.optim.step() diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 8631dc9..7c1d538 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -4,7 +4,7 @@ from torch import nn from typing import Dict, List, Tuple, Union, Optional from tianshou.policy import PGPolicy -from tianshou.data import Batch, ReplayBuffer +from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch class PPOPolicy(PGPolicy): @@ -88,7 +88,7 @@ class PPOPolicy(PGPolicy): with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) - v_ = torch.cat(v_, dim=0).cpu().numpy() + v_ = to_numpy(torch.cat(v_, dim=0)) return self.compute_episodic_return( batch, v_, gamma=self._gamma, gae_lambda=self._lambda) @@ -129,12 +129,12 @@ class PPOPolicy(PGPolicy): for b in batch.split(batch_size, shuffle=False): v.append(self.critic(b.obs)) old_log_prob.append(self(b).dist.log_prob( - torch.tensor(b.act, device=v[0].device))) + to_torch(b.act, device=v[0].device))) batch.v = torch.cat(v, dim=0) # old value dev = batch.v.device - batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) + batch.act = to_torch(batch.act, dtype=torch.float, device=dev) batch.logp_old = torch.cat(old_log_prob, dim=0) - batch.returns = torch.tensor( + batch.returns = to_torch( batch.returns, dtype=torch.float, device=dev ).reshape(batch.v.shape) if self._rew_norm: diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index e409438..bbf258c 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -4,7 +4,7 @@ from copy import deepcopy import torch.nn.functional as F from typing import Dict, Tuple, Union, Optional -from tianshou.data import Batch +from tianshou.data import Batch, to_torch from tianshou.policy import DDPGPolicy from tianshou.policy.dist import DiagGaussian @@ -110,15 +110,15 @@ class SACPolicy(DDPGPolicy): obs_next_result = self(batch, input='obs_next') a_ = obs_next_result.act dev = a_.device - batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) + batch.act = to_torch(batch.act, dtype=torch.float, device=dev) target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_), ) - self._alpha * obs_next_result.log_prob - rew = torch.tensor(batch.rew, - dtype=torch.float, device=dev)[:, None] - done = torch.tensor(batch.done, - dtype=torch.float, device=dev)[:, None] + rew = to_torch(batch.rew, + dtype=torch.float, device=dev)[:, None] + done = to_torch(batch.done, + dtype=torch.float, device=dev)[:, None] target_q = (rew + (1. - done) * self._gamma * target_q) # critic 1 current_q1 = self.critic1(batch.obs, batch.act) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index ee6519a..9f0b3cd 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -3,7 +3,7 @@ from copy import deepcopy import torch.nn.functional as F from typing import Dict, Tuple, Optional -from tianshou.data import Batch +from tianshou.data import Batch, to_torch from tianshou.policy import DDPGPolicy @@ -112,10 +112,10 @@ class TD3Policy(DDPGPolicy): target_q = torch.min( self.critic1_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_)) - rew = torch.tensor(batch.rew, - dtype=torch.float, device=dev)[:, None] - done = torch.tensor(batch.done, - dtype=torch.float, device=dev)[:, None] + rew = to_torch(batch.rew, + dtype=torch.float, device=dev)[:, None] + done = to_torch(batch.done, + dtype=torch.float, device=dev)[:, None] target_q = (rew + (1. - done) * self._gamma * target_q) # critic 1 current_q1 = self.critic1(batch.obs, batch.act)