Add pickle support for Batch. Fix VectorEnv. (#67)
* Fix vecenv. * Add pickle support for Batch class. * Add Batch pickle Unit Test. * Fix lint. * Swap Batch UT. * Fix lint. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
dd3e2130bb
commit
529a4cf44c
@ -1,4 +1,5 @@
|
||||
import pytest
|
||||
import pickle
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
@ -30,6 +31,15 @@ def test_batch_over_batch():
|
||||
assert batch2[-1].b.b == 0
|
||||
|
||||
|
||||
def test_batch_pickle():
|
||||
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
|
||||
np=np.zeros([3, 4]))
|
||||
batch_pk = pickle.loads(pickle.dumps(batch))
|
||||
assert batch.obs.a == batch_pk.obs.a
|
||||
assert torch.all(batch.obs.c == batch_pk.obs.c)
|
||||
assert np.all(batch.np == batch_pk.np)
|
||||
|
||||
|
||||
def test_batch_from_to_numpy_without_copy():
|
||||
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
||||
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
|
||||
|
@ -1,8 +1,15 @@
|
||||
import torch
|
||||
import warnings
|
||||
import pprint
|
||||
import numpy as np
|
||||
from typing import Any, List, Union, Iterator, Optional
|
||||
|
||||
# Disable pickle warning related to torch, since it has been removed
|
||||
# on torch master branch. See Pull Request #39003 for details:
|
||||
# https://github.com/pytorch/pytorch/pull/39003
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="pickle support for Storage will be removed in 1.5.")
|
||||
|
||||
|
||||
class Batch:
|
||||
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||
@ -66,9 +73,13 @@ class Batch:
|
||||
[11 22] [6 6]
|
||||
"""
|
||||
|
||||
def __new__(cls, **kwargs) -> None:
|
||||
self = super().__new__(cls)
|
||||
self._meta = {}
|
||||
return self
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._meta = {}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, (list, np.ndarray)) \
|
||||
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
||||
@ -86,6 +97,25 @@ class Batch:
|
||||
else:
|
||||
self.__dict__[k] = kwargs[k]
|
||||
|
||||
def __getstate__(self):
|
||||
"""Pickling interface. Only the actual data are serialized
|
||||
for both efficiency and simplicity.
|
||||
"""
|
||||
state = {}
|
||||
for k in self.keys():
|
||||
v = self[k]
|
||||
if isinstance(v, Batch):
|
||||
v = v.__getstate__()
|
||||
state[k] = v
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Unpickling interface. At this point, self is an empty Batch
|
||||
instance that has not been initialized, so it can safely be
|
||||
initialized by the pickle state.
|
||||
"""
|
||||
self.__init__(**state)
|
||||
|
||||
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
|
||||
"""Return self[index]."""
|
||||
if isinstance(index, str):
|
||||
|
18
tianshou/env/vecenv.py
vendored
18
tianshou/env/vecenv.py
vendored
@ -12,7 +12,7 @@ except ImportError:
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
class BaseVectorEnv(ABC, gym.Env):
|
||||
"""Base class for vectorized environments wrapper. Usage:
|
||||
::
|
||||
|
||||
@ -79,8 +79,14 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||
"""Set the seed for all environments. Accept ``None``, an int (which
|
||||
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
|
||||
"""Set the seed for all environments.
|
||||
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
``[i, i + 1, i + 2, ...]``) or a list.
|
||||
|
||||
:return: The list of seeds used in this env's random number generators.
|
||||
The first value in the list should be the "main" seed, or the value
|
||||
which a reproducer should pass to 'seed'.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -91,7 +97,11 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments."""
|
||||
"""Close all of the environments.
|
||||
|
||||
Environments will automatically close() themselves when garbage
|
||||
collected or when the program exits.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user