Support batch_size=None and use it in various scripts (#993)
Closes #986
This commit is contained in:
parent
f134bc20b5
commit
8d3d1f164b
@ -34,7 +34,7 @@ def get_args():
|
||||
parser.add_argument("--step-per-collect", type=int, default=80)
|
||||
parser.add_argument("--repeat-per-collect", type=int, default=1)
|
||||
# batch-size >> step-per-collect means calculating all data in one singe forward.
|
||||
parser.add_argument("--batch-size", type=int, default=99999)
|
||||
parser.add_argument("--batch-size", type=int, default=None)
|
||||
parser.add_argument("--training-num", type=int, default=16)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
# a2c special
|
||||
|
@ -30,7 +30,7 @@ def main(
|
||||
step_per_epoch: int = 30000,
|
||||
step_per_collect: int = 80,
|
||||
repeat_per_collect: int = 1,
|
||||
batch_size: int = 99999,
|
||||
batch_size: int | None = None,
|
||||
training_num: int = 16,
|
||||
test_num: int = 10,
|
||||
rew_norm: bool = True,
|
||||
|
@ -39,7 +39,7 @@ def get_args():
|
||||
parser.add_argument("--step-per-collect", type=int, default=1024)
|
||||
parser.add_argument("--repeat-per-collect", type=int, default=1)
|
||||
# batch-size >> step-per-collect means calculating all data in one singe forward.
|
||||
parser.add_argument("--batch-size", type=int, default=99999)
|
||||
parser.add_argument("--batch-size", type=int, default=None)
|
||||
parser.add_argument("--training-num", type=int, default=16)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
# npg special
|
||||
|
@ -32,7 +32,7 @@ def main(
|
||||
step_per_epoch: int = 30000,
|
||||
step_per_collect: int = 1024,
|
||||
repeat_per_collect: int = 1,
|
||||
batch_size: int = 99999,
|
||||
batch_size: int | None = None,
|
||||
training_num: int = 16,
|
||||
test_num: int = 10,
|
||||
rew_norm: bool = True,
|
||||
|
@ -34,7 +34,7 @@ def get_args():
|
||||
parser.add_argument("--step-per-collect", type=int, default=2048)
|
||||
parser.add_argument("--repeat-per-collect", type=int, default=1)
|
||||
# batch-size >> step-per-collect means calculating all data in one singe forward.
|
||||
parser.add_argument("--batch-size", type=int, default=99999)
|
||||
parser.add_argument("--batch-size", type=int, default=None)
|
||||
parser.add_argument("--training-num", type=int, default=64)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
# reinforce special
|
||||
|
@ -29,7 +29,7 @@ def main(
|
||||
step_per_epoch: int = 30000,
|
||||
step_per_collect: int = 2048,
|
||||
repeat_per_collect: int = 1,
|
||||
batch_size: int = 99999,
|
||||
batch_size: int | None = None,
|
||||
training_num: int = 64,
|
||||
test_num: int = 10,
|
||||
rew_norm: bool = True,
|
||||
|
@ -39,7 +39,7 @@ def get_args():
|
||||
parser.add_argument("--step-per-collect", type=int, default=1024)
|
||||
parser.add_argument("--repeat-per-collect", type=int, default=1)
|
||||
# batch-size >> step-per-collect means calculating all data in one singe forward.
|
||||
parser.add_argument("--batch-size", type=int, default=99999)
|
||||
parser.add_argument("--batch-size", type=int, default=None)
|
||||
parser.add_argument("--training-num", type=int, default=16)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
# trpo special
|
||||
|
@ -32,7 +32,7 @@ def main(
|
||||
step_per_epoch: int = 30000,
|
||||
step_per_collect: int = 1024,
|
||||
repeat_per_collect: int = 1,
|
||||
batch_size: int = 99999,
|
||||
batch_size: int | None = None,
|
||||
training_num: int = 16,
|
||||
test_num: int = 10,
|
||||
rew_norm: bool = True,
|
||||
|
@ -290,20 +290,32 @@ class ReplayBuffer:
|
||||
self._meta[ptr] = batch
|
||||
return ptr, ep_rew, ep_len, ep_idx
|
||||
|
||||
def sample_indices(self, batch_size: int) -> np.ndarray:
|
||||
def sample_indices(self, batch_size: int | None) -> np.ndarray:
|
||||
"""Get a random sample of index with size = batch_size.
|
||||
|
||||
Return all available indices in the buffer if batch_size is 0; return an empty
|
||||
numpy array if batch_size < 0 or no available index can be sampled.
|
||||
|
||||
:param batch_size: the number of indices to be sampled. If None, it will be set
|
||||
to the length of the buffer (i.e. return all available indices in a
|
||||
random order).
|
||||
"""
|
||||
if batch_size is None:
|
||||
batch_size = len(self)
|
||||
if self.stack_num == 1 or not self._sample_avail: # most often case
|
||||
if batch_size > 0:
|
||||
return np.random.choice(self._size, batch_size)
|
||||
# TODO: is this behavior really desired?
|
||||
if batch_size == 0: # construct current available indices
|
||||
return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)])
|
||||
return np.array([], int)
|
||||
# TODO: raise error on negative batch_size instead?
|
||||
if batch_size < 0:
|
||||
return np.array([], int)
|
||||
# TODO: simplify this code - shouldn't have such a large if-else
|
||||
# with many returns for handling different stack nums.
|
||||
# It is also not clear whether this is really necessary - frame stacking usually is handled
|
||||
# by environment wrappers (e.g. FrameStack) and not by the replay buffer.
|
||||
all_indices = prev_indices = np.concatenate(
|
||||
[np.arange(self._index, self._size), np.arange(self._index)],
|
||||
)
|
||||
@ -314,7 +326,7 @@ class ReplayBuffer:
|
||||
return np.random.choice(all_indices, batch_size)
|
||||
return all_indices
|
||||
|
||||
def sample(self, batch_size: int) -> tuple[RolloutBatchProtocol, np.ndarray]:
|
||||
def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]:
|
||||
"""Get a random sample from buffer with size = batch_size.
|
||||
|
||||
Return all the data in the buffer if batch_size is 0.
|
||||
|
@ -84,7 +84,7 @@ class HERReplayBuffer(ReplayBuffer):
|
||||
self._restore_cache()
|
||||
return super().add(batch, buffer_ids)
|
||||
|
||||
def sample_indices(self, batch_size: int) -> np.ndarray:
|
||||
def sample_indices(self, batch_size: int | None) -> np.ndarray:
|
||||
"""Get a random sample of index with size = batch_size.
|
||||
|
||||
Return all available indices in the buffer if batch_size is 0; return an \
|
||||
|
@ -169,8 +169,10 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
self._meta[ptrs] = batch
|
||||
return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
|
||||
|
||||
def sample_indices(self, batch_size: int) -> np.ndarray:
|
||||
if batch_size < 0:
|
||||
def sample_indices(self, batch_size: int | None) -> np.ndarray:
|
||||
# TODO: simplify this code
|
||||
if batch_size is not None and batch_size < 0:
|
||||
# TODO: raise error instead?
|
||||
return np.array([], int)
|
||||
if self._sample_avail and self.stack_num > 1:
|
||||
all_indices = np.concatenate(
|
||||
@ -181,8 +183,10 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
)
|
||||
if batch_size == 0:
|
||||
return all_indices
|
||||
if batch_size is None:
|
||||
batch_size = len(all_indices)
|
||||
return np.random.choice(all_indices, batch_size)
|
||||
if batch_size == 0: # get all available indices
|
||||
if batch_size == 0 or batch_size is None: # get all available indices
|
||||
sample_num = np.zeros(self.buffer_num, int)
|
||||
else:
|
||||
buffer_idx = np.random.choice(
|
||||
|
@ -58,8 +58,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self.init_weight(ptr)
|
||||
return ptr, ep_rew, ep_len, ep_idx
|
||||
|
||||
def sample_indices(self, batch_size: int) -> np.ndarray:
|
||||
if batch_size > 0 and len(self) > 0:
|
||||
def sample_indices(self, batch_size: int | None) -> np.ndarray:
|
||||
if batch_size is not None and batch_size > 0 and len(self) > 0:
|
||||
scalar = np.random.rand(batch_size) * self.weight.reduce()
|
||||
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
|
||||
return super().sample_indices(batch_size)
|
||||
|
@ -370,7 +370,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
def update(
|
||||
self,
|
||||
sample_size: int,
|
||||
sample_size: int | None,
|
||||
buffer: ReplayBuffer | None,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
@ -382,7 +382,9 @@ class BasePolicy(ABC, nn.Module):
|
||||
Please refer to :ref:`policy_state` for more detailed explanation.
|
||||
|
||||
:param sample_size: 0 means it will extract all the data from the buffer,
|
||||
otherwise it will sample a batch with given sample_size.
|
||||
otherwise it will sample a batch with given sample_size. None also
|
||||
means it will extract all the data from the buffer, but it will be shuffled
|
||||
first. TODO: remove the option for 0?
|
||||
:param buffer: the corresponding replay buffer.
|
||||
|
||||
:return: A dict, including the data needed to be logged (e.g., loss) from
|
||||
|
@ -139,7 +139,7 @@ class GAILPolicy(PPOPolicy):
|
||||
def learn( # type: ignore
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch_size: int,
|
||||
batch_size: int | None,
|
||||
repeat: int,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, list[float]]:
|
||||
|
@ -142,14 +142,15 @@ class A2CPolicy(PGPolicy):
|
||||
def learn( # type: ignore
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch_size: int,
|
||||
batch_size: int | None,
|
||||
repeat: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, list[float]]:
|
||||
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(split_batch_size, merge_last=True):
|
||||
# calculate loss for actor
|
||||
dist = self(minibatch).dist
|
||||
log_prob = dist.log_prob(minibatch.act)
|
||||
|
@ -109,13 +109,14 @@ class NPGPolicy(A2CPolicy):
|
||||
def learn( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
batch_size: int,
|
||||
batch_size: int | None,
|
||||
repeat: int,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, list[float]]:
|
||||
actor_losses, vf_losses, kls = [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(split_batch_size, merge_last=True):
|
||||
# optimize actor
|
||||
# direction: calculate villia gradient
|
||||
dist = self(minibatch).dist
|
||||
|
@ -193,14 +193,15 @@ class PGPolicy(BasePolicy):
|
||||
def learn( # type: ignore
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch_size: int,
|
||||
batch_size: int | None,
|
||||
repeat: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, list[float]]:
|
||||
losses = []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(split_batch_size, merge_last=True):
|
||||
self.optim.zero_grad()
|
||||
result = self(minibatch)
|
||||
dist = result.dist
|
||||
|
@ -128,16 +128,17 @@ class PPOPolicy(A2CPolicy):
|
||||
def learn( # type: ignore
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
batch_size: int,
|
||||
batch_size: int | None,
|
||||
repeat: int,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, list[float]]:
|
||||
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for step in range(repeat):
|
||||
if self.recompute_adv and step > 0:
|
||||
batch = self._compute_returns(batch, self._buffer, self._indices)
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(split_batch_size, merge_last=True):
|
||||
# calculate loss for actor
|
||||
dist = self(minibatch).dist
|
||||
if self.norm_adv:
|
||||
|
@ -91,13 +91,14 @@ class TRPOPolicy(NPGPolicy):
|
||||
def learn( # type: ignore
|
||||
self,
|
||||
batch: Batch,
|
||||
batch_size: int,
|
||||
batch_size: int | None,
|
||||
repeat: int,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, list[float]]:
|
||||
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
|
||||
split_batch_size = batch_size or -1
|
||||
for _ in range(repeat):
|
||||
for minibatch in batch.split(batch_size, merge_last=True):
|
||||
for minibatch in batch.split(split_batch_size, merge_last=True):
|
||||
# optimize actor
|
||||
# direction: calculate villia gradient
|
||||
dist = self(minibatch).dist # TODO could come from batch
|
||||
|
@ -31,7 +31,7 @@ class BaseTrainer(ABC):
|
||||
|
||||
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
|
||||
:param batch_size: the batch size of sample data, which is going to feed in
|
||||
the policy network.
|
||||
the policy network. If None, will use the whole buffer in each gradient step.
|
||||
:param train_collector: the collector used for training.
|
||||
:param test_collector: the collector used for testing. If it's None,
|
||||
then no testing will be performed.
|
||||
@ -141,7 +141,7 @@ class BaseTrainer(ABC):
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
max_epoch: int,
|
||||
batch_size: int,
|
||||
batch_size: int | None,
|
||||
train_collector: Collector | None = None,
|
||||
test_collector: Collector | None = None,
|
||||
buffer: ReplayBuffer | None = None,
|
||||
@ -320,7 +320,9 @@ class BaseTrainer(ABC):
|
||||
|
||||
# for offline RL
|
||||
if self.train_collector is None:
|
||||
self.env_step = self.gradient_step * self.batch_size
|
||||
assert self.buffer is not None
|
||||
batch_size = self.batch_size or len(self.buffer)
|
||||
self.env_step = self.gradient_step * batch_size
|
||||
|
||||
if not self.stop_fn_flag:
|
||||
self.logger.save_data(
|
||||
@ -565,9 +567,9 @@ class OnpolicyTrainer(BaseTrainer):
|
||||
"""Perform one on-policy update."""
|
||||
assert self.train_collector is not None
|
||||
losses = self.policy.update(
|
||||
0,
|
||||
self.train_collector.buffer,
|
||||
# Note: sample_size is 0, so the whole buffer is used for the update.
|
||||
sample_size=0,
|
||||
buffer=self.train_collector.buffer,
|
||||
# Note: sample_size is None, so the whole buffer is used for the update.
|
||||
# The kwargs are in the end passed to the .learn method, which uses
|
||||
# batch_size to iterate through the buffer in mini-batches
|
||||
# Off-policy algos typically don't use the batch_size kwarg at all
|
||||
@ -579,7 +581,9 @@ class OnpolicyTrainer(BaseTrainer):
|
||||
# TODO: remove the gradient step counting in trainers? Doesn't seem like
|
||||
# it's important and it adds complexity
|
||||
self.gradient_step += 1
|
||||
if self.batch_size > 0:
|
||||
if self.batch_size is None:
|
||||
self.gradient_step += 1
|
||||
elif self.batch_size > 0:
|
||||
self.gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size)
|
||||
|
||||
# Note: this is the main difference to the off-policy trainer!
|
||||
|
Loading…
x
Reference in New Issue
Block a user