set policy.eval() before collector.collect (#204)

* fix #203

* no_grad argument in collector.collect
This commit is contained in:
n+e 2020-09-06 16:20:16 +08:00 committed by GitHub
parent 34f714a677
commit 8bb8ecba6e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 25 deletions

View File

@ -56,10 +56,8 @@ $ pip install tianshou
You can also install with the newest version through GitHub:
```bash
# latest release
# latest version
$ pip install git+https://github.com/thu-ml/tianshou.git@master
# develop version
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
```
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:

View File

@ -8,7 +8,6 @@ To install Tianshou in an "editable" mode, run
.. code-block:: bash
$ git checkout dev
$ pip install -e ".[dev]"
in the main directory. This installation is removable by
@ -70,9 +69,4 @@ To compile documentation into webpages, run
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
Pull Request
------------
All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.

View File

@ -46,10 +46,8 @@ You can also install with the newest version through GitHub:
.. code-block:: bash
# latest release
# latest version
$ pip install git+https://github.com/thu-ml/tianshou.git@master
# develop version
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
@ -70,7 +68,7 @@ After installation, open your python console and type
If no error occurs, you have successfully installed Tianshou.
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
.. toctree::
:maxdepth: 1

View File

@ -173,6 +173,7 @@ class Collector(object):
n_episode: Optional[Union[int, List[int]]] = None,
random: bool = False,
render: Optional[float] = None,
no_grad: bool = True,
) -> Dict[str, float]:
"""Collect a specified number of step or episode.
@ -185,6 +186,8 @@ class Collector(object):
defaults to ``False``.
:param float render: the sleep time between rendering consecutive
frames, defaults to ``None`` (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward,
defaults to ``True`` (no gradient retaining).
.. note::
@ -252,7 +255,10 @@ class Collector(object):
result = Batch(
act=[spaces[i].sample() for i in self._ready_env_ids])
else:
with torch.no_grad():
if no_grad:
with torch.no_grad(): # faster than retain_grad version
result = self.policy(self.data, last_state)
else:
result = self.policy(self.data, last_state)
state = result.get('state', Batch())

View File

@ -76,13 +76,13 @@ def offpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_step=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
@ -99,9 +99,10 @@ def offpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
for i in range(update_per_step * min(
result['n/st'] // collect_per_step, t.total - t.n)):
global_step += collect_per_step

View File

@ -76,13 +76,13 @@ def onpolicy_trainer(
start_time = time.time()
test_in_train = test_in_train and train_collector.policy == policy
for epoch in range(1, 1 + max_epoch):
# train
policy.train()
if train_fn:
train_fn(epoch)
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
**tqdm_config) as t:
while t.n < t.total:
# collect
if train_fn:
train_fn(epoch)
policy.eval()
result = train_collector.collect(n_episode=collect_per_step)
data = {}
if test_in_train and stop_fn and stop_fn(result['rew']):
@ -99,9 +99,10 @@ def onpolicy_trainer(
start_time, train_collector, test_collector,
test_result['rew'])
else:
policy.train()
if train_fn:
train_fn(epoch)
# train
policy.train()
losses = policy.update(
0, train_collector.buffer, batch_size, repeat_per_collect)
train_collector.reset_buffer()