set policy.eval() before collector.collect (#204)
* fix #203 * no_grad argument in collector.collect
This commit is contained in:
parent
34f714a677
commit
8bb8ecba6e
@ -56,10 +56,8 @@ $ pip install tianshou
|
||||
You can also install with the newest version through GitHub:
|
||||
|
||||
```bash
|
||||
# latest release
|
||||
# latest version
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||
# develop version
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
||||
```
|
||||
|
||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||
|
@ -8,7 +8,6 @@ To install Tianshou in an "editable" mode, run
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ git checkout dev
|
||||
$ pip install -e ".[dev]"
|
||||
|
||||
in the main directory. This installation is removable by
|
||||
@ -70,9 +69,4 @@ To compile documentation into webpages, run
|
||||
|
||||
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
|
||||
|
||||
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
|
||||
|
||||
Pull Request
|
||||
------------
|
||||
|
||||
All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
|
||||
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
|
||||
|
@ -46,10 +46,8 @@ You can also install with the newest version through GitHub:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# latest release
|
||||
# latest version
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||
# develop version
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
||||
|
||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||
|
||||
@ -70,7 +68,7 @@ After installation, open your python console and type
|
||||
|
||||
If no error occurs, you have successfully installed Tianshou.
|
||||
|
||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
|
||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
@ -173,6 +173,7 @@ class Collector(object):
|
||||
n_episode: Optional[Union[int, List[int]]] = None,
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
no_grad: bool = True,
|
||||
) -> Dict[str, float]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
@ -185,6 +186,8 @@ class Collector(object):
|
||||
defaults to ``False``.
|
||||
:param float render: the sleep time between rendering consecutive
|
||||
frames, defaults to ``None`` (no rendering).
|
||||
:param bool no_grad: whether to retain gradient in policy.forward,
|
||||
defaults to ``True`` (no gradient retaining).
|
||||
|
||||
.. note::
|
||||
|
||||
@ -252,7 +255,10 @@ class Collector(object):
|
||||
result = Batch(
|
||||
act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if no_grad:
|
||||
with torch.no_grad(): # faster than retain_grad version
|
||||
result = self.policy(self.data, last_state)
|
||||
else:
|
||||
result = self.policy(self.data, last_state)
|
||||
|
||||
state = result.get('state', Batch())
|
||||
|
@ -76,13 +76,13 @@ def offpolicy_trainer(
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
# collect
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
policy.eval()
|
||||
result = train_collector.collect(n_step=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||
@ -99,9 +99,10 @@ def offpolicy_trainer(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result['rew'])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
# train
|
||||
policy.train()
|
||||
for i in range(update_per_step * min(
|
||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||
global_step += collect_per_step
|
||||
|
@ -76,13 +76,13 @@ def onpolicy_trainer(
|
||||
start_time = time.time()
|
||||
test_in_train = test_in_train and train_collector.policy == policy
|
||||
for epoch in range(1, 1 + max_epoch):
|
||||
# train
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||
**tqdm_config) as t:
|
||||
while t.n < t.total:
|
||||
# collect
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
policy.eval()
|
||||
result = train_collector.collect(n_episode=collect_per_step)
|
||||
data = {}
|
||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||
@ -99,9 +99,10 @@ def onpolicy_trainer(
|
||||
start_time, train_collector, test_collector,
|
||||
test_result['rew'])
|
||||
else:
|
||||
policy.train()
|
||||
if train_fn:
|
||||
train_fn(epoch)
|
||||
# train
|
||||
policy.train()
|
||||
losses = policy.update(
|
||||
0, train_collector.buffer, batch_size, repeat_per_collect)
|
||||
train_collector.reset_buffer()
|
||||
|
Loading…
x
Reference in New Issue
Block a user