Support batch_size=None and use it in various scripts (#993)

Closes #986
This commit is contained in:
Michael Panchenko 2023-11-24 19:13:10 +01:00 committed by GitHub
parent f134bc20b5
commit 8d3d1f164b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 63 additions and 36 deletions

View File

@ -34,7 +34,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=80)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# a2c special

View File

@ -30,7 +30,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 80,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 16,
test_num: int = 10,
rew_norm: bool = True,

View File

@ -39,7 +39,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=1024)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# npg special

View File

@ -32,7 +32,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 1024,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 16,
test_num: int = 10,
rew_norm: bool = True,

View File

@ -34,7 +34,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=2048)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=64)
parser.add_argument("--test-num", type=int, default=10)
# reinforce special

View File

@ -29,7 +29,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 64,
test_num: int = 10,
rew_norm: bool = True,

View File

@ -39,7 +39,7 @@ def get_args():
parser.add_argument("--step-per-collect", type=int, default=1024)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=99999)
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# trpo special

View File

@ -32,7 +32,7 @@ def main(
step_per_epoch: int = 30000,
step_per_collect: int = 1024,
repeat_per_collect: int = 1,
batch_size: int = 99999,
batch_size: int | None = None,
training_num: int = 16,
test_num: int = 10,
rew_norm: bool = True,

View File

@ -290,20 +290,32 @@ class ReplayBuffer:
self._meta[ptr] = batch
return ptr, ep_rew, ep_len, ep_idx
def sample_indices(self, batch_size: int) -> np.ndarray:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return an empty
numpy array if batch_size < 0 or no available index can be sampled.
:param batch_size: the number of indices to be sampled. If None, it will be set
to the length of the buffer (i.e. return all available indices in a
random order).
"""
if batch_size is None:
batch_size = len(self)
if self.stack_num == 1 or not self._sample_avail: # most often case
if batch_size > 0:
return np.random.choice(self._size, batch_size)
# TODO: is this behavior really desired?
if batch_size == 0: # construct current available indices
return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)])
return np.array([], int)
# TODO: raise error on negative batch_size instead?
if batch_size < 0:
return np.array([], int)
# TODO: simplify this code - shouldn't have such a large if-else
# with many returns for handling different stack nums.
# It is also not clear whether this is really necessary - frame stacking usually is handled
# by environment wrappers (e.g. FrameStack) and not by the replay buffer.
all_indices = prev_indices = np.concatenate(
[np.arange(self._index, self._size), np.arange(self._index)],
)
@ -314,7 +326,7 @@ class ReplayBuffer:
return np.random.choice(all_indices, batch_size)
return all_indices
def sample(self, batch_size: int) -> tuple[RolloutBatchProtocol, np.ndarray]:
def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]:
"""Get a random sample from buffer with size = batch_size.
Return all the data in the buffer if batch_size is 0.

View File

@ -84,7 +84,7 @@ class HERReplayBuffer(ReplayBuffer):
self._restore_cache()
return super().add(batch, buffer_ids)
def sample_indices(self, batch_size: int) -> np.ndarray:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
"""Get a random sample of index with size = batch_size.
Return all available indices in the buffer if batch_size is 0; return an \

View File

@ -169,8 +169,10 @@ class ReplayBufferManager(ReplayBuffer):
self._meta[ptrs] = batch
return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs)
def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size < 0:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
# TODO: simplify this code
if batch_size is not None and batch_size < 0:
# TODO: raise error instead?
return np.array([], int)
if self._sample_avail and self.stack_num > 1:
all_indices = np.concatenate(
@ -181,8 +183,10 @@ class ReplayBufferManager(ReplayBuffer):
)
if batch_size == 0:
return all_indices
if batch_size is None:
batch_size = len(all_indices)
return np.random.choice(all_indices, batch_size)
if batch_size == 0: # get all available indices
if batch_size == 0 or batch_size is None: # get all available indices
sample_num = np.zeros(self.buffer_num, int)
else:
buffer_idx = np.random.choice(

View File

@ -58,8 +58,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self.init_weight(ptr)
return ptr, ep_rew, ep_len, ep_idx
def sample_indices(self, batch_size: int) -> np.ndarray:
if batch_size > 0 and len(self) > 0:
def sample_indices(self, batch_size: int | None) -> np.ndarray:
if batch_size is not None and batch_size > 0 and len(self) > 0:
scalar = np.random.rand(batch_size) * self.weight.reduce()
return self.weight.get_prefix_sum_idx(scalar) # type: ignore
return super().sample_indices(batch_size)

View File

@ -370,7 +370,7 @@ class BasePolicy(ABC, nn.Module):
def update(
self,
sample_size: int,
sample_size: int | None,
buffer: ReplayBuffer | None,
**kwargs: Any,
) -> dict[str, Any]:
@ -382,7 +382,9 @@ class BasePolicy(ABC, nn.Module):
Please refer to :ref:`policy_state` for more detailed explanation.
:param sample_size: 0 means it will extract all the data from the buffer,
otherwise it will sample a batch with given sample_size.
otherwise it will sample a batch with given sample_size. None also
means it will extract all the data from the buffer, but it will be shuffled
first. TODO: remove the option for 0?
:param buffer: the corresponding replay buffer.
:return: A dict, including the data needed to be logged (e.g., loss) from

View File

@ -139,7 +139,7 @@ class GAILPolicy(PPOPolicy):
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
**kwargs: Any,
) -> dict[str, list[float]]:

View File

@ -142,14 +142,15 @@ class A2CPolicy(PGPolicy):
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
*args: Any,
**kwargs: Any,
) -> dict[str, list[float]]:
losses, actor_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
log_prob = dist.log_prob(minibatch.act)

View File

@ -109,13 +109,14 @@ class NPGPolicy(A2CPolicy):
def learn( # type: ignore
self,
batch: Batch,
batch_size: int,
batch_size: int | None,
repeat: int,
**kwargs: Any,
) -> dict[str, list[float]]:
actor_losses, vf_losses, kls = [], [], []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(minibatch).dist

View File

@ -193,14 +193,15 @@ class PGPolicy(BasePolicy):
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
*args: Any,
**kwargs: Any,
) -> dict[str, list[float]]:
losses = []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
self.optim.zero_grad()
result = self(minibatch)
dist = result.dist

View File

@ -128,16 +128,17 @@ class PPOPolicy(A2CPolicy):
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int,
batch_size: int | None,
repeat: int,
*args: Any,
**kwargs: Any,
) -> dict[str, list[float]]:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1
for step in range(repeat):
if self.recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indices)
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
if self.norm_adv:

View File

@ -91,13 +91,14 @@ class TRPOPolicy(NPGPolicy):
def learn( # type: ignore
self,
batch: Batch,
batch_size: int,
batch_size: int | None,
repeat: int,
**kwargs: Any,
) -> dict[str, list[float]]:
actor_losses, vf_losses, step_sizes, kls = [], [], [], []
split_batch_size = batch_size or -1
for _ in range(repeat):
for minibatch in batch.split(batch_size, merge_last=True):
for minibatch in batch.split(split_batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(minibatch).dist # TODO could come from batch

View File

@ -31,7 +31,7 @@ class BaseTrainer(ABC):
:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
:param batch_size: the batch size of sample data, which is going to feed in
the policy network.
the policy network. If None, will use the whole buffer in each gradient step.
:param train_collector: the collector used for training.
:param test_collector: the collector used for testing. If it's None,
then no testing will be performed.
@ -141,7 +141,7 @@ class BaseTrainer(ABC):
self,
policy: BasePolicy,
max_epoch: int,
batch_size: int,
batch_size: int | None,
train_collector: Collector | None = None,
test_collector: Collector | None = None,
buffer: ReplayBuffer | None = None,
@ -320,7 +320,9 @@ class BaseTrainer(ABC):
# for offline RL
if self.train_collector is None:
self.env_step = self.gradient_step * self.batch_size
assert self.buffer is not None
batch_size = self.batch_size or len(self.buffer)
self.env_step = self.gradient_step * batch_size
if not self.stop_fn_flag:
self.logger.save_data(
@ -565,9 +567,9 @@ class OnpolicyTrainer(BaseTrainer):
"""Perform one on-policy update."""
assert self.train_collector is not None
losses = self.policy.update(
0,
self.train_collector.buffer,
# Note: sample_size is 0, so the whole buffer is used for the update.
sample_size=0,
buffer=self.train_collector.buffer,
# Note: sample_size is None, so the whole buffer is used for the update.
# The kwargs are in the end passed to the .learn method, which uses
# batch_size to iterate through the buffer in mini-batches
# Off-policy algos typically don't use the batch_size kwarg at all
@ -579,7 +581,9 @@ class OnpolicyTrainer(BaseTrainer):
# TODO: remove the gradient step counting in trainers? Doesn't seem like
# it's important and it adds complexity
self.gradient_step += 1
if self.batch_size > 0:
if self.batch_size is None:
self.gradient_step += 1
elif self.batch_size > 0:
self.gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size)
# Note: this is the main difference to the off-policy trainer!