fix critical bugs in MAPolicy and docs update (#207)
- fix a bug in MAPolicy: `buffer.rew = Batch()` doesn't change `buffer.rew` (thanks mypy) - polish examples/box2d/bipedal_hardcore_sac.py - several docs update - format setup.py and bump version to 0.2.7
This commit is contained in:
parent
380e9e911d
commit
64af7ea839
@ -36,7 +36,7 @@ Here is Tianshou's other features:
|
|||||||
- Elegant framework, using only ~2000 lines of code
|
- Elegant framework, using only ~2000 lines of code
|
||||||
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
|
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
|
||||||
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
|
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
|
||||||
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
||||||
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||||
@ -74,8 +74,8 @@ $ pip install tianshou
|
|||||||
After installation, open your python console and type
|
After installation, open your python console and type
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import tianshou as ts
|
import tianshou
|
||||||
print(ts.__version__)
|
print(tianshou.__version__)
|
||||||
```
|
```
|
||||||
|
|
||||||
If no error occurs, you have successfully installed Tianshou.
|
If no error occurs, you have successfully installed Tianshou.
|
||||||
|
@ -24,11 +24,11 @@ Welcome to Tianshou!
|
|||||||
Here is Tianshou's other features:
|
Here is Tianshou's other features:
|
||||||
|
|
||||||
* Elegant framework, using only ~2000 lines of code
|
* Elegant framework, using only ~2000 lines of code
|
||||||
* Support parallel environment sampling for all algorithms: :ref:`parallel_sampling`
|
* Support parallel environment simulation (synchronous or asynchronous) for all algorithms: :ref:`parallel_sampling`
|
||||||
* Support recurrent state representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
|
* Support recurrent state/action representation in actor network and critic network (RNN-style training for POMDP): :ref:`rnn_training`
|
||||||
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
|
* Support any type of environment state (e.g. a dict, a self-defined class, ...): :ref:`self_defined_env`
|
||||||
* Support customized training process: :ref:`customize_training`
|
* Support customized training process: :ref:`customize_training`
|
||||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay for all Q-learning based algorithms
|
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||||
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
|
* Support multi-agent RL: :doc:`/tutorials/tictactoe`
|
||||||
|
|
||||||
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
中文文档位于 `https://tianshou.readthedocs.io/zh/latest/ <https://tianshou.readthedocs.io/zh/latest/>`_
|
||||||
@ -63,8 +63,8 @@ If you use Anaconda or Miniconda, you can install Tianshou through the following
|
|||||||
After installation, open your python console and type
|
After installation, open your python console and type
|
||||||
::
|
::
|
||||||
|
|
||||||
import tianshou as ts
|
import tianshou
|
||||||
print(ts.__version__)
|
print(tianshou.__version__)
|
||||||
|
|
||||||
If no error occurs, you have successfully installed Tianshou.
|
If no error occurs, you have successfully installed Tianshou.
|
||||||
|
|
||||||
|
7
examples/box2d/README.md
Normal file
7
examples/box2d/README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# Bipedal-Hardcore-SAC
|
||||||
|
|
||||||
|
- Our default choice: remove the done flag penalty, will soon converge to \~250 reward within 100 epochs (10M env steps, 3~4 hours, see the image below)
|
||||||
|
- If the done penalty is not removed, it converges much slower than before, about 200 epochs (20M env steps) to reach the same performance (\~200 reward)
|
||||||
|
- Action noise is only necessary in the beginning. It is a negative impact at the end of the training. Removing it can reach \~255 (our best result under the original env, no done penalty removed).
|
||||||
|
|
||||||
|

|
@ -24,13 +24,13 @@ def get_args():
|
|||||||
parser.add_argument('--gamma', type=float, default=0.99)
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
parser.add_argument('--tau', type=float, default=0.005)
|
parser.add_argument('--tau', type=float, default=0.005)
|
||||||
parser.add_argument('--alpha', type=float, default=0.1)
|
parser.add_argument('--alpha', type=float, default=0.1)
|
||||||
parser.add_argument('--epoch', type=int, default=1000)
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
parser.add_argument('--step-per-epoch', type=int, default=2400)
|
parser.add_argument('--step-per-epoch', type=int, default=10000)
|
||||||
parser.add_argument('--collect-per-step', type=int, default=10)
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
parser.add_argument('--batch-size', type=int, default=128)
|
parser.add_argument('--batch-size', type=int, default=128)
|
||||||
parser.add_argument('--layer-num', type=int, default=1)
|
parser.add_argument('--layer-num', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=8)
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
parser.add_argument('--test-num', type=int, default=8)
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.)
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
parser.add_argument('--rew-norm', type=int, default=0)
|
parser.add_argument('--rew-norm', type=int, default=0)
|
||||||
@ -39,14 +39,14 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--device', type=str,
|
'--device', type=str,
|
||||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
parser.add_argument('--resume_path', type=str, default=None)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
class EnvWrapper(object):
|
class EnvWrapper(object):
|
||||||
"""Env wrapper for reward scale, action repeat and action noise"""
|
"""Env wrapper for reward scale, action repeat and action noise"""
|
||||||
|
|
||||||
def __init__(self, task, action_repeat=3,
|
def __init__(self, task, action_repeat=3, reward_scale=5, act_noise=0.3):
|
||||||
reward_scale=5, act_noise=0.3):
|
|
||||||
self._env = gym.make(task)
|
self._env = gym.make(task)
|
||||||
self.action_repeat = action_repeat
|
self.action_repeat = action_repeat
|
||||||
self.reward_scale = reward_scale
|
self.reward_scale = reward_scale
|
||||||
@ -70,8 +70,6 @@ class EnvWrapper(object):
|
|||||||
|
|
||||||
|
|
||||||
def test_sac_bipedal(args=get_args()):
|
def test_sac_bipedal(args=get_args()):
|
||||||
torch.set_num_threads(1) # we just need only one thread for NN
|
|
||||||
|
|
||||||
env = EnvWrapper(args.task)
|
env = EnvWrapper(args.task)
|
||||||
|
|
||||||
def IsStop(reward):
|
def IsStop(reward):
|
||||||
@ -118,6 +116,10 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
reward_normalization=args.rew_norm,
|
reward_normalization=args.rew_norm,
|
||||||
ignore_done=args.ignore_done,
|
ignore_done=args.ignore_done,
|
||||||
estimation_step=args.n_step)
|
estimation_step=args.n_step)
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
|
||||||
# collector
|
# collector
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
@ -135,7 +137,8 @@ def test_sac_bipedal(args=get_args()):
|
|||||||
result = offpolicy_trainer(
|
result = offpolicy_trainer(
|
||||||
policy, train_collector, test_collector, args.epoch,
|
policy, train_collector, test_collector, args.epoch,
|
||||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||||
args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer)
|
args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer,
|
||||||
|
test_in_train=False)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pprint.pprint(result)
|
pprint.pprint(result)
|
||||||
|
BIN
examples/box2d/results/sac/BipedalHardcore.png
Normal file
BIN
examples/box2d/results/sac/BipedalHardcore.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 40 KiB |
9
setup.py
9
setup.py
@ -1,12 +1,19 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
import os
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
|
||||||
|
def get_version() -> str:
|
||||||
|
# https://packaging.python.org/guides/single-sourcing-package-version/
|
||||||
|
init = open(os.path.join("tianshou", "__init__.py"), "r").read().split()
|
||||||
|
return init[init.index("__version__") + 2][1:-1]
|
||||||
|
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name='tianshou',
|
name='tianshou',
|
||||||
version='0.2.6',
|
version=get_version(),
|
||||||
description='A Library for Deep Reinforcement Learning',
|
description='A Library for Deep Reinforcement Learning',
|
||||||
long_description=open('README.md', encoding='utf8').read(),
|
long_description=open('README.md', encoding='utf8').read(),
|
||||||
long_description_content_type='text/markdown',
|
long_description_content_type='text/markdown',
|
||||||
|
@ -5,7 +5,7 @@ from tianshou import data, env, utils, policy, trainer, exploration
|
|||||||
utils.pre_compile()
|
utils.pre_compile()
|
||||||
|
|
||||||
|
|
||||||
__version__ = '0.2.6'
|
__version__ = '0.2.7'
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'env',
|
'env',
|
||||||
|
@ -8,7 +8,7 @@ from tianshou.policy import BasePolicy
|
|||||||
|
|
||||||
|
|
||||||
class ImitationPolicy(BasePolicy):
|
class ImitationPolicy(BasePolicy):
|
||||||
"""Implementation of vanilla imitation learning (for continuous action space).
|
"""Implementation of vanilla imitation learning.
|
||||||
|
|
||||||
:param torch.nn.Module model: a model following the rules in
|
:param torch.nn.Module model: a model following the rules in
|
||||||
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
:class:`~tianshou.policy.BasePolicy`. (s -> a)
|
||||||
|
@ -36,7 +36,9 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
# reward can be empty Batch (after initial reset) or nparray.
|
# reward can be empty Batch (after initial reset) or nparray.
|
||||||
has_rew = isinstance(buffer.rew, np.ndarray)
|
has_rew = isinstance(buffer.rew, np.ndarray)
|
||||||
if has_rew: # save the original reward in save_rew
|
if has_rew: # save the original reward in save_rew
|
||||||
save_rew, buffer.rew = buffer.rew, Batch()
|
# Since we do not override buffer.__setattr__, here we use _meta to
|
||||||
|
# change buffer.rew, otherwise buffer.rew = Batch() has no effect.
|
||||||
|
save_rew, buffer._meta.rew = buffer.rew, Batch()
|
||||||
for policy in self.policies:
|
for policy in self.policies:
|
||||||
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
agent_index = np.nonzero(batch.obs.agent_id == policy.agent_id)[0]
|
||||||
if len(agent_index) == 0:
|
if len(agent_index) == 0:
|
||||||
@ -45,11 +47,11 @@ class MultiAgentPolicyManager(BasePolicy):
|
|||||||
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
|
tmp_batch, tmp_indice = batch[agent_index], indice[agent_index]
|
||||||
if has_rew:
|
if has_rew:
|
||||||
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
|
tmp_batch.rew = tmp_batch.rew[:, policy.agent_id - 1]
|
||||||
buffer.rew = save_rew[:, policy.agent_id - 1]
|
buffer._meta.rew = save_rew[:, policy.agent_id - 1]
|
||||||
results[f'agent_{policy.agent_id}'] = \
|
results[f'agent_{policy.agent_id}'] = \
|
||||||
policy.process_fn(tmp_batch, buffer, tmp_indice)
|
policy.process_fn(tmp_batch, buffer, tmp_indice)
|
||||||
if has_rew: # restore from save_rew
|
if has_rew: # restore from save_rew
|
||||||
buffer.rew = save_rew
|
buffer._meta.rew = save_rew
|
||||||
return Batch(results)
|
return Batch(results)
|
||||||
|
|
||||||
def forward(self, batch: Batch,
|
def forward(self, batch: Batch,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user