Robust conversion from/to numpy/pytorch (#63)

* Enable to convert Batch data back to torch.

* Add torch converter to collector.

* Fix

* Move to_numpy/to_torch convert in dedicated utils.py.

* Use to_numpy/to_torch to convert arrays.

* fix lint

* fix

* Add unit test to check Batch from/to numpy.

* Fix Batch over Batch.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-05-29 14:45:21 +02:00 committed by GitHub
parent b5093ecb56
commit 8af7196a9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 119 additions and 61 deletions

View File

@ -1,4 +1,5 @@
import pytest import pytest
import torch
import numpy as np import numpy as np
from tianshou.data import Batch from tianshou.data import Batch
@ -29,6 +30,20 @@ def test_batch_over_batch():
assert batch2[-1].b.b == 0 assert batch2[-1].b.b == 0
def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
c_mem_addr_orig = batch["b"]["c"].__array_interface__['data'][0]
batch.to_torch()
assert isinstance(batch["a"], torch.Tensor)
assert isinstance(batch["b"]["c"], torch.Tensor)
batch.to_numpy()
a_mem_addr_new = batch["a"].__array_interface__['data'][0]
c_mem_addr_new = batch["b"]["c"].__array_interface__['data'][0]
assert a_mem_addr_new == a_mem_addr_orig
assert c_mem_addr_new == c_mem_addr_orig
if __name__ == '__main__': if __name__ == '__main__':
test_batch() test_batch()
test_batch_over_batch() test_batch_over_batch()

View File

@ -1,12 +1,15 @@
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
from tianshou.data.utils import to_numpy, to_torch
from tianshou.data.buffer import ReplayBuffer, \ from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector from tianshou.data.collector import Collector
__all__ = [ __all__ = [
'Batch', 'Batch',
'to_numpy',
'to_torch',
'ReplayBuffer', 'ReplayBuffer',
'ListReplayBuffer', 'ListReplayBuffer',
'PrioritizedReplayBuffer', 'PrioritizedReplayBuffer',
'Collector', 'Collector'
] ]

View File

@ -141,13 +141,31 @@ class Batch:
return self.__getattr__(k) return self.__getattr__(k)
return d return d
def to_numpy(self) -> np.ndarray: def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray. This is an inplace """Change all torch.Tensor to numpy.ndarray. This is an inplace
operation. operation.
""" """
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
self.__dict__[k] = v.cpu().numpy() self.__dict__[k] = v.cpu().numpy()
elif isinstance(v, Batch):
v.to_numpy()
def to_torch(self,
dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu'
) -> None:
"""Change all numpy.ndarray to torch.Tensor. This is an inplace
operation.
"""
for k, v in self.__dict__.items():
if isinstance(v, np.ndarray):
v = torch.from_numpy(v).to(device)
if dtype is not None:
v = v.type(dtype)
self.__dict__[k] = v
elif isinstance(v, Batch):
v.to_torch()
def append(self, batch: 'Batch') -> None: def append(self, batch: 'Batch') -> None:
"""Append a :class:`~tianshou.data.Batch` object to current batch.""" """Append a :class:`~tianshou.data.Batch` object to current batch."""

View File

@ -8,7 +8,7 @@ from typing import Any, Dict, List, Union, Optional, Callable
from tianshou.utils import MovAvg from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
class Collector(object): class Collector(object):
@ -201,21 +201,6 @@ class Collector(object):
elif isinstance(self.state, (torch.Tensor, np.ndarray)): elif isinstance(self.state, (torch.Tensor, np.ndarray)):
self.state[id] = 0 self.state[id] = 0
def _to_numpy(self, x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> None:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
return x.cpu().numpy()
elif isinstance(x, dict):
for k in x:
if isinstance(x[k], torch.Tensor):
x[k] = x[k].cpu().numpy()
return x
elif isinstance(x, Batch):
x.to_numpy()
return x
return x
def collect(self, def collect(self,
n_step: int = 0, n_step: int = 0,
n_episode: Union[int, List[int]] = 0, n_episode: Union[int, List[int]] = 0,
@ -270,9 +255,9 @@ class Collector(object):
with torch.no_grad(): with torch.no_grad():
result = self.policy(batch, self.state) result = self.policy(batch, self.state)
self.state = result.get('state', None) self.state = result.get('state', None)
self._policy = self._to_numpy(result.policy) \ self._policy = to_numpy(result.policy) \
if hasattr(result, 'policy') else [{}] * self.env_num if hasattr(result, 'policy') else [{}] * self.env_num
self._act = self._to_numpy(result.act) self._act = to_numpy(result.act)
obs_next, self._rew, self._done, self._info = self.env.step( obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0]) self._act if self._multi_env else self._act[0])
if not self._multi_env: if not self._multi_env:

36
tianshou/data/utils.py Normal file
View File

@ -0,0 +1,36 @@
import torch
import numpy as np
from typing import Union, Optional
from tianshou.data import Batch
def to_numpy(x: Union[
torch.Tensor, dict, Batch, np.ndarray]) -> Union[
dict, Batch, np.ndarray]:
"""Return an object without torch.Tensor."""
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_numpy(v)
elif isinstance(x, Batch):
x.to_numpy()
return x
def to_torch(x: Union[torch.Tensor, dict, Batch, np.ndarray],
dtype: Optional[torch.dtype] = None,
device: Union[str, int] = 'cpu'
) -> Union[dict, Batch, torch.Tensor]:
"""Return an object without np.ndarray."""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).to(device)
if dtype is not None:
x = x.type(dtype)
elif isinstance(x, dict):
for k, v in x.items():
x[k] = to_torch(v, dtype, device)
elif isinstance(x, Batch):
x.to_torch()
return x

View File

@ -3,7 +3,7 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from typing import Dict, Union, Optional from typing import Dict, Union, Optional
from tianshou.data import Batch from tianshou.data import Batch, to_torch
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
@ -46,11 +46,11 @@ class ImitationPolicy(BasePolicy):
self.optim.zero_grad() self.optim.zero_grad()
if self.mode == 'continuous': if self.mode == 'continuous':
a = self(batch).act a = self(batch).act
a_ = torch.tensor(batch.act, dtype=torch.float, device=a.device) a_ = to_torch(batch.act, dtype=torch.float, device=a.device)
loss = F.mse_loss(a, a_) loss = F.mse_loss(a, a_)
elif self.mode == 'discrete': # classification elif self.mode == 'discrete': # classification
a = self(batch).logits a = self(batch).logits
a_ = torch.tensor(batch.act, dtype=torch.long, device=a.device) a_ = to_torch(batch.act, dtype=torch.long, device=a.device)
loss = F.nll_loss(a, a_) loss = F.nll_loss(a, a_)
loss.backward() loss.backward()
self.optim.step() self.optim.step()

View File

@ -5,7 +5,7 @@ import torch.nn.functional as F
from typing import Dict, List, Union, Optional from typing import Dict, List, Union, Optional
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer, to_torch, to_numpy
class A2CPolicy(PGPolicy): class A2CPolicy(PGPolicy):
@ -64,7 +64,7 @@ class A2CPolicy(PGPolicy):
v_ = [] v_ = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False): for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy()) v_.append(to_numpy(self.critic(b.obs_next)))
v_ = np.concatenate(v_, axis=0) v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return( return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda) batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
@ -106,8 +106,8 @@ class A2CPolicy(PGPolicy):
self.optim.zero_grad() self.optim.zero_grad()
dist = self(b).dist dist = self(b).dist
v = self.critic(b.obs) v = self.critic(b.obs)
a = torch.tensor(b.act, device=v.device) a = to_torch(b.act, device=v.device)
r = torch.tensor(b.returns, device=v.device) r = to_torch(b.returns, device=v.device)
a_loss = -(dist.log_prob(a) * (r - v).detach()).mean() a_loss = -(dist.log_prob(a) * (r - v).detach()).mean()
vf_loss = F.mse_loss(r[:, None], v) vf_loss = F.mse_loss(r[:, None], v)
ent_loss = dist.entropy().mean() ent_loss = dist.entropy().mean()

View File

@ -6,7 +6,7 @@ from typing import Dict, Tuple, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
# from tianshou.exploration import OUNoise # from tianshou.exploration import OUNoise
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer, to_torch
class DDPGPolicy(BasePolicy): class DDPGPolicy(BasePolicy):
@ -135,7 +135,7 @@ class DDPGPolicy(BasePolicy):
eps = self._eps eps = self._eps
if eps > 0: if eps > 0:
# noise = np.random.normal(0, eps, size=logits.shape) # noise = np.random.normal(0, eps, size=logits.shape)
# logits += torch.tensor(noise, device=logits.device) # logits += to_torch(noise, device=logits.device)
# noise = self.noise(logits.shape, eps) # noise = self.noise(logits.shape, eps)
logits += torch.randn( logits += torch.randn(
size=logits.shape, device=logits.device) * eps size=logits.shape, device=logits.device) * eps
@ -147,10 +147,10 @@ class DDPGPolicy(BasePolicy):
target_q = self.critic_old(batch.obs_next, self( target_q = self.critic_old(batch.obs_next, self(
batch, model='actor_old', input='obs_next', eps=0).act) batch, model='actor_old', input='obs_next', eps=0).act)
dev = target_q.device dev = target_q.device
rew = torch.tensor(batch.rew, rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q) target_q = (rew + (1. - done) * self._gamma * target_q)
current_q = self.critic(batch.obs, batch.act) current_q = self.critic(batch.obs, batch.act)
critic_loss = F.mse_loss(current_q, target_q) critic_loss = F.mse_loss(current_q, target_q)

View File

@ -5,7 +5,8 @@ import torch.nn.functional as F
from typing import Dict, Union, Optional from typing import Dict, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch, to_numpy
class DQNPolicy(BasePolicy): class DQNPolicy(BasePolicy):
@ -96,12 +97,12 @@ class DQNPolicy(BasePolicy):
target_q = self( target_q = self(
terminal_data, model='model_old', input='obs_next').logits terminal_data, model='model_old', input='obs_next').logits
if isinstance(target_q, torch.Tensor): if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy() target_q = to_numpy(target_q)
target_q = target_q[np.arange(len(a)), a] target_q = target_q[np.arange(len(a)), a]
else: else:
target_q = self(terminal_data, input='obs_next').logits target_q = self(terminal_data, input='obs_next').logits
if isinstance(target_q, torch.Tensor): if isinstance(target_q, torch.Tensor):
target_q = target_q.detach().cpu().numpy() target_q = to_numpy(target_q)
target_q = target_q.max(axis=1) target_q = target_q.max(axis=1)
target_q[gammas != self._n_step] = 0 target_q[gammas != self._n_step] = 0
returns += (self._gamma ** gammas) * target_q returns += (self._gamma ** gammas) * target_q
@ -111,11 +112,11 @@ class DQNPolicy(BasePolicy):
q = q[np.arange(len(q)), batch.act] q = q[np.arange(len(q)), batch.act]
r = batch.returns r = batch.returns
if isinstance(r, np.ndarray): if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype) r = to_torch(r, device=q.device, dtype=q.dtype)
td = r - q td = r - q
buffer.update_weight(indice, td.detach().cpu().numpy()) buffer.update_weight(indice, to_numpy(td))
impt_weight = torch.tensor(batch.impt_weight, impt_weight = to_torch(batch.impt_weight,
device=q.device, dtype=torch.float) device=q.device, dtype=torch.float)
loss = (td.pow(2) * impt_weight).mean() loss = (td.pow(2) * impt_weight).mean()
if not hasattr(batch, 'loss'): if not hasattr(batch, 'loss'):
batch.loss = loss batch.loss = loss
@ -147,7 +148,7 @@ class DQNPolicy(BasePolicy):
model = getattr(self, model) model = getattr(self, model)
obs = getattr(batch, input) obs = getattr(batch, input)
q, h = model(obs, state=state, info=batch.info) q, h = model(obs, state=state, info=batch.info)
act = q.max(dim=1)[1].detach().cpu().numpy() act = to_numpy(q.max(dim=1)[1])
# add eps to act # add eps to act
if eps is None: if eps is None:
eps = self.eps eps = self.eps
@ -168,7 +169,7 @@ class DQNPolicy(BasePolicy):
q = q[np.arange(len(q)), batch.act] q = q[np.arange(len(q)), batch.act]
r = batch.returns r = batch.returns
if isinstance(r, np.ndarray): if isinstance(r, np.ndarray):
r = torch.tensor(r, device=q.device, dtype=q.dtype) r = to_torch(r, device=q.device, dtype=q.dtype)
loss = F.mse_loss(q, r) loss = F.mse_loss(q, r)
loss.backward() loss.backward()
self.optim.step() self.optim.step()

View File

@ -3,7 +3,7 @@ import numpy as np
from typing import Dict, List, Union, Optional from typing import Dict, List, Union, Optional
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer, to_torch
class PGPolicy(BasePolicy): class PGPolicy(BasePolicy):
@ -88,8 +88,8 @@ class PGPolicy(BasePolicy):
for b in batch.split(batch_size): for b in batch.split(batch_size):
self.optim.zero_grad() self.optim.zero_grad()
dist = self(b).dist dist = self(b).dist
a = torch.tensor(b.act, device=dist.logits.device) a = to_torch(b.act, device=dist.logits.device)
r = torch.tensor(b.returns, device=dist.logits.device) r = to_torch(b.returns, device=dist.logits.device)
loss = -(dist.log_prob(a) * r).sum() loss = -(dist.log_prob(a) * r).sum()
loss.backward() loss.backward()
self.optim.step() self.optim.step()

View File

@ -4,7 +4,7 @@ from torch import nn
from typing import Dict, List, Tuple, Union, Optional from typing import Dict, List, Tuple, Union, Optional
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch
class PPOPolicy(PGPolicy): class PPOPolicy(PGPolicy):
@ -88,7 +88,7 @@ class PPOPolicy(PGPolicy):
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, shuffle=False): for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next)) v_.append(self.critic(b.obs_next))
v_ = torch.cat(v_, dim=0).cpu().numpy() v_ = to_numpy(torch.cat(v_, dim=0))
return self.compute_episodic_return( return self.compute_episodic_return(
batch, v_, gamma=self._gamma, gae_lambda=self._lambda) batch, v_, gamma=self._gamma, gae_lambda=self._lambda)
@ -129,12 +129,12 @@ class PPOPolicy(PGPolicy):
for b in batch.split(batch_size, shuffle=False): for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs)) v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob( old_log_prob.append(self(b).dist.log_prob(
torch.tensor(b.act, device=v[0].device))) to_torch(b.act, device=v[0].device)))
batch.v = torch.cat(v, dim=0) # old value batch.v = torch.cat(v, dim=0) # old value
dev = batch.v.device dev = batch.v.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
batch.logp_old = torch.cat(old_log_prob, dim=0) batch.logp_old = torch.cat(old_log_prob, dim=0)
batch.returns = torch.tensor( batch.returns = to_torch(
batch.returns, dtype=torch.float, device=dev batch.returns, dtype=torch.float, device=dev
).reshape(batch.v.shape) ).reshape(batch.v.shape)
if self._rew_norm: if self._rew_norm:

View File

@ -4,7 +4,7 @@ from copy import deepcopy
import torch.nn.functional as F import torch.nn.functional as F
from typing import Dict, Tuple, Union, Optional from typing import Dict, Tuple, Union, Optional
from tianshou.data import Batch from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.dist import DiagGaussian from tianshou.policy.dist import DiagGaussian
@ -110,15 +110,15 @@ class SACPolicy(DDPGPolicy):
obs_next_result = self(batch, input='obs_next') obs_next_result = self(batch, input='obs_next')
a_ = obs_next_result.act a_ = obs_next_result.act
dev = a_.device dev = a_.device
batch.act = torch.tensor(batch.act, dtype=torch.float, device=dev) batch.act = to_torch(batch.act, dtype=torch.float, device=dev)
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_), self.critic2_old(batch.obs_next, a_),
) - self._alpha * obs_next_result.log_prob ) - self._alpha * obs_next_result.log_prob
rew = torch.tensor(batch.rew, rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q) target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act) current_q1 = self.critic1(batch.obs, batch.act)

View File

@ -3,7 +3,7 @@ from copy import deepcopy
import torch.nn.functional as F import torch.nn.functional as F
from typing import Dict, Tuple, Optional from typing import Dict, Tuple, Optional
from tianshou.data import Batch from tianshou.data import Batch, to_torch
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
@ -112,10 +112,10 @@ class TD3Policy(DDPGPolicy):
target_q = torch.min( target_q = torch.min(
self.critic1_old(batch.obs_next, a_), self.critic1_old(batch.obs_next, a_),
self.critic2_old(batch.obs_next, a_)) self.critic2_old(batch.obs_next, a_))
rew = torch.tensor(batch.rew, rew = to_torch(batch.rew,
dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
done = torch.tensor(batch.done, done = to_torch(batch.done,
dtype=torch.float, device=dev)[:, None] dtype=torch.float, device=dev)[:, None]
target_q = (rew + (1. - done) * self._gamma * target_q) target_q = (rew + (1. - done) * self._gamma * target_q)
# critic 1 # critic 1
current_q1 = self.critic1(batch.obs, batch.act) current_q1 = self.critic1(batch.obs, batch.act)