Add pickle support for Batch. Fix VectorEnv. (#67)

* Fix vecenv.

* Add pickle support for Batch class.

* Add Batch pickle Unit Test.

* Fix lint.

* Swap Batch UT.

* Fix lint.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
Alexis DUBURCQ 2020-05-30 15:29:33 +02:00 committed by GitHub
parent dd3e2130bb
commit 529a4cf44c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 5 deletions

View File

@ -1,4 +1,5 @@
import pytest
import pickle
import torch
import numpy as np
@ -30,6 +31,15 @@ def test_batch_over_batch():
assert batch2[-1].b.b == 0
def test_batch_pickle():
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
assert batch.obs.a == batch_pk.obs.a
assert torch.all(batch.obs.c == batch_pk.obs.c)
assert np.all(batch.np == batch_pk.np)
def test_batch_from_to_numpy_without_copy():
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]

View File

@ -1,8 +1,15 @@
import torch
import warnings
import pprint
import numpy as np
from typing import Any, List, Union, Iterator, Optional
# Disable pickle warning related to torch, since it has been removed
# on torch master branch. See Pull Request #39003 for details:
# https://github.com/pytorch/pytorch/pull/39003
warnings.filterwarnings(
"ignore", message="pickle support for Storage will be removed in 1.5.")
class Batch:
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
@ -66,9 +73,13 @@ class Batch:
[11 22] [6 6]
"""
def __new__(cls, **kwargs) -> None:
self = super().__new__(cls)
self._meta = {}
return self
def __init__(self, **kwargs) -> None:
super().__init__()
self._meta = {}
for k, v in kwargs.items():
if isinstance(v, (list, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
@ -86,6 +97,25 @@ class Batch:
else:
self.__dict__[k] = kwargs[k]
def __getstate__(self):
"""Pickling interface. Only the actual data are serialized
for both efficiency and simplicity.
"""
state = {}
for k in self.keys():
v = self[k]
if isinstance(v, Batch):
v = v.__getstate__()
state[k] = v
return state
def __setstate__(self, state):
"""Unpickling interface. At this point, self is an empty Batch
instance that has not been initialized, so it can safely be
initialized by the pickle state.
"""
self.__init__(**state)
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
"""Return self[index]."""
if isinstance(index, str):

View File

@ -12,7 +12,7 @@ except ImportError:
from tianshou.env.utils import CloudpickleWrapper
class BaseVectorEnv(ABC, gym.Wrapper):
class BaseVectorEnv(ABC, gym.Env):
"""Base class for vectorized environments wrapper. Usage:
::
@ -79,8 +79,14 @@ class BaseVectorEnv(ABC, gym.Wrapper):
@abstractmethod
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Set the seed for all environments. Accept ``None``, an int (which
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
"""Set the seed for all environments.
Accept ``None``, an int (which will extend ``i`` to
``[i, i + 1, i + 2, ...]``) or a list.
:return: The list of seeds used in this env's random number generators.
The first value in the list should be the "main" seed, or the value
which a reproducer should pass to 'seed'.
"""
pass
@ -91,7 +97,11 @@ class BaseVectorEnv(ABC, gym.Wrapper):
@abstractmethod
def close(self) -> None:
"""Close all of the environments."""
"""Close all of the environments.
Environments will automatically close() themselves when garbage
collected or when the program exits.
"""
pass