Robust conversion from/to numpy/pytorch (#63)

* Enable to convert Batch data back to torch.

* Add torch converter to collector.

* Fix

* Move to_numpy/to_torch convert in dedicated utils.py.

* Use to_numpy/to_torch to convert arrays.

* fix lint

* fix

* Add unit test to check Batch from/to numpy.

* Fix Batch over Batch.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-05-29 14:45:21 +02:00 committed by GitHub
parent b5093ecb56
commit 8af7196a9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 119 additions and 61 deletions

View File

@ -1,4 +1,5 @@
import pytest
import torch
import numpy as np
from tianshou.data import Batch
@ -29,6 +30,20 @@ def test_batch_over_batch():
assert batch2[-1].b.b == 0
def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0]
batch.to_torch()
assert isinstance(batch["a"], torch.Tensor)
assert isinstance(batch["b"]["c"], torch.Tensor)
batch.to_numpy()
a_mem_addr_new = batch["a"].__array_interface__['data'][0]
c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0]
assert a_mem_addr_new == a_mem_addr_orig
assert c_mem_addr_new == c_mem_addr_orig
if __name__ == '__main__':
test_batch()
test_batch_over_batch()

View File

@ -1,12 +1,15 @@
from tianshou.data.batch import Batch
from tianshou.data.utils import to_numpy, to_torch
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
__all__ = [
'Batch',
'to_numpy',
'to_torch',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
'Collector',
'Collector'
]

View File

@ -141,13 +141,31 @@ class Batch:
return self.__getattr__(k)
return d
def to_numpy(self) -> np.ndarray:
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an inplace
operation.
"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
self.__dict__[k] = v.cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu'
) -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
operation.
"""
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
elif isinstance(v, Batch):
v.to_torch()
def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch."""

View File

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
class Collector(object):
@ -201,21 +201,6 @@ class Collector(object):
elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] = 0
def _to_numpy(self, x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> None:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
elif isinstance(x, dict):
for k in x:
if isinstance(x[k], torch.Tensor):
x[k] = x[k].cpu().numpy()
return x
elif isinstance(x, Batch):
x.to_numpy()
return x
return x
def collect(self,
n_step: int = 0,
n_episode: Union[int, List[int]] = 0,
@ -270,9 +255,9 @@ class Collector(object):
with torch.no_grad():
result = self.policy(batch, self.state)
self.state = result.get('state', None)
self._policy = self._to_numpy(result.policy) \
self._policy = to_numpy(result.policy) \
if hasattr(result, 'policy') else [{}] * self.env_num
self._act = self._to_numpy(result.act)
self._act = to_numpy(result.act)
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0])
if not self._multi_env:

36
tianshou/data/utils.py Normal file
View File

@ -0,0 +1,36 @@
import torch
import numpy as np
from typing import Union, Optional
from tianshou.data import Batch
def to_numpy(x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
dict, Batch, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_numpy(v)
elif isinstance(x, Batch):
x.to_numpy()
return x
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch):
x.to_torch()
return x

View File

@ -3,7 +3,7 @@ import numpy as np
import torch.nn.functional as F
from typing import Dict, Union, Optional
from tianshou.data import Batch
from tianshou.data import Batch, to_torch
from tianshou.policy import BasePolicy
@ -46,11 +46,11 @@ class ImitationPolicy(BasePolicy):
self.optim.zero_grad()
if self.mode == 'continuous':
a = self(batch).act
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device)
a_ = to_torch(batch.act, dtype=torch.float, device=a.device)
loss = F.mse_loss(a, a_)
elif self.mode == 'discrete': # classification
a = self(batch).logits
a_ = torch.tensor(batch.act, dtype=torch.long, device=a.device)
a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_)
loss.backward()
self.optim.step()

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Dict, List, Union, Optional
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy
class A2CPolicy(PGPolicy):
@ -64,7 +64,7 @@ class A2CPolicy(PGPolicy):
v_ = []
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
@ -106,8 +106,8 @@ class A2CPolicy(PGPolicy):
self.optim.zero_grad()
dist = self(b).dist
v = self.critic(b.obs)
a = torch.tensor(b.act, device=v.device)
r = torch.tensor(b.returns, device=v.device)
a = to_torch(b.act, device=v.device)
r = to_torch(b.returns, device=v.device)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean()

View File

