Add pickle support for Batch. Fix VectorEnv. (#67)
* Fix vecenv. * Add pickle support for Batch class. * Add Batch pickle Unit Test. * Fix lint. * Swap Batch UT. * Fix lint. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
parent
dd3e2130bb
commit
529a4cf44c
@ -1,4 +1,5 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import pickle
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -30,6 +31,15 @@ def test_batch_over_batch():
|
|||||||
assert batch2[-1].b.b == 0
|
assert batch2[-1].b.b == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_batch_pickle():
|
||||||
|
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
|
||||||
|
np=np.zeros([3, 4]))
|
||||||
|
batch_pk = pickle.loads(pickle.dumps(batch))
|
||||||
|
assert batch.obs.a == batch_pk.obs.a
|
||||||
|
assert torch.all(batch.obs.c == batch_pk.obs.c)
|
||||||
|
assert np.all(batch.np == batch_pk.np)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_from_to_numpy_without_copy():
|
def test_batch_from_to_numpy_without_copy():
|
||||||
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
|
||||||
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
|
a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
|
||||||
|
@ -1,8 +1,15 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import warnings
|
||||||
import pprint
|
import pprint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Any, List, Union, Iterator, Optional
|
from typing import Any, List, Union, Iterator, Optional
|
||||||
|
|
||||||
|
# Disable pickle warning related to torch, since it has been removed
|
||||||
|
# on torch master branch. See Pull Request #39003 for details:
|
||||||
|
# https://github.com/pytorch/pytorch/pull/39003
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore", message="pickle support for Storage will be removed in 1.5.")
|
||||||
|
|
||||||
|
|
||||||
class Batch:
|
class Batch:
|
||||||
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
"""Tianshou provides :class:`~tianshou.data.Batch` as the internal data
|
||||||
@ -66,9 +73,13 @@ class Batch:
|
|||||||
[11 22] [6 6]
|
[11 22] [6 6]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __new__(cls, **kwargs) -> None:
|
||||||
|
self = super().__new__(cls)
|
||||||
|
self._meta = {}
|
||||||
|
return self
|
||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._meta = {}
|
|
||||||
for k, v in kwargs.items():
|
for k, v in kwargs.items():
|
||||||
if isinstance(v, (list, np.ndarray)) \
|
if isinstance(v, (list, np.ndarray)) \
|
||||||
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
||||||
@ -86,6 +97,25 @@ class Batch:
|
|||||||
else:
|
else:
|
||||||
self.__dict__[k] = kwargs[k]
|
self.__dict__[k] = kwargs[k]
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
"""Pickling interface. Only the actual data are serialized
|
||||||
|
for both efficiency and simplicity.
|
||||||
|
"""
|
||||||
|
state = {}
|
||||||
|
for k in self.keys():
|
||||||
|
v = self[k]
|
||||||
|
if isinstance(v, Batch):
|
||||||
|
v = v.__getstate__()
|
||||||
|
state[k] = v
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
"""Unpickling interface. At this point, self is an empty Batch
|
||||||
|
instance that has not been initialized, so it can safely be
|
||||||
|
initialized by the pickle state.
|
||||||
|
"""
|
||||||
|
self.__init__(**state)
|
||||||
|
|
||||||
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
|
def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
|
||||||
"""Return self[index]."""
|
"""Return self[index]."""
|
||||||
if isinstance(index, str):
|
if isinstance(index, str):
|
||||||
|
18
tianshou/env/vecenv.py
vendored
18
tianshou/env/vecenv.py
vendored
@ -12,7 +12,7 @@ except ImportError:
|
|||||||
from tianshou.env.utils import CloudpickleWrapper
|
from tianshou.env.utils import CloudpickleWrapper
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorEnv(ABC, gym.Wrapper):
|
class BaseVectorEnv(ABC, gym.Env):
|
||||||
"""Base class for vectorized environments wrapper. Usage:
|
"""Base class for vectorized environments wrapper. Usage:
|
||||||
::
|
::
|
||||||
|
|
||||||
@ -79,8 +79,14 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
|
||||||
"""Set the seed for all environments. Accept ``None``, an int (which
|
"""Set the seed for all environments.
|
||||||
will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
|
|
||||||
|
Accept ``None``, an int (which will extend ``i`` to
|
||||||
|
``[i, i + 1, i + 2, ...]``) or a list.
|
||||||
|
|
||||||
|
:return: The list of seeds used in this env's random number generators.
|
||||||
|
The first value in the list should be the "main" seed, or the value
|
||||||
|
which a reproducer should pass to 'seed'.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -91,7 +97,11 @@ class BaseVectorEnv(ABC, gym.Wrapper):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
"""Close all of the environments."""
|
"""Close all of the environments.
|
||||||
|
|
||||||
|
Environments will automatically close() themselves when garbage
|
||||||
|
collected or when the program exits.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user