parent
5df64800f4
commit
e45e2096d8
@ -52,8 +52,9 @@ Here is Tianshou's other features:
|
||||
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
|
||||
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html##multi-agent-reinforcement-learning)
|
||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-agent-reinforcement-learning)
|
||||
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
|
||||
- Support multi-GPU training [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-gpu)
|
||||
- Comprehensive documentation, PEP8 code-style checking, type checking and [unit tests](https://github.com/thu-ml/tianshou/actions)
|
||||
|
||||
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
||||
|
@ -45,6 +45,7 @@ Here is Tianshou's other features:
|
||||
* Support n-step returns estimation :meth:`~tianshou.policy.BasePolicy.compute_nstep_return` and prioritized experience replay :class:`~tianshou.data.PrioritizedReplayBuffer` for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||
* Support :doc:`/tutorials/tictactoe`
|
||||
* Support both `TensorBoard <https://www.tensorflow.org/tensorboard>`_ and `W&B <https://wandb.ai/>`_ log tools
|
||||
* Support multi-GPU training :ref:`multi_gpu`
|
||||
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
|
||||
|
||||
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_
|
||||
|
@ -33,6 +33,7 @@ fn
|
||||
boolean
|
||||
pre
|
||||
np
|
||||
cuda
|
||||
rnn
|
||||
rew
|
||||
pre
|
||||
|
@ -197,6 +197,35 @@ The above code supports only stacked-observation. If you want to use stacked-act
|
||||
- After applying wrapper: ``([s, a], a, [s', a'], r, d)`` stored in replay buffer, and get both stacked s and a.
|
||||
|
||||
|
||||
.. _multi_gpu:
|
||||
|
||||
Multi-GPU Training
|
||||
------------------
|
||||
|
||||
To enable training an RL agent with multiple GPUs for a standard environment (i.e., without nested observation) with default networks provided by Tianshou:
|
||||
|
||||
1. Import :class:`~tianshou.utils.net.common.DataParallelNet` from ``tianshou.utils.net.common``;
|
||||
2. Change the ``device`` argument to ``None`` in the existing networks such as ``Net``, ``Actor``, ``Critic``, ``ActorProb``
|
||||
3. Apply ``DataParallelNet`` wrapper to these networks.
|
||||
|
||||
::
|
||||
|
||||
from tianshou.utils.net.common import Net, DataParallelNet
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
actor = DataParallelNet(Actor(net, args.action_shape, device=None).to(args.device))
|
||||
critic = DataParallelNet(Critic(net, device=None).to(args.device))
|
||||
|
||||
Yes, that's all! This general approach can be applied to almost all kinds of algorithms implemented in Tianshou.
|
||||
We provide a complete script to show how to run multi-GPU: `test/discrete/test_ppo.py <https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_ppo.py>`_
|
||||
|
||||
As for other cases such as customized network or environments that have a nested observation, here are the rules:
|
||||
|
||||
1. The data format transformation (numpy -> cuda) is done in the ``DataParallelNet`` wrapper; your customized network should not apply any kinds of data format transformation;
|
||||
2. Create a similar class that inherit ``DataParallelNet``, which is only in charge of data format transformation (numpy -> cuda);
|
||||
3. Do the same things above.
|
||||
|
||||
|
||||
.. _self_defined_env:
|
||||
|
||||
User-defined Environment and Different State Representation
|
||||
|
@ -110,8 +110,7 @@ def test_ppo(args=get_args()):
|
||||
reward_normalization=args.rew_norm,
|
||||
advantage_normalization=args.norm_adv,
|
||||
recompute_advantage=args.recompute_adv,
|
||||
# dual_clip=args.dual_clip,
|
||||
# dual clip cause monotonically increasing log_std :)
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip,
|
||||
gae_lambda=args.gae_lambda,
|
||||
action_space=env.action_space
|
||||
|
@ -8,11 +8,11 @@ import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.common import ActorCritic, DataParallelNet, Net
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
@ -57,11 +57,11 @@ def test_ppo(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)]
|
||||
)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = DummyVectorEnv(
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)]
|
||||
)
|
||||
# seed
|
||||
@ -71,8 +71,14 @@ def test_ppo(args=get_args()):
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
if torch.cuda.is_available():
|
||||
actor = DataParallelNet(
|
||||
Actor(net, args.action_shape, device=None).to(args.device)
|
||||
)
|
||||
critic = DataParallelNet(Critic(net, device=None).to(args.device))
|
||||
else:
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
critic = Critic(net, device=args.device).to(args.device)
|
||||
actor_critic = ActorCritic(actor, critic)
|
||||
# orthogonal initialization
|
||||
for m in actor_critic.modules():
|
||||
|
@ -87,9 +87,14 @@ class MLP(nn.Module):
|
||||
self.output_dim = output_dim or hidden_sizes[-1]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
x = torch.as_tensor(x, device=self.device, dtype=torch.float32) # type: ignore
|
||||
return self.model(x.flatten(1))
|
||||
def forward(self, s: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
||||
if self.device is not None:
|
||||
s = torch.as_tensor(
|
||||
s,
|
||||
device=self.device, # type: ignore
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return self.model(s.flatten(1)) # type: ignore
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
@ -278,3 +283,24 @@ class ActorCritic(nn.Module):
|
||||
super().__init__()
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
|
||||
|
||||
class DataParallelNet(nn.Module):
|
||||
"""DataParallel wrapper for training agent with multi-GPU.
|
||||
|
||||
This class does only the conversion of input data type, from numpy array to torch's
|
||||
Tensor. If the input is a nested dictionary, the user should create a similar class
|
||||
to do the same thing.
|
||||
|
||||
:param nn.Module net: the network to be distributed in different GPUs.
|
||||
"""
|
||||
|
||||
def __init__(self, net: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.net = nn.DataParallel(net)
|
||||
|
||||
def forward(self, s: Union[np.ndarray, torch.Tensor], *args: Any,
|
||||
**kwargs: Any) -> Tuple[Any, Any]:
|
||||
if not isinstance(s, torch.Tensor):
|
||||
s = torch.as_tensor(s, dtype=torch.float32)
|
||||
return self.net(s=s.cuda(), *args, **kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user