Remove reset_buffer() from reset method (#501)

This commit is contained in:
Markus28 2022-01-13 01:46:28 +01:00 committed by GitHub
parent 3592f45446
commit a2d76d1276
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 13 deletions

View File

@ -96,7 +96,7 @@ class ReplayBuffer:
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
return buf
def reset(self, keep_statistics: bool = False) -> None:

View File

@ -46,6 +46,10 @@ class Collector(object):
Please make sure the given environment has a time limitation if using n_episode
collect option.
.. note::
In past versions of Tianshou, the replay buffer that was passed to `__init__`
was automatically reset. This is not done in the current implementation.
"""
def __init__(
@ -68,7 +72,7 @@ class Collector(object):
self.preprocess_fn = preprocess_fn
self._action_space = env.action_space
# avoid creating attribute outside __init__
self.reset()
self.reset(False)
def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
"""Check if the buffer matches the constraint."""
@ -94,15 +98,20 @@ class Collector(object):
)
self.buffer = buffer
def reset(self) -> None:
"""Reset all related variables in the collector."""
def reset(self, reset_buffer: bool = True) -> None:
"""Reset the environment, statistics, current data and possibly replay memory.
:param bool reset_buffer: if true, reset the replay buffer that is attached
to the collector.
"""
# use empty Batch for "state" so that self.data supports slicing
# convert empty Batch to None when passing data to policy
self.data = Batch(
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
)
self.reset_env()
self.reset_buffer()
if reset_buffer:
self.reset_buffer()
self.reset_stat()
def reset_stat(self) -> None:

View File

@ -48,7 +48,12 @@ class Actor(nn.Module):
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self.last = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self._max = max_action
def forward(
@ -96,7 +101,12 @@ class Critic(nn.Module):
self.preprocess = preprocess_net
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
self.last = MLP(
input_dim, # type: ignore
1,
hidden_sizes,
device=self.device
)
def forward(
self,
@ -165,11 +175,19 @@ class ActorProb(nn.Module):
self.device = device
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self.mu = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = MLP(
input_dim, self.output_dim, hidden_sizes, device=self.device
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
else:
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))

View File

@ -49,7 +49,12 @@ class Actor(nn.Module):
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self.last = MLP(
input_dim, # type: ignore
self.output_dim,
hidden_sizes,
device=self.device
)
self.softmax_output = softmax_output
def forward(
@ -101,7 +106,12 @@ class Critic(nn.Module):
self.preprocess = preprocess_net
self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
self.last = MLP(
input_dim, # type: ignore
last_size,
hidden_sizes,
device=self.device
)
def forward(
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
@ -183,8 +193,10 @@ class ImplicitQuantileNetwork(Critic):
self.input_dim = getattr(
preprocess_net, "output_dim", preprocess_net_output_dim
)
self.embed_model = CosineEmbeddingNetwork(num_cosines,
self.input_dim).to(device)
self.embed_model = CosineEmbeddingNetwork(
num_cosines,
self.input_dim # type: ignore
).to(device)
def forward( # type: ignore
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any