set policy.eval() before collector.collect (#204)
* fix #203 * no_grad argument in collector.collect
This commit is contained in:
parent
34f714a677
commit
8bb8ecba6e
@ -56,10 +56,8 @@ $ pip install tianshou
|
|||||||
You can also install with the newest version through GitHub:
|
You can also install with the newest version through GitHub:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# latest release
|
# latest version
|
||||||
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||||
# develop version
|
|
||||||
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||||
|
@ -8,7 +8,6 @@ To install Tianshou in an "editable" mode, run
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
$ git checkout dev
|
|
||||||
$ pip install -e ".[dev]"
|
$ pip install -e ".[dev]"
|
||||||
|
|
||||||
in the main directory. This installation is removable by
|
in the main directory. This installation is removable by
|
||||||
@ -70,9 +69,4 @@ To compile documentation into webpages, run
|
|||||||
|
|
||||||
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
|
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
|
||||||
|
|
||||||
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
|
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
|
||||||
|
|
||||||
Pull Request
|
|
||||||
------------
|
|
||||||
|
|
||||||
All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
|
|
||||||
|
@ -46,10 +46,8 @@ You can also install with the newest version through GitHub:
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
# latest release
|
# latest version
|
||||||
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
$ pip install git+https://github.com/thu-ml/tianshou.git@master
|
||||||
# develop version
|
|
||||||
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
|
|
||||||
|
|
||||||
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||||
|
|
||||||
@ -70,7 +68,7 @@ After installation, open your python console and type
|
|||||||
|
|
||||||
If no error occurs, you have successfully installed Tianshou.
|
If no error occurs, you have successfully installed Tianshou.
|
||||||
|
|
||||||
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
|
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
@ -173,6 +173,7 @@ class Collector(object):
|
|||||||
n_episode: Optional[Union[int, List[int]]] = None,
|
n_episode: Optional[Union[int, List[int]]] = None,
|
||||||
random: bool = False,
|
random: bool = False,
|
||||||
render: Optional[float] = None,
|
render: Optional[float] = None,
|
||||||
|
no_grad: bool = True,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""Collect a specified number of step or episode.
|
"""Collect a specified number of step or episode.
|
||||||
|
|
||||||
@ -185,6 +186,8 @@ class Collector(object):
|
|||||||
defaults to ``False``.
|
defaults to ``False``.
|
||||||
:param float render: the sleep time between rendering consecutive
|
:param float render: the sleep time between rendering consecutive
|
||||||
frames, defaults to ``None`` (no rendering).
|
frames, defaults to ``None`` (no rendering).
|
||||||
|
:param bool no_grad: whether to retain gradient in policy.forward,
|
||||||
|
defaults to ``True`` (no gradient retaining).
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -252,7 +255,10 @@ class Collector(object):
|
|||||||
result = Batch(
|
result = Batch(
|
||||||
act=[spaces[i].sample() for i in self._ready_env_ids])
|
act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||||
else:
|
else:
|
||||||
with torch.no_grad():
|
if no_grad:
|
||||||
|
with torch.no_grad(): # faster than retain_grad version
|
||||||
|
result = self.policy(self.data, last_state)
|
||||||
|
else:
|
||||||
result = self.policy(self.data, last_state)
|
result = self.policy(self.data, last_state)
|
||||||
|
|
||||||
state = result.get('state', Batch())
|
state = result.get('state', Batch())
|
||||||
|
@ -76,13 +76,13 @@ def offpolicy_trainer(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
|
||||||
policy.train()
|
|
||||||
if train_fn:
|
|
||||||
train_fn(epoch)
|
|
||||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
**tqdm_config) as t:
|
**tqdm_config) as t:
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
|
# collect
|
||||||
|
if train_fn:
|
||||||
|
train_fn(epoch)
|
||||||
|
policy.eval()
|
||||||
result = train_collector.collect(n_step=collect_per_step)
|
result = train_collector.collect(n_step=collect_per_step)
|
||||||
data = {}
|
data = {}
|
||||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||||
@ -99,9 +99,10 @@ def offpolicy_trainer(
|
|||||||
start_time, train_collector, test_collector,
|
start_time, train_collector, test_collector,
|
||||||
test_result['rew'])
|
test_result['rew'])
|
||||||
else:
|
else:
|
||||||
policy.train()
|
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
for i in range(update_per_step * min(
|
for i in range(update_per_step * min(
|
||||||
result['n/st'] // collect_per_step, t.total - t.n)):
|
result['n/st'] // collect_per_step, t.total - t.n)):
|
||||||
global_step += collect_per_step
|
global_step += collect_per_step
|
||||||
|
@ -76,13 +76,13 @@ def onpolicy_trainer(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
test_in_train = test_in_train and train_collector.policy == policy
|
test_in_train = test_in_train and train_collector.policy == policy
|
||||||
for epoch in range(1, 1 + max_epoch):
|
for epoch in range(1, 1 + max_epoch):
|
||||||
# train
|
|
||||||
policy.train()
|
|
||||||
if train_fn:
|
|
||||||
train_fn(epoch)
|
|
||||||
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
|
||||||
**tqdm_config) as t:
|
**tqdm_config) as t:
|
||||||
while t.n < t.total:
|
while t.n < t.total:
|
||||||
|
# collect
|
||||||
|
if train_fn:
|
||||||
|
train_fn(epoch)
|
||||||
|
policy.eval()
|
||||||
result = train_collector.collect(n_episode=collect_per_step)
|
result = train_collector.collect(n_episode=collect_per_step)
|
||||||
data = {}
|
data = {}
|
||||||
if test_in_train and stop_fn and stop_fn(result['rew']):
|
if test_in_train and stop_fn and stop_fn(result['rew']):
|
||||||
@ -99,9 +99,10 @@ def onpolicy_trainer(
|
|||||||
start_time, train_collector, test_collector,
|
start_time, train_collector, test_collector,
|
||||||
test_result['rew'])
|
test_result['rew'])
|
||||||
else:
|
else:
|
||||||
policy.train()
|
|
||||||
if train_fn:
|
if train_fn:
|
||||||
train_fn(epoch)
|
train_fn(epoch)
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
losses = policy.update(
|
losses = policy.update(
|
||||||
0, train_collector.buffer, batch_size, repeat_per_collect)
|
0, train_collector.buffer, batch_size, repeat_per_collect)
|
||||||
train_collector.reset_buffer()
|
train_collector.reset_buffer()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user