Remove reset_buffer() from reset method (#501)
This commit is contained in:
parent
3592f45446
commit
a2d76d1276
@ -96,7 +96,7 @@ class ReplayBuffer:
|
|||||||
"""Load replay buffer from HDF5 file."""
|
"""Load replay buffer from HDF5 file."""
|
||||||
with h5py.File(path, "r") as f:
|
with h5py.File(path, "r") as f:
|
||||||
buf = cls.__new__(cls)
|
buf = cls.__new__(cls)
|
||||||
buf.__setstate__(from_hdf5(f, device=device))
|
buf.__setstate__(from_hdf5(f, device=device)) # type: ignore
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
def reset(self, keep_statistics: bool = False) -> None:
|
def reset(self, keep_statistics: bool = False) -> None:
|
||||||
|
@ -46,6 +46,10 @@ class Collector(object):
|
|||||||
|
|
||||||
Please make sure the given environment has a time limitation if using n_episode
|
Please make sure the given environment has a time limitation if using n_episode
|
||||||
collect option.
|
collect option.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
In past versions of Tianshou, the replay buffer that was passed to `__init__`
|
||||||
|
was automatically reset. This is not done in the current implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -68,7 +72,7 @@ class Collector(object):
|
|||||||
self.preprocess_fn = preprocess_fn
|
self.preprocess_fn = preprocess_fn
|
||||||
self._action_space = env.action_space
|
self._action_space = env.action_space
|
||||||
# avoid creating attribute outside __init__
|
# avoid creating attribute outside __init__
|
||||||
self.reset()
|
self.reset(False)
|
||||||
|
|
||||||
def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
|
def _assign_buffer(self, buffer: Optional[ReplayBuffer]) -> None:
|
||||||
"""Check if the buffer matches the constraint."""
|
"""Check if the buffer matches the constraint."""
|
||||||
@ -94,14 +98,19 @@ class Collector(object):
|
|||||||
)
|
)
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self, reset_buffer: bool = True) -> None:
|
||||||
"""Reset all related variables in the collector."""
|
"""Reset the environment, statistics, current data and possibly replay memory.
|
||||||
|
|
||||||
|
:param bool reset_buffer: if true, reset the replay buffer that is attached
|
||||||
|
to the collector.
|
||||||
|
"""
|
||||||
# use empty Batch for "state" so that self.data supports slicing
|
# use empty Batch for "state" so that self.data supports slicing
|
||||||
# convert empty Batch to None when passing data to policy
|
# convert empty Batch to None when passing data to policy
|
||||||
self.data = Batch(
|
self.data = Batch(
|
||||||
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
|
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
|
||||||
)
|
)
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
|
if reset_buffer:
|
||||||
self.reset_buffer()
|
self.reset_buffer()
|
||||||
self.reset_stat()
|
self.reset_stat()
|
||||||
|
|
||||||
|
@ -48,7 +48,12 @@ class Actor(nn.Module):
|
|||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = int(np.prod(action_shape))
|
self.output_dim = int(np.prod(action_shape))
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
self.last = MLP(
|
||||||
|
input_dim, # type: ignore
|
||||||
|
self.output_dim,
|
||||||
|
hidden_sizes,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
self._max = max_action
|
self._max = max_action
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -96,7 +101,12 @@ class Critic(nn.Module):
|
|||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = 1
|
self.output_dim = 1
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, 1, hidden_sizes, device=self.device)
|
self.last = MLP(
|
||||||
|
input_dim, # type: ignore
|
||||||
|
1,
|
||||||
|
hidden_sizes,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -165,11 +175,19 @@ class ActorProb(nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
self.output_dim = int(np.prod(action_shape))
|
self.output_dim = int(np.prod(action_shape))
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
self.mu = MLP(
|
||||||
|
input_dim, # type: ignore
|
||||||
|
self.output_dim,
|
||||||
|
hidden_sizes,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
self._c_sigma = conditioned_sigma
|
self._c_sigma = conditioned_sigma
|
||||||
if conditioned_sigma:
|
if conditioned_sigma:
|
||||||
self.sigma = MLP(
|
self.sigma = MLP(
|
||||||
input_dim, self.output_dim, hidden_sizes, device=self.device
|
input_dim, # type: ignore
|
||||||
|
self.output_dim,
|
||||||
|
hidden_sizes,
|
||||||
|
device=self.device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
self.sigma_param = nn.Parameter(torch.zeros(self.output_dim, 1))
|
||||||
|
@ -49,7 +49,12 @@ class Actor(nn.Module):
|
|||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = int(np.prod(action_shape))
|
self.output_dim = int(np.prod(action_shape))
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
self.last = MLP(
|
||||||
|
input_dim, # type: ignore
|
||||||
|
self.output_dim,
|
||||||
|
hidden_sizes,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
self.softmax_output = softmax_output
|
self.softmax_output = softmax_output
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -101,7 +106,12 @@ class Critic(nn.Module):
|
|||||||
self.preprocess = preprocess_net
|
self.preprocess = preprocess_net
|
||||||
self.output_dim = last_size
|
self.output_dim = last_size
|
||||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
|
self.last = MLP(
|
||||||
|
input_dim, # type: ignore
|
||||||
|
last_size,
|
||||||
|
hidden_sizes,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
self, s: Union[np.ndarray, torch.Tensor], **kwargs: Any
|
||||||
@ -183,8 +193,10 @@ class ImplicitQuantileNetwork(Critic):
|
|||||||
self.input_dim = getattr(
|
self.input_dim = getattr(
|
||||||
preprocess_net, "output_dim", preprocess_net_output_dim
|
preprocess_net, "output_dim", preprocess_net_output_dim
|
||||||
)
|
)
|
||||||
self.embed_model = CosineEmbeddingNetwork(num_cosines,
|
self.embed_model = CosineEmbeddingNetwork(
|
||||||
self.input_dim).to(device)
|
num_cosines,
|
||||||
|
self.input_dim # type: ignore
|
||||||
|
).to(device)
|
||||||
|
|
||||||
def forward( # type: ignore
|
def forward( # type: ignore
|
||||||
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
|
self, s: Union[np.ndarray, torch.Tensor], sample_size: int, **kwargs: Any
|
||||||
|
Loading…
x
Reference in New Issue
Block a user