add multi-GPU support (#461)

add a new class DataParallelNet
This commit is contained in:
Jiayi Weng 2021-10-05 13:39:14 -04:00 committed by GitHub
parent 5df64800f4
commit e45e2096d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 75 additions and 12 deletions

View File

@ -52,8 +52,9 @@ Here is Tianshou's other features:
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation) - Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process) - Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation - Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning) - Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-agent-reinforcement-learning)
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools - Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
- Support multi-GPU training [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-gpu)
- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions) - Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment. In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.

View File

@ -45,6 +45,7 @@ Here is Tianshou's other features:
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation * Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
* Support :doc:`/tutorials/tictactoe` * Support :doc:`/tutorials/tictactoe`
* Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools * Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools
* Support multi-GPU training :ref:`multi_gpu`
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking * Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_ 中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_

View File

@ -33,6 +33,7 @@ fn
boolean boolean
pre pre
np np
cuda
rnn rnn
rew rew
pre pre

View File

@ -197,6 +197,35 @@ The above code supports only stacked-observation. If you want to use stacked-act
- After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a. - After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a.
.. _multi_gpu:
Multi-GPU Training
------------------
To enable training an RL agent with multiple GPUs for a standard environment (i.e., without nested observation) with default networks provided by Tianshou:
1. Import :class:`~tianshou.utils.net.common.DataParallelNet` from ``tianshou.utils.net.common``;
2. Change the ``device`` argument to ``None`` in the existing networks such as ``Net``, ``Actor``, ``Critic``, ``ActorProb``
3. Apply ``DataParallelNet`` wrapper to these networks.
::
from tianshou.utils.net.common import Net, DataParallelNet
from tianshou.utils.net.discrete import Actor, Critic
actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device))
critic = DataParallelNet(Critic(net, device=None).to(args.device))
Yes, that's all! This general approach can be applied to almost all kinds of algorithms implemented in Tianshou.
We provide a complete script to show how to run multi-GPU: `test/discrete/test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_ppo.py>`_
As for other cases such as customized network or environments that have a nested observation, here are the rules:
1. The data format transformation (numpy -> cuda) is done in the ``DataParallelNet`` wrapper; your customized network should not apply any kinds of data format transformation;
2. Create a similar class that inherit ``DataParallelNet``, which is only in charge of data format transformation (numpy -> cuda);
3. Do the same things above.
.. _self_defined_env: .. _self_defined_env:
User-defined Environment and Different State Representation User-defined Environment and Different State Representation

View File

@ -110,8 +110,7 @@ def test_ppo(args=get_args()):
reward_normalization=args.rew_norm, reward_normalization=args.rew_norm,
advantage_normalization=args.norm_adv, advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv, recompute_advantage=args.recompute_adv,
# dual_clip=args.dual_clip, dual_clip=args.dual_clip,
# dual clip cause monotonically increasing log_std :)
value_clip=args.value_clip, value_clip=args.value_clip,
gae_lambda=args.gae_lambda, gae_lambda=args.gae_lambda,
action_space=env.action_space action_space=env.action_space

View File

@ -8,11 +8,11 @@ import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import SubprocVectorEnv
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.trainer import onpolicy_trainer from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net
from tianshou.utils.net.discrete import Actor, Critic from tianshou.utils.net.discrete import Actor, Critic
@ -57,11 +57,11 @@ def test_ppo(args=get_args()):
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task) # train_envs = gym.make(args.task)
# you can also use tianshou.env.SubprocVectorEnv # you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv( train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)] [lambda: gym.make(args.task) for _ in range(args.training_num)]
) )
# test_envs = gym.make(args.task) # test_envs = gym.make(args.task)
test_envs = DummyVectorEnv( test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)] [lambda: gym.make(args.task) for _ in range(args.test_num)]
) )
# seed # seed
@ -71,8 +71,14 @@ def test_ppo(args=get_args()):
test_envs.seed(args.seed) test_envs.seed(args.seed)
# model # model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device) if torch.cuda.is_available():
critic = Critic(net, device=args.device).to(args.device) actor = DataParallelNet(
Actor(net, args.action_shape, device=None).to(args.device)
)
critic = DataParallelNet(Critic(net, device=None).to(args.device))
else:
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic) actor_critic = ActorCritic(actor, critic)
# orthogonal initialization # orthogonal initialization
for m in actor_critic.modules(): for m in actor_critic.modules():

View File

@ -87,9 +87,14 @@ class MLP(nn.Module):
self.output_dim = output_dim or hidden_sizes[-1] self.output_dim = output_dim or hidden_sizes[-1]
self.model = nn.Sequential(*model) self.model = nn.Sequential(*model)
def forward(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: def forward(self, s: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
x = torch.as_tensor(x, device=self.device, dtype=torch.float32) # type: ignore if self.device is not None:
return self.model(x.flatten(1)) s = torch.as_tensor(
s,
device=self.device, # type: ignore
dtype=torch.float32,
)
return self.model(s.flatten(1)) # type: ignore
class Net(nn.Module): class Net(nn.Module):
@ -278,3 +283,24 @@ class ActorCritic(nn.Module):
super().__init__() super().__init__()
self.actor = actor self.actor = actor
self.critic = critic self.critic = critic
class DataParallelNet(nn.Module):
"""DataParallel wrapper for training agent with multi-GPU.
This class does only the conversion of input data type, from numpy array to torch's
Tensor. If the input is a nested dictionary, the user should create a similar class
to do the same thing.
:param nn.Module net: the network to be distributed in different GPUs.
"""
def __init__(self, net: nn.Module) -> None:
super().__init__()
self.net = nn.DataParallel(net)
def forward(self, s: Union[np.ndarray, torch.Tensor], *args: Any,
**kwargs: Any) -> Tuple[Any, Any]:
if not isinstance(s, torch.Tensor):
s = torch.as_tensor(s, dtype=torch.float32)
return self.net(s=s.cuda(), *args, **kwargs)