@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch, ReplayBuffer, to_torch
class DDPGPolicy(BasePolicy):
@ -135,7 +135,7 @@ class DDPGPolicy(BasePolicy):
eps = self._eps
if eps > 0:
# noise = np.random.normal(0, eps, size=logits.shape)
# logits += torch.tensor(noise, device=logits.device)
# logits += to_torch(noise, device=logits.device)
# noise = self.noise(logits.shape, eps)
logits += torch.randn(
size=logits.shape, device=logits.device) * eps
@ -147,10 +147,10 @@ class DDPGPolicy(BasePolicy):
target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q)

View File

@ -5,7 +5,8 @@ import torch.nn.functional as F
from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch, to_numpy
class DQNPolicy(BasePolicy):
@ -96,12 +97,12 @@ class DQNPolicy(BasePolicy):
target_q = self(
terminal_data, model='model_old', input='obs_next').logits
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = to_numpy(target_q)
target_q = target_q[np.arange(len(a)), a]
else:
target_q = self(terminal_data, input='obs_next').logits
if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy()
target_q = to_numpy(target_q)
target_q = target_q.max(axis=1)
target_q[gammas != self._n_step] = 0
returns += (self._gamma ** gammas) * target_q
@ -111,11 +112,11 @@ class DQNPolicy(BasePolicy):
q = q[np.arange(len(q)), batch.act]
r = batch.returns
if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype)
r = to_torch(r, device=q.device, dtype=q.dtype)
td = r - q
buffer.update_weight(indice, td.detach().cpu().numpy())
impt_weight = torch.tensor(batch.impt_weight,
device=q.device, dtype=torch.float)
buffer.update_weight(indice, to_numpy(td))
impt_weight = to_torch(batch.impt_weight,
device=q.device, dtype=torch.float)
loss = (td.pow(2) * impt_weight).mean()
if not hasattr(batch, 'loss'):
batch.loss = loss
@ -147,7 +148,7 @@ class DQNPolicy(BasePolicy):
model = getattr(self, model)
obs = getattr(batch, input)
q, h = model(obs, state=state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy()
act = to_numpy(q.max(dim=1)[1])
# add eps to act
if eps is None:
eps = self.eps
@ -168,7 +169,7 @@ class DQNPolicy(BasePolicy):
q = q[np.arange(len(q)), batch.act]
r = batch.returns
if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype)
r = to_torch(r, device=q.device, dtype=q.dtype)
loss = F.mse_loss(q, r)
loss.backward()
self.optim.step()

View File

@ -3,7 +3,7 @@ import numpy as np
from typing import Dict, List, Union, Optional
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch, ReplayBuffer, to_torch
class PGPolicy(BasePolicy):
@ -88,8 +88,8 @@ class PGPolicy(BasePolicy):
for b in batch.split(batch_size):
self.optim.zero_grad()
dist = self(b).dist
a = torch.tensor(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device)
a = to_torch(b.act, device=dist.logits.device)
r = to_torch(b.returns, device=dist.logits.device)
loss = -(dist.log_prob(a) * r).sum()
loss.backward()
self.optim.step()

View File

@ -4,7 +4,7 @@ from torch import nn
from typing import Dict, List, Tuple, Union, Optional
from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
class PPOPolicy(PGPolicy):
@ -88,7 +88,7 @@ class PPOPolicy(PGPolicy):
with torch.no_grad():
for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next))
v_ = torch.cat(v_, dim=0).cpu().numpy()
v_ = to_numpy(torch.cat(v_, dim=0))
return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
@ -129,12 +129,12 @@ class PPOPolicy(PGPolicy):
for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob(
torch.tensor(b.act, device=v[0].device)))
to_torch(b.act, device=v[0].device)))
batch.v = torch.cat(v, dim=0) # old value
dev = batch.v.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = torch.tensor(
batch.returns = to_torch(
batch.returns, dtype=torch.float, device=dev
).reshape(batch.v.shape)
if self._rew_norm:

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch
from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian
@ -110,15 +110,15 @@ class SACPolicy(DDPGPolicy):
obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act
dev = a_.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev)
batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)

View File

@ -3,7 +3,7 @@ from copy import deepcopy
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from tianshou.data import Batch
from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy
@ -112,10 +112,10 @@ class TD3Policy(DDPGPolicy):
target_q = torch.min(
self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done,
dtype=torch.float, device=dev)[:, None]
rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None]
done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1
current_q1 = self.critic1(batch.obs, batch.act)