From 8bb8ecba6ea28fdba8f6130755b0ddf493073c84 Mon Sep 17 00:00:00 2001
From: n+e <463003665@qq.com>
Date: Sun, 6 Sep 2020 16:20:16 +0800
Subject: [PATCH] set policy.eval() before collector.collect (#204)
* fix #203
* no_grad argument in collector.collect
---
README.md | 4 +---
docs/contributing.rst | 8 +-------
docs/index.rst | 6 ++----
tianshou/data/collector.py | 8 +++++++-
tianshou/trainer/offpolicy.py | 11 ++++++-----
tianshou/trainer/onpolicy.py | 11 ++++++-----
6 files changed, 23 insertions(+), 25 deletions(-)
diff --git a/README.md b/README.md
index bc89aee..b2a2079 100644
--- a/README.md
+++ b/README.md
@@ -56,10 +56,8 @@ $ pip install tianshou
You can also install with the newest version through GitHub:
```bash
-# latest release
+# latest version
$ pip install git+https://github.com/thu-ml/tianshou.git@master
-# develop version
-$ pip install git+https://github.com/thu-ml/tianshou.git@dev
```
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
diff --git a/docs/contributing.rst b/docs/contributing.rst
index 8c60f7e..063db78 100644
--- a/docs/contributing.rst
+++ b/docs/contributing.rst
@@ -8,7 +8,6 @@ To install Tianshou in an "editable" mode, run
.. code-block:: bash
- $ git checkout dev
$ pip install -e ".[dev]"
in the main directory. This installation is removable by
@@ -70,9 +69,4 @@ To compile documentation into webpages, run
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
-Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
-
-Pull Request
-------------
-
-All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
+Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
diff --git a/docs/index.rst b/docs/index.rst
index 9ef598a..dba58bd 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -46,10 +46,8 @@ You can also install with the newest version through GitHub:
.. code-block:: bash
- # latest release
+ # latest version
$ pip install git+https://github.com/thu-ml/tianshou.git@master
- # develop version
- $ pip install git+https://github.com/thu-ml/tianshou.git@dev
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
@@ -70,7 +68,7 @@ After installation, open your python console and type
If no error occurs, you have successfully installed Tianshou.
-Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_ and the develop version through `tianshou.readthedocs.io/en/dev/ `_.
+Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ `_.
.. toctree::
:maxdepth: 1
diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py
index 57f885d..4ffacac 100644
--- a/tianshou/data/collector.py
+++ b/tianshou/data/collector.py
@@ -173,6 +173,7 @@ class Collector(object):
n_episode: Optional[Union[int, List[int]]] = None,
random: bool = False,
render: Optional[float] = None,
+ no_grad: bool = True,
) -> Dict[str, float]:
"""Collect a specified number of step or episode.
@@ -185,6 +186,8 @@ class Collector(object):
defaults to ``False``.
:param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering).
+ :param bool no_grad: whether to retain gradient in policy.forward,
+ defaults to ``True`` (no gradient retaining).
.. note::
@@ -252,7 +255,10 @@ class Collector(object):
result = Batch(
act=[spaces[i].sample() for i in self._ready_env_ids])
else:
- with torch.no_grad():
+ if no_grad:
+ with torch.no_grad(): # faster than retain_grad version
+ result = self.policy(self.data, last_state)
+ else:
result = self.policy(self.data, last_state)
state = result.get('state', Batch())
diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py
index 153f94d..a7c6ffa 100644
--- a/tianshou/trainer/offpolicy.py
+++ b/tianshou/trainer/offpolicy.py
@@ -76,13 +76,13 @@ def offpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
- # train
- policy.train()
- if train_fn:
- train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
+ # collect
+ if train_fn:
+ train_fn(epoch)
+ policy.eval()
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
@@ -99,9 +99,10 @@ def offpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
- policy.train()
if train_fn:
train_fn(epoch)
+ # train
+ policy.train()
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += collect_per_step
diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py
index ea57ed1..0a564fc 100644
--- a/tianshou/trainer/onpolicy.py
+++ b/tianshou/trainer/onpolicy.py
@@ -76,13 +76,13 @@ def onpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
- # train
- policy.train()
- if train_fn:
- train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
+ # collect
+ if train_fn:
+ train_fn(epoch)
+ policy.eval()
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
@@ -99,9 +99,10 @@ def onpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
- policy.train()
if train_fn:
train_fn(epoch)
+ # train
+ policy.train()
losses = policy.update(
0, train_collector.buffer, batch_size, repeat_per_collect)
train_collector.reset_buffer()