From 8bb8ecba6ea28fdba8f6130755b0ddf493073c84 Mon Sep 17 00:00:00 2001 From: n+e <463003665@qq.com> Date: Sun, 6 Sep 2020 16:20:16 +0800 Subject: [PATCH] set policy.eval() before collector.collect (#204) * fix #203 * no_grad argument in collector.collect --- README.md | 4 +--- docs/contributing.rst | 8 +------- docs/index.rst | 6 ++---- tianshou/data/collector.py | 8 +++++++- tianshou/trainer/offpolicy.py | 11 ++++++----- tianshou/trainer/onpolicy.py | 11 ++++++----- 6 files changed, 23 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index bc89aee..b2a2079 100644 --- a/README.md +++ b/README.md @@ -56,10 +56,8 @@ $ pip install tianshou You can also install with the newest version through GitHub: ```bash -# latest release +# latest version $ pip install git+https://github.com/thu-ml/tianshou.git@master -# develop version -$ pip install git+https://github.com/thu-ml/tianshou.git@dev ``` If you use Anaconda or Miniconda, you can install Tianshou through the following command lines: diff --git a/docs/contributing.rst b/docs/contributing.rst index 8c60f7e..063db78 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -8,7 +8,6 @@ To install Tianshou in an "editable" mode, run .. code-block:: bash - $ git checkout dev $ pip install -e ".[dev]" in the main directory. This installation is removable by @@ -70,9 +69,4 @@ To compile documentation into webpages, run under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers. -Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/. - -Pull Request ------------- - -All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging. +Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/. diff --git a/docs/index.rst b/docs/index.rst index 9ef598a..dba58bd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,10 +46,8 @@ You can also install with the newest version through GitHub: .. code-block:: bash - # latest release + # latest version $ pip install git+https://github.com/thu-ml/tianshou.git@master - # develop version - $ pip install git+https://github.com/thu-ml/tianshou.git@dev If you use Anaconda or Miniconda, you can install Tianshou through the following command lines: @@ -70,7 +68,7 @@ After installation, open your python console and type If no error occurs, you have successfully installed Tianshou. -Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_ and the develop version through `tianshou.readthedocs.io/en/dev/ `_. +Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_. .. toctree:: :maxdepth: 1 diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 57f885d..4ffacac 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -173,6 +173,7 @@ class Collector(object): n_episode: Optional[Union[int, List[int]]] = None, random: bool = False, render: Optional[float] = None, + no_grad: bool = True, ) -> Dict[str, float]: """Collect a specified number of step or episode. @@ -185,6 +186,8 @@ class Collector(object): defaults to ``False``. :param float render: the sleep time between rendering consecutive frames, defaults to ``None`` (no rendering). + :param bool no_grad: whether to retain gradient in policy.forward, + defaults to ``True`` (no gradient retaining). .. note:: @@ -252,7 +255,10 @@ class Collector(object): result = Batch( act=[spaces[i].sample() for i in self._ready_env_ids]) else: - with torch.no_grad(): + if no_grad: + with torch.no_grad(): # faster than retain_grad version + result = self.policy(self.data, last_state) + else: result = self.policy(self.data, last_state) state = result.get('state', Batch()) diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py index 153f94d..a7c6ffa 100644 --- a/tianshou/trainer/offpolicy.py +++ b/tianshou/trainer/offpolicy.py @@ -76,13 +76,13 @@ def offpolicy_trainer( start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): - # train - policy.train() - if train_fn: - train_fn(epoch) with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: + # collect + if train_fn: + train_fn(epoch) + policy.eval() result = train_collector.collect(n_step=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): @@ -99,9 +99,10 @@ def offpolicy_trainer( start_time, train_collector, test_collector, test_result['rew']) else: - policy.train() if train_fn: train_fn(epoch) + # train + policy.train() for i in range(update_per_step * min( result['n/st'] // collect_per_step, t.total - t.n)): global_step += collect_per_step diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py index ea57ed1..0a564fc 100644 --- a/tianshou/trainer/onpolicy.py +++ b/tianshou/trainer/onpolicy.py @@ -76,13 +76,13 @@ def onpolicy_trainer( start_time = time.time() test_in_train = test_in_train and train_collector.policy == policy for epoch in range(1, 1 + max_epoch): - # train - policy.train() - if train_fn: - train_fn(epoch) with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}', **tqdm_config) as t: while t.n < t.total: + # collect + if train_fn: + train_fn(epoch) + policy.eval() result = train_collector.collect(n_episode=collect_per_step) data = {} if test_in_train and stop_fn and stop_fn(result['rew']): @@ -99,9 +99,10 @@ def onpolicy_trainer( start_time, train_collector, test_collector, test_result['rew']) else: - policy.train() if train_fn: train_fn(epoch) + # train + policy.train() losses = policy.update( 0, train_collector.buffer, batch_size, repeat_per_collect) train_collector.reset_buffer()