diff --git a/README.md b/README.md index 2c027d5..2c5db38 100644 --- a/README.md +++ b/README.md @@ -139,12 +139,14 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page ### Modularized Policy -We decouple all of the algorithms into 4 parts: +We decouple all of the algorithms roughly into the following parts: - `__init__`: initialize the policy; - `forward`: to compute actions over given observations; - `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms); -- `learn`: to learn from a given batch data. +- `learn`: to learn from a given batch data; +- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight); +- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`. Within this API, we can interact with different policies conveniently. @@ -165,7 +167,7 @@ result = collector.collect(n_episode=[1, 0, 3]) If you want to train the given policy with a sampled batch: ```python -result = policy.learn(collector.sample(batch_size)) +result = policy.update(batch_size, collector.buffer) ``` You can check out the [documentation](https://tianshou.readthedocs.io) for further usage. diff --git a/docs/tutorials/concepts.rst b/docs/tutorials/concepts.rst index 3f033ea..ba771ad 100644 --- a/docs/tutorials/concepts.rst +++ b/docs/tutorials/concepts.rst @@ -64,12 +64,14 @@ Policy Tianshou aims to modularizing RL algorithms. It comes into several classes of policies in Tianshou. All of the policy classes must inherit :class:`~tianshou.policy.BasePolicy`. -A policy class typically has four parts: +A policy class typically has the following parts: -* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including coping the target network and so on; +* :meth:`~tianshou.policy.BasePolicy.__init__`: initialize the policy, including copying the target network and so on; * :meth:`~tianshou.policy.BasePolicy.forward`: compute action with given observation; -* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer (this function can interact with replay buffer); +* :meth:`~tianshou.policy.BasePolicy.process_fn`: pre-process data from the replay buffer; * :meth:`~tianshou.policy.BasePolicy.learn`: update policy with a given batch of data. +* :meth:`~tianshou.policy.BasePolicy.post_process_fn`: update the buffer with a given batch of data. +* :meth:`~tianshou.policy.BasePolicy.update`: the main interface for training. This function samples data from buffer, pre-process data (such as computing n-step return), learn with the data, and finally post-process the data (such as updating prioritized replay buffer); in short, ``process_fn -> learn -> post_process_fn``. Take 2-step return DQN as an example. The 2-step return DQN compute each frame's return as: @@ -125,10 +127,8 @@ Collector --------- The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently. -In short, :class:`~tianshou.data.Collector` has two main methods: -* :meth:`~tianshou.data.Collector.collect`: let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer; -* :meth:`~tianshou.data.Collector.sample`: sample a data batch from replay buffer; it will call :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the final batch data. +:class:`~tianshou.data.Collector` has one main method :meth:`~tianshou.data.Collector.collect`: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer. Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically. @@ -144,8 +144,6 @@ Once you have a collector and a policy, you can start writing the training metho Tianshou has two types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` and :func:`~tianshou.trainer.offpolicy_trainer`, corresponding to on-policy algorithms (such as Policy Gradient) and off-policy algorithms (such as DQN). Please check out :doc:`/api/tianshou.trainer` for the usage. -There will be more types of trainers, for instance, multi-agent trainer. - .. _pseudocode: @@ -165,7 +163,8 @@ We give a high-level explanation through the pseudocode used in section :ref:`po buffer.store(s, a, s_, r, d) # collector.collect(...) s = s_ # collector.collect(...) if i % 1000 == 0: # done in trainer - b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # collector.sample(batch_size) + # the following is done in policy.update(batch_size, buffer) + b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # buffer.sample(batch_size) # compute 2-step returns. How? b_ret = compute_2_step_return(buffer, b_r, b_d, ...) # policy.process_fn(batch, buffer, indice) # update DQN policy diff --git a/docs/tutorials/dqn.rst b/docs/tutorials/dqn.rst index e5edbfd..764bb45 100644 --- a/docs/tutorials/dqn.rst +++ b/docs/tutorials/dqn.rst @@ -210,8 +210,8 @@ Tianshou supports user-defined training code. Here is the code snippet: # back to training eps policy.set_eps(0.1) - # train policy with a sampled batch data - losses = policy.learn(train_collector.sample(batch_size=64)) + # train policy with a sampled batch data from buffer + losses = policy.update(64, train_collector.buffer) For further usage, you can refer to the :doc:`/tutorials/cheatsheet`. diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 68cf5c7..38e5d93 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -174,7 +174,7 @@ def test_collector_with_dict_state(): c1.seed(0) c1.collect(n_step=10) c1.collect(n_episode=[2, 1, 1, 2]) - batch = c1.sample(10) + batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) assert np.allclose(c0.buffer[:len(c0.buffer)].obs.index, np.expand_dims([ @@ -184,7 +184,7 @@ def test_collector_with_dict_state(): c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4), Logger.single_preprocess_fn) c2.collect(n_episode=[0, 0, 0, 10]) - batch = c2.sample(10) + batch, _ = c2.buffer.sample(10) print(batch['obs_next']['index']) @@ -209,7 +209,7 @@ def test_collector_with_ma(): assert np.asanyarray(r).size == 1 and r == 4. r = c1.collect(n_episode=[2, 1, 1, 2])['rew'] assert np.asanyarray(r).size == 1 and r == 4. - batch = c1.sample(10) + batch, _ = c1.buffer.sample(10) print(batch) c0.buffer.update(c1.buffer) obs = np.array(np.expand_dims([ @@ -226,7 +226,7 @@ def test_collector_with_ma(): Logger.single_preprocess_fn, reward_metric=reward_metric) r = c2.collect(n_episode=[0, 0, 0, 10])['rew'] assert np.asanyarray(r).size == 1 and r == 4. - batch = c2.sample(10) + batch, _ = c2.buffer.sample(10) print(batch['obs_next']) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 9105374..8d2863c 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -64,16 +64,6 @@ class Collector(object): # sleep time between rendering consecutive frames) collector.collect(n_episode=1, render=0.03) - # sample data with a given number of batch-size: - batch_data = collector.sample(batch_size=64) - # policy.learn(batch_data) # btw, vanilla policy gradient only - # supports on-policy training, so here we pick all data in the buffer - batch_data = collector.sample(batch_size=0) - policy.learn(batch_data) - # on-policy algorithms use the collected data only once, so here we - # clear the buffer - collector.reset_buffer() - Collected data always consist of full episodes. So if only ``n_step`` argument is give, the collector may return the data more than the ``n_step`` limitation. Same as ``n_episode`` for the multiple environment @@ -357,13 +347,18 @@ class Collector(object): def sample(self, batch_size: int) -> Batch: """Sample a data batch from the internal replay buffer. It will call - :meth:`~tianshou.policy.BasePolicy.process_fn` before returning - the final batch data. + :meth:`~tianshou.policy.BasePolicy.process_fn` before returning the + final batch data. :param int batch_size: ``0`` means it will extract all the data from the buffer, otherwise it will extract the data with the given batch_size. """ + import warnings + warnings.warn( + 'Collector.sample is deprecated and will cause error if you use ' + 'prioritized experience replay! Collector.sample will be removed ' + 'upon version 0.3. Use policy.update instead!', Warning) batch_data, indice = self.buffer.sample(batch_size) batch_data = self.process_fn(batch_data, self.buffer, indice) return batch_data diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index cc1a593..a2f545e 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -93,9 +93,8 @@ class BasePolicy(ABC, nn.Module): # some code return Batch(..., policy=Batch(log_prob=dist.log_prob(act))) - # and in the sampled data batch, you can directly call - # batch.policy.log_prob to get your data, although it is stored in - # np.ndarray. + # and in the sampled data batch, you can directly use + # batch.policy.log_prob to get your data. """ pass @@ -123,6 +122,7 @@ class BasePolicy(ABC, nn.Module): v_s_: Optional[Union[np.ndarray, torch.Tensor]] = None, gamma: float = 0.99, gae_lambda: float = 0.95, + rew_norm: bool = False, ) -> Batch: """Compute returns over given full-length episodes, including the implementation of Generalized Advantage Estimator (arXiv:1506.02438). @@ -136,6 +136,8 @@ class BasePolicy(ABC, nn.Module): to 0.99. :param float gae_lambda: the parameter for Generalized Advantage Estimation, should be in [0, 1], defaults to 0.95. + :param bool rew_norm: normalize the reward to Normal(0, 1), defaults + to ``False``. :return: a Batch. The result will be stored in batch.returns as a numpy array with shape (bsz, ). @@ -150,6 +152,8 @@ class BasePolicy(ABC, nn.Module): for i in range(len(rew) - 1, -1, -1): gae = delta[i] + m[i] * gae returns[i] += gae + if rew_norm and not np.isclose(returns.std(), 0, 1e-2): + returns = (returns - returns.mean()) / returns.std() batch.returns = returns return batch @@ -196,7 +200,7 @@ class BasePolicy(ABC, nn.Module): if rew_norm: bfr = rew[:min(len(buffer), 1000)] # avoid large buffer mean, std = bfr.mean(), bfr.std() - if np.isclose(std, 0): + if np.isclose(std, 0, 1e-2): mean, std = 0, 1 else: mean, std = 0, 1 @@ -216,9 +220,30 @@ class BasePolicy(ABC, nn.Module): batch.returns = target_q * gammas + returns # prio buffer update if isinstance(buffer, PrioritizedReplayBuffer): - batch.update_weight = buffer.update_weight - batch.indice = indice batch.weight = to_torch_as(batch.weight, target_q) else: batch.weight = torch.ones_like(target_q) return batch + + def post_process_fn(self, batch: Batch, + buffer: ReplayBuffer, indice: np.ndarray): + """Post-process the data from the provided replay buffer. Typical + usage is to update the sampling weight in prioritized experience + replay. Check out :ref:`policy_concept` for more information. + """ + if isinstance(buffer, PrioritizedReplayBuffer): + buffer.update_weight(indice, batch.weight) + + def update(self, batch_size: int, buffer: ReplayBuffer, *args, **kwargs): + """Update the policy network and replay buffer (if needed). It includes + three function steps: process_fn, learn, and post_process_fn. + + :param int batch_size: 0 means it will extract all the data from the + buffer, otherwise it will sample a batch with the given batch_size. + :param ReplayBuffer buffer: the corresponding replay buffer. + """ + batch, indice = buffer.sample(batch_size) + batch = self.process_fn(batch, buffer, indice) + result = self.learn(batch, *args, **kwargs) + self.post_process_fn(batch, buffer, indice) + return result diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 2a5c123..52d8dd2 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -67,7 +67,8 @@ class A2CPolicy(PGPolicy): v_.append(to_numpy(self.critic(b.obs_next))) v_ = np.concatenate(v_, axis=0) return self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda) + batch, v_, gamma=self._gamma, gae_lambda=self._lambda, + rew_norm=self._rew_norm) def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -97,9 +98,6 @@ class A2CPolicy(PGPolicy): def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: self._batch = batch_size - r = batch.returns - if self._rew_norm and not np.isclose(r.std(), 0): - batch.returns = (r - r.mean()) / r.std() losses, actor_losses, vf_losses, ent_losses = [], [], [], [] for _ in range(repeat): for b in batch.split(batch_size): diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 2205102..79a65d3 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -144,10 +144,8 @@ class DDPGPolicy(BasePolicy): current_q = self.critic(batch.obs, batch.act).flatten() target_q = batch.returns.flatten() td = current_q - target_q - if hasattr(batch, 'update_weight'): # prio-buffer - batch.update_weight(batch.indice, td) critic_loss = (td.pow(2) * batch.weight).mean() - # critic_loss = F.mse_loss(current_q, target_q) + batch.weight = td # prio-buffer self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index c37dac5..f1a01a6 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -160,10 +160,8 @@ class DQNPolicy(BasePolicy): q = q[np.arange(len(q)), batch.act] r = to_torch_as(batch.returns, q).flatten() td = r - q - if hasattr(batch, 'update_weight'): # prio-buffer - batch.update_weight(batch.indice, td) loss = (td.pow(2) * batch.weight).mean() - # loss = F.mse_loss(q, r) + batch.weight = td # prio-buffer loss.backward() self.optim.step() self._cnt += 1 diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index d6176e6..8fded95 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -51,7 +51,7 @@ class PGPolicy(BasePolicy): # batch.returns = self._vectorized_returns(batch) # return batch return self.compute_episodic_return( - batch, gamma=self._gamma, gae_lambda=1.) + batch, gamma=self._gamma, gae_lambda=1., rew_norm=self._rew_norm) def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -81,9 +81,6 @@ class PGPolicy(BasePolicy): def learn(self, batch: Batch, batch_size: int, repeat: int, **kwargs) -> Dict[str, List[float]]: losses = [] - r = batch.returns - if self._rew_norm and not np.isclose(r.std(), 0): - batch.returns = (r - r.mean()) / r.std() for _ in range(repeat): for b in batch.split(batch_size): self.optim.zero_grad() diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index 2d1dece..3094be8 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -79,18 +79,29 @@ class PPOPolicy(PGPolicy): indice: np.ndarray) -> Batch: if self._rew_norm: mean, std = batch.rew.mean(), batch.rew.std() - if not np.isclose(std, 0): + if not np.isclose(std, 0, 1e-2): batch.rew = (batch.rew - mean) / std - if self._lambda in [0, 1]: - return self.compute_episodic_return( - batch, None, gamma=self._gamma, gae_lambda=self._lambda) - v_ = [] + v, v_, old_log_prob = [], [], [] with torch.no_grad(): for b in batch.split(self._batch, shuffle=False): v_.append(self.critic(b.obs_next)) + v.append(self.critic(b.obs)) + old_log_prob.append(self(b).dist.log_prob( + to_torch_as(b.act, v[0]))) v_ = to_numpy(torch.cat(v_, dim=0)) - return self.compute_episodic_return( - batch, v_, gamma=self._gamma, gae_lambda=self._lambda) + batch = self.compute_episodic_return( + batch, v_, gamma=self._gamma, gae_lambda=self._lambda, + rew_norm=self._rew_norm) + batch.v = torch.cat(v, dim=0).flatten() # old value + batch.act = to_torch_as(batch.act, v[0]) + batch.logp_old = torch.cat(old_log_prob, dim=0) + batch.returns = to_torch_as(batch.returns, v[0]) + batch.adv = batch.returns - batch.v + if self._rew_norm: + mean, std = batch.adv.mean(), batch.adv.std() + if not np.isclose(std.item(), 0, 1e-2): + batch.adv = (batch.adv - mean) / std + return batch def forward(self, batch: Batch, state: Optional[Union[dict, Batch, np.ndarray]] = None, @@ -123,26 +134,6 @@ class PPOPolicy(PGPolicy): **kwargs) -> Dict[str, List[float]]: self._batch = batch_size losses, clip_losses, vf_losses, ent_losses = [], [], [], [] - v = [] - old_log_prob = [] - with torch.no_grad(): - for b in batch.split(batch_size, shuffle=False): - v.append(self.critic(b.obs)) - old_log_prob.append(self(b).dist.log_prob( - to_torch_as(b.act, v[0]))) - batch.v = torch.cat(v, dim=0).flatten() # old value - batch.act = to_torch_as(batch.act, v[0]) - batch.logp_old = torch.cat(old_log_prob, dim=0) - batch.returns = to_torch_as(batch.returns, v[0]) - if self._rew_norm: - mean, std = batch.returns.mean(), batch.returns.std() - if not np.isclose(std.item(), 0): - batch.returns = (batch.returns - mean) / std - batch.adv = batch.returns - batch.v - if self._rew_norm: - mean, std = batch.adv.mean(), batch.adv.std() - if not np.isclose(std.item(), 0): - batch.adv = (batch.adv - mean) / std for _ in range(repeat): for b in batch.split(batch_size): dist = self(b).dist diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index ce4a5ba..341fe7b 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -154,9 +154,7 @@ class SACPolicy(DDPGPolicy): self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - # prio-buffer - if hasattr(batch, 'update_weight'): - batch.update_weight(batch.indice, (td1 + td2) / 2.) + batch.weight = (td1 + td2) / 2. # prio-buffer # actor obs_result = self(batch, explorating=False) a = obs_result.act diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index 698145f..9a34095 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -132,8 +132,7 @@ class TD3Policy(DDPGPolicy): self.critic2_optim.zero_grad() critic2_loss.backward() self.critic2_optim.step() - if hasattr(batch, 'update_weight'): # prio-buffer - batch.update_weight(batch.indice, (td1 + td2) / 2.) + batch.weight = (td1 + td2) / 2. # prio-buffer if self._cnt % self._freq == 0: actor_loss = -self.critic1( batch.obs, self(batch, eps=0).act).mean() diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 408d4e7..171cbb9 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -28,7 +28,8 @@ def offpolicy_trainer( verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: - """A wrapper for off-policy trainer procedure. + """A wrapper for off-policy trainer procedure. The ``step`` in trainer + means a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. @@ -47,8 +48,9 @@ def offpolicy_trainer( :param int batch_size: the batch size of sample data, which is going to feed in the policy network. :param int update_per_step: the number of times the policy network would - be updated after frames be collected. In other words, collect some - frames and do some policy network update. + be updated after frames are collected, for example, set it to 256 means + it updates policy 256 times once after ``collect_per_step`` frames are + collected. :param function train_fn: a function receives the current number of epoch index and performs some operations at the beginning of training in this epoch. @@ -103,7 +105,7 @@ def offpolicy_trainer( for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): global_step += 1 - losses = policy.learn(train_collector.sample(batch_size)) + losses = policy.update(batch_size, train_collector.buffer) for k in result.keys(): data[k] = f'{result[k]:.2f}' if writer and global_step % log_interval == 0: diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index dec42b2..e31724d 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -28,7 +28,8 @@ def onpolicy_trainer( verbose: bool = True, test_in_train: bool = True, ) -> Dict[str, Union[float, str]]: - """A wrapper for on-policy trainer procedure. + """A wrapper for on-policy trainer procedure. The ``step`` in trainer means + a policy network update. :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. @@ -101,8 +102,8 @@ def onpolicy_trainer( policy.train() if train_fn: train_fn(epoch) - losses = policy.learn( - train_collector.sample(0), batch_size, repeat_per_collect) + losses = policy.update( + 0, train_collector.buffer, batch_size, repeat_per_collect) train_collector.reset_buffer() step = 1 for k in losses.keys():