Remove reset_buffer() from reset method (#501)
This commit is contained in:
parent
3592f45446
commit
a2d76d1276
@ -96,7 +96,7 @@ class ReplayBuffer:
|
||||
"""Load replay buffer from HDF5 file."""
|
||||
with h5py.File(path, "r") as f:
|
||||
buf = cls.__new__(cls)
|
||||
buf.__setstate__(from_hdf5(f, device=device))
|
||||
buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
|
||||
return buf
|
||||
|
||||
def reset(self, keep_statistics: bool = False) -> None:
|
||||
|
@ -46,6 +46,10 @@ class Collector(object):
|
||||
|
||||
Please make sure the given environment has a time limitation if using n_episode
|
||||
collect option.
|
||||
|
||||
.. note::
|
||||
In past versions of Tianshou, the replay buffer that was passed to `__init__`
|
||||
was automatically reset. This is not done in the current implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -68,7 +72,7 @@ class Collector(object):
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self._action_space = env.action_space
|
||||
# avoid creating attribute outside __init__
|
||||
self.reset()
|
||||
self.reset(False)
|
||||
|
||||
def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
|
||||
"""Check if the buffer matches the constraint."""
|
||||
@ -94,15 +98,20 @@ class Collector(object):
|
||||
)
|
||||
self.buffer = buffer
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
def reset(self, reset_buffer: bool = True) -> None:
|
||||
"""Reset the environment, statistics, current data and possibly replay memory.
|
||||
|
||||
:param bool reset_buffer: if true, reset the replay buffer that is attached
|
||||
to the collector.
|
||||
"""
|
||||
# use empty Batch for "state" so that self.data supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(
|
||||
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
|
||||
)
|
||||
self.reset_env()
|
||||
self.reset_buffer()
|
||||
if reset_buffer:
|
||||
self.reset_buffer()
|
||||
self.reset_stat()
|
||||
|
||||
def reset_stat(self) -> None:
|
||||
|
@ -48,7 +48,12 @@ class Actor(nn.Module):
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
||||
self.last = MLP(
|
||||
input_dim, # type: ignore
|
||||
self.output_dim,
|
||||
hidden_sizes,
|
||||
device=self.device
|
||||
)
|
||||
self._max = max_action
|
||||
|
||||
def forward(
|
||||
@ -96,7 +101,12 @@ class Critic(nn.Module):
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = 1
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
|
||||
self.last = MLP(
|
||||
input_dim, # type: ignore
|
||||
1,
|
||||
hidden_sizes,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -165,11 +175,19 @@ class ActorProb(nn.Module):
|
||||
self.device = device
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
||||
self.mu = MLP(
|
||||
input_dim, # type: ignore
|
||||
self.output_dim,
|
||||
hidden_sizes,
|
||||
device=self.device
|
||||
)
|
||||
self._c_sigma = conditioned_sigma
|
||||
if conditioned_sigma:
|
||||
self.sigma = MLP(
|
||||
input_dim, self.output_dim, hidden_sizes, device=self.device
|
||||
input_dim, # type: ignore
|
||||
self.output_dim,
|
||||
hidden_sizes,
|
||||
device=self.device
|
||||
)
|
||||
else:
|
||||
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
||||
|
@ -49,7 +49,12 @@ class Actor(nn.Module):
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
||||
self.last = MLP(
|
||||
input_dim, # type: ignore
|
||||
self.output_dim,
|
||||
hidden_sizes,
|
||||
device=self.device
|
||||
)
|
||||
self.softmax_output = softmax_output
|
||||
|
||||
def forward(
|
||||
@ -101,7 +106,12 @@ class Critic(nn.Module):
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = last_size
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
|
||||
self.last = MLP(
|
||||
input_dim, # type: ignore
|
||||
last_size,
|
||||
hidden_sizes,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||
@ -183,8 +193,10 @@ class ImplicitQuantileNetwork(Critic):
|
||||
self.input_dim = getattr(
|
||||
preprocess_net, "output_dim", preprocess_net_output_dim
|
||||
)
|
||||
self.embed_model = CosineEmbeddingNetwork(num_cosines,
|
||||
self.input_dim).to(device)
|
||||
self.embed_model = CosineEmbeddingNetwork(
|
||||
num_cosines,
|
||||
self.input_dim # type: ignore
|
||||
).to(device)
|
||||
|
||||
def forward( # type: ignore
|
||||
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
|
||||
|
Loading…
x
Reference in New Issue
Block a user