diff --git a/test/base/test_batch.py b/test/base/test_batch.py index d169085..eafd7b3 100644 --- a/test/base/test_batch.py +++ b/test/base/test_batch.py @@ -1,4 +1,5 @@ import pytest +import pickle import torch import numpy as np @@ -30,6 +31,15 @@ def test_batch_over_batch(): assert batch2[-1].b.b == 0 +def test_batch_pickle(): + batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), + np=np.zeros([3, 4])) + batch_pk = pickle.loads(pickle.dumps(batch)) + assert batch.obs.a == batch_pk.obs.a + assert torch.all(batch.obs.c == batch_pk.obs.c) + assert np.all(batch.np == batch_pk.np) + + def test_batch_from_to_numpy_without_copy(): batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,)))) a_mem_addr_orig = batch["a"].__array_interface__['data'][0] diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 526cea0..e0129f7 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -1,8 +1,15 @@ import torch +import warnings import pprint import numpy as np from typing import Any, List, Union, Iterator, Optional +# Disable pickle warning related to torch, since it has been removed +# on torch master branch. See Pull Request #39003 for details: +# https://github.com/pytorch/pytorch/pull/39003 +warnings.filterwarnings( + "ignore", message="pickle support for Storage will be removed in 1.5.") + class Batch: """Tianshou provides :class:`~tianshou.data.Batch` as the internal data @@ -66,9 +73,13 @@ class Batch: [11 22] [6 6] """ + def __new__(cls, **kwargs) -> None: + self = super().__new__(cls) + self._meta = {} + return self + def __init__(self, **kwargs) -> None: super().__init__() - self._meta = {} for k, v in kwargs.items(): if isinstance(v, (list, np.ndarray)) \ and len(v) > 0 and isinstance(v[0], dict) and k != 'info': @@ -86,6 +97,25 @@ class Batch: else: self.__dict__[k] = kwargs[k] + def __getstate__(self): + """Pickling interface. Only the actual data are serialized + for both efficiency and simplicity. + """ + state = {} + for k in self.keys(): + v = self[k] + if isinstance(v, Batch): + v = v.__getstate__() + state[k] = v + return state + + def __setstate__(self, state): + """Unpickling interface. At this point, self is an empty Batch + instance that has not been initialized, so it can safely be + initialized by the pickle state. + """ + self.__init__(**state) + def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]: """Return self[index].""" if isinstance(index, str): diff --git a/tianshou/env/vecenv.py b/tianshou/env/vecenv.py index 6145cc7..f2ea4b7 100644 --- a/tianshou/env/vecenv.py +++ b/tianshou/env/vecenv.py @@ -12,7 +12,7 @@ except ImportError: from tianshou.env.utils import CloudpickleWrapper -class BaseVectorEnv(ABC, gym.Wrapper): +class BaseVectorEnv(ABC, gym.Env): """Base class for vectorized environments wrapper. Usage: :: @@ -79,8 +79,14 @@ class BaseVectorEnv(ABC, gym.Wrapper): @abstractmethod def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None: - """Set the seed for all environments. Accept ``None``, an int (which - will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list. + """Set the seed for all environments. + + Accept ``None``, an int (which will extend ``i`` to + ``[i, i + 1, i + 2, ...]``) or a list. + + :return: The list of seeds used in this env's random number generators. + The first value in the list should be the "main" seed, or the value + which a reproducer should pass to 'seed'. """ pass @@ -91,7 +97,11 @@ class BaseVectorEnv(ABC, gym.Wrapper): @abstractmethod def close(self) -> None: - """Close all of the environments.""" + """Close all of the environments. + + Environments will automatically close() themselves when garbage + collected or when the program exits. + """ pass