Add pickle support for Batch. Fix VectorEnv. (#67)

* Fix vecenv.

* Add pickle support for Batch class.

* Add Batch pickle Unit Test.

* Fix lint.

* Swap Batch UT.

* Fix lint.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-05-30 15:29:33 +02:00 committed by GitHub
parent dd3e2130bb
commit 529a4cf44c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import pytest import pytest
import pickle
import torch import torch
import numpy as np import numpy as np
@ -30,6 +31,15 @@ def test_batch_over_batch():
assert batch2[-1].b.b == 0 assert batch2[-1].b.b == 0
def test_batch_pickle():
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
assert batch.obs.a == batch_pk.obs.a
assert torch.all(batch.obs.c == batch_pk.obs.c)
assert np.all(batch.np == batch_pk.np)
def test_batch_from_to_numpy_without_copy(): def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch["a"].__array_interface__['data'][0] a_mem_addr_orig = batch["a"].__array_interface__['data'][0]

View File

@ -1,8 +1,15 @@
import torch import torch
import warnings
import pprint import pprint
import numpy as np import numpy as np
from typing import Any, List, Union, Iterator, Optional from typing import Any, List, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
# https://github.com/pytorch/pytorch/pull/39003
warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.")
class Batch: class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
@ -66,9 +73,13 @@ class Batch:
[11 22] [6 6] [11 22] [6 6]
""" """
def __new__(cls, **kwargs) -> None:
self = super().__new__(cls)
self._meta = {}
return self
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
super().__init__() super().__init__()
self._meta = {}
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, (list, np.ndarray)) \ if isinstance(v, (list, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict) and k != 'info': and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
@ -86,6 +97,25 @@ class Batch:
else: else:
self.__dict__[k] = kwargs[k] self.__dict__[k] = kwargs[k]
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized
for both efficiency and simplicity.
"""
state = {}
for k in self.keys():
v = self[k]
if isinstance(v, Batch):
v = v.__getstate__()
state[k] = v
return state
def __setstate__(self, state):
"""Unpickling interface. At this point, self is an empty Batch
instance that has not been initialized, so it can safely be
initialized by the pickle state.
"""
self.__init__(**state)
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]: def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
"""Return self[index].""" """Return self[index]."""
if isinstance(index, str): if isinstance(index, str):

View File

@ -12,7 +12,7 @@ except ImportError:
from tianshou.env.utils import CloudpickleWrapper from tianshou.env.utils import CloudpickleWrapper
class BaseVectorEnv(ABC, gym.Wrapper): class BaseVectorEnv(ABC, gym.Env):
"""Base class for vectorized environments wrapper. Usage: """Base class for vectorized environments wrapper. Usage:
:: ::
@ -79,8 +79,14 @@ class BaseVectorEnv(ABC, gym.Wrapper):
@abstractmethod @abstractmethod
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Set the seed for all environments. Accept ``None``, an int (which """Set the seed for all environments.
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
Accept ``None``, an int (which will extend ``i`` to
``[i, i + 1, i + 2, ...]``) or a list.
:return: The list of seeds used in this env's random number generators.
The first value in the list should be the "main" seed, or the value
which a reproducer should pass to 'seed'.
""" """
pass pass
@ -91,7 +97,11 @@ class BaseVectorEnv(ABC, gym.Wrapper):
@abstractmethod @abstractmethod
def close(self) -> None: def close(self) -> None:
"""Close all of the environments.""" """Close all of the environments.
Environments will automatically close() themselves when garbage
collected or when the program exits.
"""
pass pass