Remove reset_buffer() from reset method (#501)

This commit is contained in:
Markus28 2022-01-13 01:46:28 +01:00 committed by GitHub
parent 3592f45446
commit a2d76d1276
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 13 deletions

View File

@ -96,7 +96,7 @@ class ReplayBuffer:
"""Load replay buffer from HDF5 file.""" """Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f: with h5py.File(path, "r") as f:
buf = cls.__new__(cls) buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device)) buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
return buf return buf
def reset(self, keep_statistics: bool = False) -> None: def reset(self, keep_statistics: bool = False) -> None:

View File

@ -46,6 +46,10 @@ class Collector(object):
Please make sure the given environment has a time limitation if using n_episode Please make sure the given environment has a time limitation if using n_episode
collect option. collect option.
.. note::
In past versions of Tianshou, the replay buffer that was passed to `__init__`
was automatically reset. This is not done in the current implementation.
""" """
def __init__( def __init__(
@ -68,7 +72,7 @@ class Collector(object):
self.preprocess_fn = preprocess_fn self.preprocess_fn = preprocess_fn
self._action_space = env.action_space self._action_space = env.action_space
# avoid creating attribute outside __init__ # avoid creating attribute outside __init__
self.reset() self.reset(False)
def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None: def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
"""Check if the buffer matches the constraint.""" """Check if the buffer matches the constraint."""
@ -94,14 +98,19 @@ class Collector(object):
) )
self.buffer = buffer self.buffer = buffer
def reset(self) -> None: def reset(self, reset_buffer: bool = True) -> None:
"""Reset all related variables in the collector.""" """Reset the environment, statistics, current data and possibly replay memory.
:param bool reset_buffer: if true, reset the replay buffer that is attached
to the collector.
"""
# use empty Batch for "state" so that self.data supports slicing # use empty Batch for "state" so that self.data supports slicing
# convert empty Batch to None when passing data to policy # convert empty Batch to None when passing data to policy
self.data = Batch( self.data = Batch(
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
) )
self.reset_env() self.reset_env()
if reset_buffer:
self.reset_buffer() self.reset_buffer()
self.reset_stat() self.reset_stat()

View File

@ -48,7 +48,12 @@ class Actor(nn.Module):
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape)) self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self.last = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self._max = max_action self._max = max_action
def forward( def forward(
@ -96,7 +101,12 @@ class Critic(nn.Module):
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = 1 self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device) self.last = MLP(
input_dim, # type: ignore
1,
hidden_sizes,
device=self.device
)
def forward( def forward(
self, self,
@ -165,11 +175,19 @@ class ActorProb(nn.Module):
self.device = device self.device = device
self.output_dim = int(np.prod(action_shape)) self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self.mu = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self._c_sigma = conditioned_sigma self._c_sigma = conditioned_sigma
if conditioned_sigma: if conditioned_sigma:
self.sigma = MLP( self.sigma = MLP(
input_dim, self.output_dim, hidden_sizes, device=self.device input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
) )
else: else:
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1)) self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))

View File

@ -49,7 +49,12 @@ class Actor(nn.Module):
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape)) self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) self.last = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self.softmax_output = softmax_output self.softmax_output = softmax_output
def forward( def forward(
@ -101,7 +106,12 @@ class Critic(nn.Module):
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = last_size self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim) input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) self.last = MLP(
input_dim, # type: ignore
last_size,
hidden_sizes,
device=self.device
)
def forward( def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
@ -183,8 +193,10 @@ class ImplicitQuantileNetwork(Critic):
self.input_dim = getattr( self.input_dim = getattr(
preprocess_net, "output_dim", preprocess_net_output_dim preprocess_net, "output_dim", preprocess_net_output_dim
) )
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.embed_model = CosineEmbeddingNetwork(
self.input_dim).to(device) num_cosines,
self.input_dim # type: ignore
).to(device)
def forward( # type: ignore def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any