diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 956da23..fbd2ec9 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -34,7 +34,7 @@ def get_args(): parser.add_argument("--step-per-collect", type=int, default=80) parser.add_argument("--repeat-per-collect", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=99999) + parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--training-num", type=int, default=16) parser.add_argument("--test-num", type=int, default=10) # a2c special diff --git a/examples/mujoco/mujoco_a2c_hl.py b/examples/mujoco/mujoco_a2c_hl.py index 1825b6c..7b7dba7 100644 --- a/examples/mujoco/mujoco_a2c_hl.py +++ b/examples/mujoco/mujoco_a2c_hl.py @@ -30,7 +30,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 80, repeat_per_collect: int = 1, - batch_size: int = 99999, + batch_size: int | None = None, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 4df3eb8..bad4f0e 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -39,7 +39,7 @@ def get_args(): parser.add_argument("--step-per-collect", type=int, default=1024) parser.add_argument("--repeat-per-collect", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=99999) + parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--training-num", type=int, default=16) parser.add_argument("--test-num", type=int, default=10) # npg special diff --git a/examples/mujoco/mujoco_npg_hl.py b/examples/mujoco/mujoco_npg_hl.py index 8d496f2..a423768 100644 --- a/examples/mujoco/mujoco_npg_hl.py +++ b/examples/mujoco/mujoco_npg_hl.py @@ -32,7 +32,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 1024, repeat_per_collect: int = 1, - batch_size: int = 99999, + batch_size: int | None = None, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 043bd34..50bf312 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -34,7 +34,7 @@ def get_args(): parser.add_argument("--step-per-collect", type=int, default=2048) parser.add_argument("--repeat-per-collect", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=99999) + parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--training-num", type=int, default=64) parser.add_argument("--test-num", type=int, default=10) # reinforce special diff --git a/examples/mujoco/mujoco_reinforce_hl.py b/examples/mujoco/mujoco_reinforce_hl.py index 1a08449..4f9bfdc 100644 --- a/examples/mujoco/mujoco_reinforce_hl.py +++ b/examples/mujoco/mujoco_reinforce_hl.py @@ -29,7 +29,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 2048, repeat_per_collect: int = 1, - batch_size: int = 99999, + batch_size: int | None = None, training_num: int = 64, test_num: int = 10, rew_norm: bool = True, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 0bf952f..e37cf91 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -39,7 +39,7 @@ def get_args(): parser.add_argument("--step-per-collect", type=int, default=1024) parser.add_argument("--repeat-per-collect", type=int, default=1) # batch-size >> step-per-collect means calculating all data in one singe forward. - parser.add_argument("--batch-size", type=int, default=99999) + parser.add_argument("--batch-size", type=int, default=None) parser.add_argument("--training-num", type=int, default=16) parser.add_argument("--test-num", type=int, default=10) # trpo special diff --git a/examples/mujoco/mujoco_trpo_hl.py b/examples/mujoco/mujoco_trpo_hl.py index 5901cc5..692e302 100644 --- a/examples/mujoco/mujoco_trpo_hl.py +++ b/examples/mujoco/mujoco_trpo_hl.py @@ -32,7 +32,7 @@ def main( step_per_epoch: int = 30000, step_per_collect: int = 1024, repeat_per_collect: int = 1, - batch_size: int = 99999, + batch_size: int | None = None, training_num: int = 16, test_num: int = 10, rew_norm: bool = True, diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 588b1fa..9e2033c 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -290,20 +290,32 @@ class ReplayBuffer: self._meta[ptr] = batch return ptr, ep_rew, ep_len, ep_idx - def sample_indices(self, batch_size: int) -> np.ndarray: + def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. Return all available indices in the buffer if batch_size is 0; return an empty numpy array if batch_size < 0 or no available index can be sampled. + + :param batch_size: the number of indices to be sampled. If None, it will be set + to the length of the buffer (i.e. return all available indices in a + random order). """ + if batch_size is None: + batch_size = len(self) if self.stack_num == 1 or not self._sample_avail: # most often case if batch_size > 0: return np.random.choice(self._size, batch_size) + # TODO: is this behavior really desired? if batch_size == 0: # construct current available indices return np.concatenate([np.arange(self._index, self._size), np.arange(self._index)]) return np.array([], int) + # TODO: raise error on negative batch_size instead? if batch_size < 0: return np.array([], int) + # TODO: simplify this code - shouldn't have such a large if-else + # with many returns for handling different stack nums. + # It is also not clear whether this is really necessary - frame stacking usually is handled + # by environment wrappers (e.g. FrameStack) and not by the replay buffer. all_indices = prev_indices = np.concatenate( [np.arange(self._index, self._size), np.arange(self._index)], ) @@ -314,7 +326,7 @@ class ReplayBuffer: return np.random.choice(all_indices, batch_size) return all_indices - def sample(self, batch_size: int) -> tuple[RolloutBatchProtocol, np.ndarray]: + def sample(self, batch_size: int | None) -> tuple[RolloutBatchProtocol, np.ndarray]: """Get a random sample from buffer with size = batch_size. Return all the data in the buffer if batch_size is 0. diff --git a/tianshou/data/buffer/her.py b/tianshou/data/buffer/her.py index 23a7fa6..1ae1e8f 100644 --- a/tianshou/data/buffer/her.py +++ b/tianshou/data/buffer/her.py @@ -84,7 +84,7 @@ class HERReplayBuffer(ReplayBuffer): self._restore_cache() return super().add(batch, buffer_ids) - def sample_indices(self, batch_size: int) -> np.ndarray: + def sample_indices(self, batch_size: int | None) -> np.ndarray: """Get a random sample of index with size = batch_size. Return all available indices in the buffer if batch_size is 0; return an \ diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index 760e078..e09b696 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -169,8 +169,10 @@ class ReplayBufferManager(ReplayBuffer): self._meta[ptrs] = batch return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) - def sample_indices(self, batch_size: int) -> np.ndarray: - if batch_size < 0: + def sample_indices(self, batch_size: int | None) -> np.ndarray: + # TODO: simplify this code + if batch_size is not None and batch_size < 0: + # TODO: raise error instead? return np.array([], int) if self._sample_avail and self.stack_num > 1: all_indices = np.concatenate( @@ -181,8 +183,10 @@ class ReplayBufferManager(ReplayBuffer): ) if batch_size == 0: return all_indices + if batch_size is None: + batch_size = len(all_indices) return np.random.choice(all_indices, batch_size) - if batch_size == 0: # get all available indices + if batch_size == 0 or batch_size is None: # get all available indices sample_num = np.zeros(self.buffer_num, int) else: buffer_idx = np.random.choice( diff --git a/tianshou/data/buffer/prio.py b/tianshou/data/buffer/prio.py index a84e496..bef6a06 100644 --- a/tianshou/data/buffer/prio.py +++ b/tianshou/data/buffer/prio.py @@ -58,8 +58,8 @@ class PrioritizedReplayBuffer(ReplayBuffer): self.init_weight(ptr) return ptr, ep_rew, ep_len, ep_idx - def sample_indices(self, batch_size: int) -> np.ndarray: - if batch_size > 0 and len(self) > 0: + def sample_indices(self, batch_size: int | None) -> np.ndarray: + if batch_size is not None and batch_size > 0 and len(self) > 0: scalar = np.random.rand(batch_size) * self.weight.reduce() return self.weight.get_prefix_sum_idx(scalar) # type: ignore return super().sample_indices(batch_size) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 50db30b..3c8b8ad 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -370,7 +370,7 @@ class BasePolicy(ABC, nn.Module): def update( self, - sample_size: int, + sample_size: int | None, buffer: ReplayBuffer | None, **kwargs: Any, ) -> dict[str, Any]: @@ -382,7 +382,9 @@ class BasePolicy(ABC, nn.Module): Please refer to :ref:`policy_state` for more detailed explanation. :param sample_size: 0 means it will extract all the data from the buffer, - otherwise it will sample a batch with given sample_size. + otherwise it will sample a batch with given sample_size. None also + means it will extract all the data from the buffer, but it will be shuffled + first. TODO: remove the option for 0? :param buffer: the corresponding replay buffer. :return: A dict, including the data needed to be logged (e.g., loss) from diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index 100199d..5461f36 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -139,7 +139,7 @@ class GAILPolicy(PPOPolicy): def learn( # type: ignore self, batch: RolloutBatchProtocol, - batch_size: int, + batch_size: int | None, repeat: int, **kwargs: Any, ) -> dict[str, list[float]]: diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index a4583a0..5d43ec1 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -142,14 +142,15 @@ class A2CPolicy(PGPolicy): def learn( # type: ignore self, batch: RolloutBatchProtocol, - batch_size: int, + batch_size: int | None, repeat: int, *args: Any, **kwargs: Any, ) -> dict[str, list[float]]: losses, actor_losses, vf_losses, ent_losses = [], [], [], [] + split_batch_size = batch_size or -1 for _ in range(repeat): - for minibatch in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(split_batch_size, merge_last=True): # calculate loss for actor dist = self(minibatch).dist log_prob = dist.log_prob(minibatch.act) diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index 8e838f6..bb1a1cc 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -109,13 +109,14 @@ class NPGPolicy(A2CPolicy): def learn( # type: ignore self, batch: Batch, - batch_size: int, + batch_size: int | None, repeat: int, **kwargs: Any, ) -> dict[str, list[float]]: actor_losses, vf_losses, kls = [], [], [] + split_batch_size = batch_size or -1 for _ in range(repeat): - for minibatch in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(split_batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient dist = self(minibatch).dist diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 8d6e22f..1bbc095 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -193,14 +193,15 @@ class PGPolicy(BasePolicy): def learn( # type: ignore self, batch: RolloutBatchProtocol, - batch_size: int, + batch_size: int | None, repeat: int, *args: Any, **kwargs: Any, ) -> dict[str, list[float]]: losses = [] + split_batch_size = batch_size or -1 for _ in range(repeat): - for minibatch in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(split_batch_size, merge_last=True): self.optim.zero_grad() result = self(minibatch) dist = result.dist diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 58e4444..3f8c47f 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -128,16 +128,17 @@ class PPOPolicy(A2CPolicy): def learn( # type: ignore self, batch: RolloutBatchProtocol, - batch_size: int, + batch_size: int | None, repeat: int, *args: Any, **kwargs: Any, ) -> dict[str, list[float]]: losses, clip_losses, vf_losses, ent_losses = [], [], [], [] + split_batch_size = batch_size or -1 for step in range(repeat): if self.recompute_adv and step > 0: batch = self._compute_returns(batch, self._buffer, self._indices) - for minibatch in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(split_batch_size, merge_last=True): # calculate loss for actor dist = self(minibatch).dist if self.norm_adv: diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index e0cbcc7..7546a25 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -91,13 +91,14 @@ class TRPOPolicy(NPGPolicy): def learn( # type: ignore self, batch: Batch, - batch_size: int, + batch_size: int | None, repeat: int, **kwargs: Any, ) -> dict[str, list[float]]: actor_losses, vf_losses, step_sizes, kls = [], [], [], [] + split_batch_size = batch_size or -1 for _ in range(repeat): - for minibatch in batch.split(batch_size, merge_last=True): + for minibatch in batch.split(split_batch_size, merge_last=True): # optimize actor # direction: calculate villia gradient dist = self(minibatch).dist # TODO could come from batch diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index df8b11c..fbd3e91 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -31,7 +31,7 @@ class BaseTrainer(ABC): :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. :param batch_size: the batch size of sample data, which is going to feed in - the policy network. + the policy network. If None, will use the whole buffer in each gradient step. :param train_collector: the collector used for training. :param test_collector: the collector used for testing. If it's None, then no testing will be performed. @@ -141,7 +141,7 @@ class BaseTrainer(ABC): self, policy: BasePolicy, max_epoch: int, - batch_size: int, + batch_size: int | None, train_collector: Collector | None = None, test_collector: Collector | None = None, buffer: ReplayBuffer | None = None, @@ -320,7 +320,9 @@ class BaseTrainer(ABC): # for offline RL if self.train_collector is None: - self.env_step = self.gradient_step * self.batch_size + assert self.buffer is not None + batch_size = self.batch_size or len(self.buffer) + self.env_step = self.gradient_step * batch_size if not self.stop_fn_flag: self.logger.save_data( @@ -565,9 +567,9 @@ class OnpolicyTrainer(BaseTrainer): """Perform one on-policy update.""" assert self.train_collector is not None losses = self.policy.update( - 0, - self.train_collector.buffer, - # Note: sample_size is 0, so the whole buffer is used for the update. + sample_size=0, + buffer=self.train_collector.buffer, + # Note: sample_size is None, so the whole buffer is used for the update. # The kwargs are in the end passed to the .learn method, which uses # batch_size to iterate through the buffer in mini-batches # Off-policy algos typically don't use the batch_size kwarg at all @@ -579,7 +581,9 @@ class OnpolicyTrainer(BaseTrainer): # TODO: remove the gradient step counting in trainers? Doesn't seem like # it's important and it adds complexity self.gradient_step += 1 - if self.batch_size > 0: + if self.batch_size is None: + self.gradient_step += 1 + elif self.batch_size > 0: self.gradient_step += int((len(self.train_collector.buffer) - 0.1) // self.batch_size) # Note: this is the main difference to the off-policy trainer!