Add pickle support for Batch. Fix VectorEnv. (#67)
* Fix vecenv. * Add pickle support for Batch class. * Add Batch pickle Unit Test. * Fix lint. * Swap Batch UT. * Fix lint. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
This commit is contained in:
		
							parent
							
								
									dd3e2130bb
								
							
						
					
					
						commit
						529a4cf44c
					
				@ -1,4 +1,5 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					import pickle
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -30,6 +31,15 @@ def test_batch_over_batch():
 | 
				
			|||||||
    assert batch2[-1].b.b == 0
 | 
					    assert batch2[-1].b.b == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_batch_pickle():
 | 
				
			||||||
 | 
					    batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])),
 | 
				
			||||||
 | 
					                  np=np.zeros([3, 4]))
 | 
				
			||||||
 | 
					    batch_pk = pickle.loads(pickle.dumps(batch))
 | 
				
			||||||
 | 
					    assert batch.obs.a == batch_pk.obs.a
 | 
				
			||||||
 | 
					    assert torch.all(batch.obs.c == batch_pk.obs.c)
 | 
				
			||||||
 | 
					    assert np.all(batch.np == batch_pk.np)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_batch_from_to_numpy_without_copy():
 | 
					def test_batch_from_to_numpy_without_copy():
 | 
				
			||||||
    batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
 | 
					    batch = Batch(a=np.ones((1,)), b=Batch(c=np.ones((1,))))
 | 
				
			||||||
    a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
 | 
					    a_mem_addr_orig = batch["a"].__array_interface__['data'][0]
 | 
				
			||||||
 | 
				
			|||||||
@ -1,8 +1,15 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
import pprint
 | 
					import pprint
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from typing import Any, List, Union, Iterator, Optional
 | 
					from typing import Any, List, Union, Iterator, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Disable pickle warning related to torch, since it has been removed
 | 
				
			||||||
 | 
					# on torch master branch. See Pull Request #39003 for details:
 | 
				
			||||||
 | 
					# https://github.com/pytorch/pytorch/pull/39003
 | 
				
			||||||
 | 
					warnings.filterwarnings(
 | 
				
			||||||
 | 
					    "ignore", message="pickle support for Storage will be removed in 1.5.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Batch:
 | 
					class Batch:
 | 
				
			||||||
    """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
 | 
					    """Tianshou provides :class:`~tianshou.data.Batch` as the internal data
 | 
				
			||||||
@ -66,9 +73,13 @@ class Batch:
 | 
				
			|||||||
        [11 22] [6 6]
 | 
					        [11 22] [6 6]
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __new__(cls, **kwargs) -> None:
 | 
				
			||||||
 | 
					        self = super().__new__(cls)
 | 
				
			||||||
 | 
					        self._meta = {}
 | 
				
			||||||
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, **kwargs) -> None:
 | 
					    def __init__(self, **kwargs) -> None:
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self._meta = {}
 | 
					 | 
				
			||||||
        for k, v in kwargs.items():
 | 
					        for k, v in kwargs.items():
 | 
				
			||||||
            if isinstance(v, (list, np.ndarray)) \
 | 
					            if isinstance(v, (list, np.ndarray)) \
 | 
				
			||||||
                    and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
 | 
					                    and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
 | 
				
			||||||
@ -86,6 +97,25 @@ class Batch:
 | 
				
			|||||||
            else:
 | 
					            else:
 | 
				
			||||||
                self.__dict__[k] = kwargs[k]
 | 
					                self.__dict__[k] = kwargs[k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getstate__(self):
 | 
				
			||||||
 | 
					        """Pickling interface. Only the actual data are serialized
 | 
				
			||||||
 | 
					        for both efficiency and simplicity.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        state = {}
 | 
				
			||||||
 | 
					        for k in self.keys():
 | 
				
			||||||
 | 
					            v = self[k]
 | 
				
			||||||
 | 
					            if isinstance(v, Batch):
 | 
				
			||||||
 | 
					                v = v.__getstate__()
 | 
				
			||||||
 | 
					            state[k] = v
 | 
				
			||||||
 | 
					        return state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __setstate__(self, state):
 | 
				
			||||||
 | 
					        """Unpickling interface. At this point, self is an empty Batch
 | 
				
			||||||
 | 
					        instance that has not been initialized, so it can safely be
 | 
				
			||||||
 | 
					        initialized by the pickle state.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        self.__init__(**state)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
 | 
					    def __getitem__(self, index: Union[str, slice]) -> Union['Batch', dict]:
 | 
				
			||||||
        """Return self[index]."""
 | 
					        """Return self[index]."""
 | 
				
			||||||
        if isinstance(index, str):
 | 
					        if isinstance(index, str):
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										18
									
								
								tianshou/env/vecenv.py
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										18
									
								
								tianshou/env/vecenv.py
									
									
									
									
										vendored
									
									
								
							@ -12,7 +12,7 @@ except ImportError:
 | 
				
			|||||||
from tianshou.env.utils import CloudpickleWrapper
 | 
					from tianshou.env.utils import CloudpickleWrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaseVectorEnv(ABC, gym.Wrapper):
 | 
					class BaseVectorEnv(ABC, gym.Env):
 | 
				
			||||||
    """Base class for vectorized environments wrapper. Usage:
 | 
					    """Base class for vectorized environments wrapper. Usage:
 | 
				
			||||||
    ::
 | 
					    ::
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -79,8 +79,14 @@ class BaseVectorEnv(ABC, gym.Wrapper):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
 | 
					    def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
 | 
				
			||||||
        """Set the seed for all environments. Accept ``None``, an int (which
 | 
					        """Set the seed for all environments.
 | 
				
			||||||
        will extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
 | 
					
 | 
				
			||||||
 | 
					        Accept ``None``, an int (which will extend ``i`` to
 | 
				
			||||||
 | 
					        ``[i, i + 1, i + 2, ...]``) or a list.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        :return: The list of seeds used in this env's random number generators.
 | 
				
			||||||
 | 
					        The first value in the list should be the "main" seed, or the value
 | 
				
			||||||
 | 
					        which a reproducer should pass to 'seed'.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -91,7 +97,11 @@ class BaseVectorEnv(ABC, gym.Wrapper):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def close(self) -> None:
 | 
					    def close(self) -> None:
 | 
				
			||||||
        """Close all of the environments."""
 | 
					        """Close all of the environments.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Environments will automatically close() themselves when garbage
 | 
				
			||||||
 | 
					        collected or when the program exits.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        pass
 | 
					        pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user