code refactor for venv (#179)
- Refacor code to remove duplicate code - Enable async simulation for all vector envs - Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv` The abstraction of vector env changed. Prior to this pr, each vector env is almost independent. After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility. Co-authored-by: n+e <463003665@qq.com> Co-authored-by: magicly <magicly007@gmail.com>
This commit is contained in:
parent
311a2beafb
commit
a9f9940d17
@ -34,7 +34,7 @@
|
||||
Here is Tianshou's other features:
|
||||
|
||||
- Elegant framework, using only ~2000 lines of code
|
||||
- Support parallel environment sampling for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
|
||||
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
|
||||
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
|
||||
- Support any type of environment state (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#customize-training-process)
|
||||
@ -152,7 +152,7 @@ Within this API, we can interact with different policies conveniently.
|
||||
|
||||
### Elegant and Flexible
|
||||
|
||||
Currently, the overall code of Tianshou platform is less than 1500 lines without environment wrappers for Atari and Mujoco. Most of the implemented algorithms are less than 100 lines of python code. It is quite easy to go through the framework and understand how it works. We provide many flexible API as you wish, for instance, if you want to use your policy to interact with the environment with (at least) `n` steps:
|
||||
Currently, the overall code of Tianshou platform is less than 2500 lines. Most of the implemented algorithms are less than 100 lines of python code. It is quite easy to go through the framework and understand how it works. We provide many flexible API as you wish, for instance, if you want to use your policy to interact with the environment with (at least) `n` steps:
|
||||
|
||||
```python
|
||||
result = collector.collect(n_step=n)
|
||||
@ -201,8 +201,8 @@ Make environments:
|
||||
|
||||
```python
|
||||
# you can also try with SubprocVectorEnv
|
||||
train_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
||||
test_envs = ts.env.VectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
||||
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
||||
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
||||
```
|
||||
|
||||
Define the network:
|
||||
@ -249,7 +249,6 @@ Watch the performance with 35 FPS:
|
||||
```python
|
||||
collector = ts.data.Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.close()
|
||||
```
|
||||
|
||||
Look at the result saved in tensorboard: (with bash script in your terminal)
|
||||
|
BIN
docs/_static/images/async.png
vendored
Normal file
BIN
docs/_static/images/async.png
vendored
Normal file
Binary file not shown.
After Width: | Height: | Size: 55 KiB |
@ -5,3 +5,8 @@ tianshou.env
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
.. automodule:: tianshou.env.worker
|
||||
:members:
|
||||
:undoc-members:
|
||||
:show-inheritance:
|
||||
|
@ -31,22 +31,50 @@ See :ref:`customized_trainer`.
|
||||
Parallel Sampling
|
||||
-----------------
|
||||
|
||||
Use :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv` or :class:`~tianshou.env.ShmemVectorEnv`.
|
||||
::
|
||||
Tianshou provides the following classes for parallel environment simulation:
|
||||
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=2),
|
||||
lambda: MyTestEnv(size=3),
|
||||
lambda: MyTestEnv(size=4),
|
||||
lambda: MyTestEnv(size=5),
|
||||
]
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
- :class:`~tianshou.env.DummyVectorEnv` is for pseudo-parallel simulation (implemented with a for-loop, useful for debugging).
|
||||
|
||||
- :class:`~tianshou.env.SubprocVectorEnv` uses multiple processes for parallel simulation. This is the most often choice for parallel simulation.
|
||||
|
||||
- :class:`~tianshou.env.ShmemVectorEnv` has a similar implementation to :class:`~tianshou.env.SubprocVectorEnv`, but is optimized (in terms of both memory footprint and simulation speed) for environments with large observations such as images.
|
||||
|
||||
- :class:`~tianshou.env.RayVectorEnv` is currently the only choice for parallel simulation in a cluster with multiple machines.
|
||||
|
||||
Although these classes are optimized for different scenarios, they have exactly the same APIs because they are sub-classes of :class:`~tianshou.env.BaseVectorEnv`. Just provide a list of functions who return environments upon called, and it is all set.
|
||||
|
||||
where ``env_fns`` is a list of callable env hooker. The above code can be written in for-loop as well:
|
||||
::
|
||||
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
venv = SubprocVectorEnv(env_fns) # DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like.
|
||||
venv.reset() # returns the initial observations of each environment
|
||||
venv.step(actions) # provide actions for each environment and get their results
|
||||
|
||||
.. sidebar:: An example of sync/async VectorEnv (steps with the same color end up in one batch that is disposed by the policy at the same time).
|
||||
|
||||
.. Figure:: ../_static/images/async.png
|
||||
|
||||
By default, parallel environment simulation is synchronous: a step is done after all environments have finished a step. Synchronous simulation works well if each step of environments costs roughly the same time.
|
||||
|
||||
In case the time cost of environments varies a lot (e.g. 90% step cost 1s, but 10% cost 10s) where slow environments lag fast environments behind, async simulation can be used (related to `Issue 103 <https://github.com/thu-ml/tianshou/issues/103>`_). The idea is to start those finished environments without waiting for slow environments.
|
||||
|
||||
Asynchronous simulation is a built-in functionality of :class:`~tianshou.env.BaseVectorEnv`. Just provide ``wait_num`` or ``timeout`` (or both) and async simulation works.
|
||||
|
||||
::
|
||||
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=x) for i in [2, 3, 4, 5]]
|
||||
# DummyVectorEnv, ShmemVectorEnv, or RayVectorEnv, whichever you like.
|
||||
venv = SubprocVectorEnv(env_fns, wait_num=3, timeout=0.2)
|
||||
venv.reset() # returns the initial observations of each environment
|
||||
# returns ``wait_num`` steps or finished steps after ``timeout`` seconds,
|
||||
# whichever occurs first.
|
||||
venv.step(actions, ready_id)
|
||||
|
||||
If we have 4 envs and set ``wait_num = 3``, each of the step only returns 3 results of these 4 envs.
|
||||
|
||||
You can treat the ``timeout`` parameter as a dynamic ``wait_num``. In each vectorized step it only returns the environments finished within the given time. If there is no such environment, it will wait until any of them finished.
|
||||
|
||||
The figure in the right gives an intuitive comparison among synchronous/asynchronous simulation.
|
||||
|
||||
.. warning::
|
||||
|
||||
@ -139,9 +167,9 @@ First of all, your self-defined environment must follow the Gym's API, some of t
|
||||
|
||||
- step(action) -> state, reward, done, info
|
||||
|
||||
- seed(s) -> None
|
||||
- seed(s) -> List[int]
|
||||
|
||||
- render(mode) -> None
|
||||
- render(mode) -> Any
|
||||
|
||||
- close() -> None
|
||||
|
||||
|
@ -30,11 +30,11 @@ It is available if you want the original ``gym.Env``:
|
||||
train_envs = gym.make('CartPole-v0')
|
||||
test_envs = gym.make('CartPole-v0')
|
||||
|
||||
Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.VectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows:
|
||||
Tianshou supports parallel sampling for all algorithms. It provides four types of vectorized environment wrapper: :class:`~tianshou.env.DummyVectorEnv`, :class:`~tianshou.env.SubprocVectorEnv`, :class:`~tianshou.env.ShmemVectorEnv`, and :class:`~tianshou.env.RayVectorEnv`. It can be used as follows: (more explanation can be found at :ref:`parallel_sampling`)
|
||||
::
|
||||
|
||||
train_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
|
||||
test_envs = ts.env.VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
|
||||
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(8)])
|
||||
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
|
||||
|
||||
Here, we set up 8 environments in ``train_envs`` and 100 environments in ``test_envs``.
|
||||
|
||||
@ -178,7 +178,6 @@ Watch the Agent's Performance
|
||||
|
||||
collector = ts.data.Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.close()
|
||||
|
||||
.. _customized_trainer:
|
||||
|
||||
|
@ -158,7 +158,6 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
|
||||
===x _ o x _ _===
|
||||
===x _ _ _ x x===
|
||||
=================
|
||||
>>> collector.close()
|
||||
|
||||
Random agents perform badly. In the above game, although agent 2 wins finally, it is clear that a smart agent 1 would place an ``x`` at row 4 col 4 to win directly.
|
||||
|
||||
@ -175,7 +174,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
|
||||
from copy import deepcopy
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
@ -220,8 +219,7 @@ The explanation of each Tianshou class/function will be deferred to their first
|
||||
help='the path of opponent agent pth file for resuming from a pre-trained agent')
|
||||
parser.add_argument('--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
.. sidebar:: The relationship between MultiAgentPolicyManager (Manager) and BasePolicy (Agent)
|
||||
|
||||
@ -290,15 +288,14 @@ With the above preparation, we are close to the first learned agent. The followi
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
if args.watch:
|
||||
watch(args)
|
||||
exit(0)
|
||||
|
||||
# ======== environment setup =========
|
||||
env_func = lambda: TicTacToeEnv(args.board_size, args.win_size)
|
||||
train_envs = VectorEnv([env_func for _ in range(args.training_num)])
|
||||
test_envs = VectorEnv([env_func for _ in range(args.test_num)])
|
||||
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
|
||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -351,9 +348,6 @@ With the above preparation, we are close to the first learned agent. The followi
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||
test_in_train=False)
|
||||
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
agent = policy.policies[args.agent_id - 1]
|
||||
# let's watch the match!
|
||||
watch(args, agent)
|
||||
|
@ -40,8 +40,7 @@ def get_args():
|
||||
parser.add_argument('--ent-coef', type=float, default=0.001)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=None)
|
||||
parser.add_argument('--max_episode_steps', type=int, default=2000)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
@ -90,8 +89,6 @@ def test_a2c(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -99,7 +96,6 @@ def test_a2c(args=get_args()):
|
||||
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -36,8 +36,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_dqn(args=get_args()):
|
||||
@ -96,8 +95,6 @@ def test_dqn(args=get_args()):
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, writer=writer)
|
||||
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -105,7 +102,6 @@ def test_dqn(args=get_args()):
|
||||
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -40,8 +40,7 @@ def get_args():
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--max_episode_steps', type=int, default=2000)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_ppo(args=get_args()):
|
||||
@ -94,8 +93,6 @@ def test_ppo(args=get_args()):
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.repeat_per_collect,
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -103,7 +100,6 @@ def test_ppo(args=get_args()):
|
||||
collector = Collector(policy, env, preprocess_fn=preprocess_fn)
|
||||
result = collector.collect(n_step=2000, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
@ -36,8 +36,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_dqn(args=get_args()):
|
||||
@ -46,10 +45,10 @@ def test_dqn(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -100,8 +99,6 @@ def test_dqn(args=get_args()):
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -109,7 +106,6 @@ def test_dqn(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -39,8 +39,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class EnvWrapper(object):
|
||||
@ -136,7 +135,6 @@ def test_sac_bipedal(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=IsStop, save_fn=save_fn, writer=writer)
|
||||
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -144,7 +142,6 @@ def test_sac_bipedal(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=16, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
109
examples/box2d/lunarlander_dqn.py
Normal file
109
examples/box2d/lunarlander_dqn.py
Normal file
@ -0,0 +1,109 @@
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
# the parameters are found by Optuna
|
||||
parser.add_argument('--task', type=str, default='LunarLander-v2')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||
parser.add_argument('--eps-train', type=float, default=0.73)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.013)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--n-step', type=int, default=4)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--collect-per-step', type=int, default=16)
|
||||
parser.add_argument('--batch-size', type=int, default=128)
|
||||
parser.add_argument('--layer-num', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_dqn(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.layer_num, args.state_shape,
|
||||
args.action_shape, args.device,
|
||||
dueling=(2, 2)).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
policy = DQNPolicy(
|
||||
net, optim, args.gamma, args.n_step,
|
||||
target_update_freq=args.target_update_freq)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
return x >= env.spec.reward_threshold
|
||||
|
||||
def train_fn(x):
|
||||
args.eps_train = max(args.eps_train * 0.6, 0.01)
|
||||
policy.set_eps(args.eps_train)
|
||||
|
||||
def test_fn(x):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||
test_in_train=False)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dqn(get_args())
|
@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.exploration import OUNoise
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
@ -41,8 +41,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_sac(args=get_args()):
|
||||
@ -51,10 +50,10 @@ def test_sac(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -110,8 +109,6 @@ def test_sac(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -119,7 +116,6 @@ def test_sac(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
@ -36,8 +36,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_ddpg(args=get_args()):
|
||||
@ -46,7 +45,7 @@ def test_ddpg(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
@ -86,8 +85,6 @@ def test_ddpg(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -95,7 +92,6 @@ def test_ddpg(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -9,7 +9,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
@ -37,8 +37,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_sac(args=get_args()):
|
||||
@ -47,7 +46,7 @@ def test_sac(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
@ -96,8 +95,6 @@ def test_sac(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -105,7 +102,6 @@ def test_sac(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
@ -39,8 +39,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_td3(args=get_args()):
|
||||
@ -49,7 +48,7 @@ def test_td3(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
@ -96,8 +95,6 @@ def test_td3(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -105,7 +102,6 @@ def test_td3(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -43,8 +43,7 @@ def get_args():
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_sac(args=get_args()):
|
||||
@ -102,8 +101,6 @@ def test_sac(args=get_args()):
|
||||
args.batch_size, stop_fn=stop_fn,
|
||||
writer=writer, log_interval=args.log_interval)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -111,7 +108,6 @@ def test_sac(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
@ -8,7 +8,7 @@ from torch.utils.tensorboard import SummaryWriter
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
@ -41,9 +41,7 @@ def get_args():
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--max_episode_steps', type=int, default=2000)
|
||||
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_td3(args=get_args()):
|
||||
@ -53,7 +51,7 @@ def test_td3(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
@ -103,8 +101,6 @@ def test_td3(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -112,7 +108,6 @@ def test_td3(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_step=1000, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
1
setup.py
1
setup.py
@ -60,6 +60,7 @@ setup(
|
||||
'flake8',
|
||||
'pytest',
|
||||
'pytest-cov',
|
||||
'ray>=0.8.0',
|
||||
],
|
||||
'atari': [
|
||||
'atari_py',
|
||||
|
@ -46,6 +46,7 @@ class MyTestEnv(gym.Env):
|
||||
|
||||
def seed(self, seed=0):
|
||||
self.rng = np.random.RandomState(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self, state=0):
|
||||
self.done = False
|
||||
|
@ -2,7 +2,7 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv, AsyncVectorEnv
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
from tianshou.data import Collector, Batch, ReplayBuffer
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -12,14 +12,24 @@ else: # pytest
|
||||
|
||||
|
||||
class MyPolicy(BasePolicy):
|
||||
def __init__(self, dict_state=False):
|
||||
def __init__(self, dict_state: bool = False, need_state: bool = True):
|
||||
"""
|
||||
:param bool dict_state: if the observation of the environment is a dict
|
||||
:param bool need_state: if the policy needs the hidden state (for RNN)
|
||||
"""
|
||||
super().__init__()
|
||||
self.dict_state = dict_state
|
||||
self.need_state = need_state
|
||||
|
||||
def forward(self, batch, state=None):
|
||||
if self.need_state:
|
||||
if state is None:
|
||||
state = np.zeros((len(batch.obs), 2))
|
||||
else:
|
||||
state += 1
|
||||
if self.dict_state:
|
||||
return Batch(act=np.ones(len(batch.obs['index'])))
|
||||
return Batch(act=np.ones(len(batch.obs)))
|
||||
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
|
||||
return Batch(act=np.ones(len(batch.obs)), state=state)
|
||||
|
||||
def learn(self):
|
||||
pass
|
||||
@ -66,32 +76,32 @@ def test_collector():
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
dum = VectorEnv(env_fns)
|
||||
dum = DummyVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
logger.preprocess_fn)
|
||||
c0.collect(n_step=3)
|
||||
assert np.allclose(c0.buffer.obs[:4], np.expand_dims(
|
||||
[0, 1, 0, 1], axis=-1))
|
||||
assert np.allclose(c0.buffer[:4].obs_next, np.expand_dims(
|
||||
[1, 2, 1, 2], axis=-1))
|
||||
assert np.allclose(c0.buffer.obs[:4],
|
||||
np.expand_dims([0, 1, 0, 1], axis=-1))
|
||||
assert np.allclose(c0.buffer[:4].obs_next,
|
||||
np.expand_dims([1, 2, 1, 2], axis=-1))
|
||||
c0.collect(n_episode=3)
|
||||
assert np.allclose(c0.buffer.obs[:10], np.expand_dims(
|
||||
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1], axis=-1))
|
||||
assert np.allclose(c0.buffer[:10].obs_next, np.expand_dims(
|
||||
[1, 2, 1, 2, 1, 2, 1, 2, 1, 2], axis=-1))
|
||||
assert np.allclose(c0.buffer.obs[:10],
|
||||
np.expand_dims([0, 1, 0, 1, 0, 1, 0, 1, 0, 1], axis=-1))
|
||||
assert np.allclose(c0.buffer[:10].obs_next,
|
||||
np.expand_dims([1, 2, 1, 2, 1, 2, 1, 2, 1, 2], axis=-1))
|
||||
c0.collect(n_step=3, random=True)
|
||||
c1 = Collector(policy, venv, ReplayBuffer(size=100, ignore_obs_next=False),
|
||||
logger.preprocess_fn)
|
||||
c1.collect(n_step=6)
|
||||
assert np.allclose(c1.buffer.obs[:11], np.expand_dims(
|
||||
[0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3], axis=-1))
|
||||
assert np.allclose(c1.buffer[:11].obs_next, np.expand_dims([
|
||||
1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4], axis=-1))
|
||||
[0, 1, 0, 1, 2, 0, 1, 0, 1, 2, 3], axis=-1))
|
||||
assert np.allclose(c1.buffer[:11].obs_next, np.expand_dims(
|
||||
[1, 2, 1, 2, 3, 1, 2, 1, 2, 3, 4], axis=-1))
|
||||
c1.collect(n_episode=2)
|
||||
assert np.allclose(c1.buffer.obs[11:21], np.expand_dims(
|
||||
[0, 1, 2, 3, 4, 0, 1, 0, 1, 2], axis=-1))
|
||||
assert np.allclose(c1.buffer.obs[11:21],
|
||||
np.expand_dims([0, 1, 2, 3, 4, 0, 1, 0, 1, 2], axis=-1))
|
||||
assert np.allclose(c1.buffer[11:21].obs_next,
|
||||
np.expand_dims([1, 2, 3, 4, 5, 1, 2, 1, 2, 3], axis=-1))
|
||||
c1.collect(n_episode=3, random=True)
|
||||
@ -116,7 +126,7 @@ def test_collector_with_async():
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.1, random_sleep=True)
|
||||
for i in env_lens]
|
||||
|
||||
venv = AsyncVectorEnv(env_fns)
|
||||
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
|
||||
policy = MyPolicy()
|
||||
c1 = Collector(policy, venv,
|
||||
ReplayBuffer(size=1000, ignore_obs_next=False),
|
||||
@ -129,8 +139,6 @@ def test_collector_with_async():
|
||||
size = len(c1.buffer)
|
||||
obs = c1.buffer.obs[:size]
|
||||
done = c1.buffer.done[:size]
|
||||
print(env_id[:size])
|
||||
print(obs)
|
||||
obs_ground_truth = []
|
||||
i = 0
|
||||
while i < size:
|
||||
@ -165,7 +173,7 @@ def test_collector_with_dict_state():
|
||||
c0.collect(n_episode=2)
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
||||
for i in [2, 3, 4, 5]]
|
||||
envs = VectorEnv(env_fns)
|
||||
envs = DummyVectorEnv(env_fns)
|
||||
envs.seed(666)
|
||||
obs = envs.reset()
|
||||
assert not np.isclose(obs[0]['rand'], obs[1]['rand'])
|
||||
@ -185,7 +193,6 @@ def test_collector_with_dict_state():
|
||||
Logger.single_preprocess_fn)
|
||||
c2.collect(n_episode=[0, 0, 0, 10])
|
||||
batch, _ = c2.buffer.sample(10)
|
||||
print(batch['obs_next']['index'])
|
||||
|
||||
|
||||
def test_collector_with_ma():
|
||||
@ -202,7 +209,7 @@ def test_collector_with_ma():
|
||||
assert np.asanyarray(r).size == 1 and r == 4.
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4)
|
||||
for i in [2, 3, 4, 5]]
|
||||
envs = VectorEnv(env_fns)
|
||||
envs = DummyVectorEnv(env_fns)
|
||||
c1 = Collector(policy, envs, ReplayBuffer(size=100),
|
||||
Logger.single_preprocess_fn, reward_metric=reward_metric)
|
||||
r = c1.collect(n_step=10)['rew']
|
||||
@ -227,7 +234,6 @@ def test_collector_with_ma():
|
||||
r = c2.collect(n_episode=[0, 0, 0, 10])['rew']
|
||||
assert np.asanyarray(r).size == 1 and r == 4.
|
||||
batch, _ = c2.buffer.sample(10)
|
||||
print(batch['obs_next'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -2,8 +2,8 @@ import time
|
||||
import numpy as np
|
||||
from gym.spaces.discrete import Discrete
|
||||
from tianshou.data import Batch
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv, \
|
||||
RayVectorEnv, AsyncVectorEnv, ShmemVectorEnv
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv, \
|
||||
ShmemVectorEnv, RayVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -11,6 +11,14 @@ else: # pytest
|
||||
from test.base.env import MyTestEnv
|
||||
|
||||
|
||||
def has_ray():
|
||||
try:
|
||||
import ray
|
||||
return hasattr(ray, 'init') # avoid PEP8 F401 Error
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def recurse_comp(a, b):
|
||||
try:
|
||||
if isinstance(a, np.ndarray):
|
||||
@ -29,79 +37,111 @@ def recurse_comp(a, b):
|
||||
return False
|
||||
|
||||
|
||||
def test_async_env(num=8, sleep=0.1):
|
||||
def test_async_env(size=10000, num=8, sleep=0.1):
|
||||
# simplify the test case, just keep stepping
|
||||
size = 10000
|
||||
env_fns = [
|
||||
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
|
||||
for i in range(size, size + num)
|
||||
]
|
||||
v = AsyncVectorEnv(env_fns, wait_num=num // 2)
|
||||
v.seed()
|
||||
v.reset()
|
||||
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
|
||||
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
|
||||
# expectation of v is n / (n + 1)
|
||||
# for a synchronous environment, the following actions should take
|
||||
# about 7 * sleep * num / (num + 1) seconds
|
||||
# for AsyncVectorEnv, the analysis is complicated, but the time cost
|
||||
# should be smaller
|
||||
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
|
||||
current_index_start = 0
|
||||
action = action_list[:num]
|
||||
env_ids = list(range(num))
|
||||
o = []
|
||||
spent_time = time.time()
|
||||
while current_index_start < len(action_list):
|
||||
A, B, C, D = v.step(action=action, id=env_ids)
|
||||
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
|
||||
env_ids = b.info.env_id
|
||||
o.append(b)
|
||||
current_index_start += len(action)
|
||||
# len of action may be smaller than len(A) in the end
|
||||
action = action_list[current_index_start: current_index_start + len(A)]
|
||||
# truncate env_ids with the first terms
|
||||
# typically len(env_ids) == len(A) == len(action), except for the
|
||||
# last batch when actions are not enough
|
||||
env_ids = env_ids[: len(action)]
|
||||
spent_time = time.time() - spent_time
|
||||
data = Batch.cat(o)
|
||||
# assure 1/7 improvement
|
||||
assert spent_time < 6.0 * sleep * num / (num + 1)
|
||||
return spent_time, data
|
||||
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
|
||||
if has_ray():
|
||||
test_cls += [RayVectorEnv]
|
||||
for cls in test_cls:
|
||||
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
|
||||
v.reset()
|
||||
# for a random variable u ~ U[0, 1], let v = max{u1, u2, ..., un}
|
||||
# P(v <= x) = x^n (0 <= x <= 1), pdf of v is nx^{n-1}
|
||||
# expectation of v is n / (n + 1)
|
||||
# for a synchronous environment, the following actions should take
|
||||
# about 7 * sleep * num / (num + 1) seconds
|
||||
# for async simulation, the analysis is complicated, but the time cost
|
||||
# should be smaller
|
||||
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
|
||||
current_idx_start = 0
|
||||
action = action_list[:num]
|
||||
env_ids = list(range(num))
|
||||
o = []
|
||||
spent_time = time.time()
|
||||
while current_idx_start < len(action_list):
|
||||
A, B, C, D = v.step(action=action, id=env_ids)
|
||||
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
|
||||
env_ids = b.info.env_id
|
||||
o.append(b)
|
||||
current_idx_start += len(action)
|
||||
# len of action may be smaller than len(A) in the end
|
||||
action = action_list[current_idx_start:current_idx_start + len(A)]
|
||||
# truncate env_ids with the first terms
|
||||
# typically len(env_ids) == len(A) == len(action), except for the
|
||||
# last batch when actions are not enough
|
||||
env_ids = env_ids[: len(action)]
|
||||
spent_time = time.time() - spent_time
|
||||
Batch.cat(o)
|
||||
v.close()
|
||||
# assure 1/7 improvement
|
||||
assert spent_time < 6.0 * sleep * num / (num + 1)
|
||||
|
||||
|
||||
def test_async_check_id(size=100, num=4, sleep=.2, timeout=.7):
|
||||
env_fns = [lambda: MyTestEnv(size=size, sleep=sleep * 2),
|
||||
lambda: MyTestEnv(size=size, sleep=sleep * 3),
|
||||
lambda: MyTestEnv(size=size, sleep=sleep * 5),
|
||||
lambda: MyTestEnv(size=size, sleep=sleep * 7)]
|
||||
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
|
||||
if has_ray():
|
||||
test_cls += [RayVectorEnv]
|
||||
for cls in test_cls:
|
||||
v = cls(env_fns, wait_num=num - 1, timeout=timeout)
|
||||
v.reset()
|
||||
expect_result = [
|
||||
[0, 1],
|
||||
[0, 1, 2],
|
||||
[0, 1, 3],
|
||||
[0, 1, 2],
|
||||
[0, 1],
|
||||
[0, 2, 3],
|
||||
[0, 1],
|
||||
]
|
||||
ids = np.arange(num)
|
||||
for res in expect_result:
|
||||
t = time.time()
|
||||
_, _, _, info = v.step([1] * len(ids), ids)
|
||||
t = time.time() - t
|
||||
ids = Batch(info).env_id
|
||||
print(ids, t)
|
||||
if cls != RayVectorEnv: # ray-project/ray#10134
|
||||
assert np.allclose(sorted(ids), res)
|
||||
assert (t < timeout) == (len(res) == num - 1)
|
||||
|
||||
|
||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [
|
||||
lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
|
||||
for i in range(size, size + num)
|
||||
]
|
||||
venv = [
|
||||
VectorEnv(env_fns),
|
||||
DummyVectorEnv(env_fns),
|
||||
SubprocVectorEnv(env_fns),
|
||||
ShmemVectorEnv(env_fns),
|
||||
]
|
||||
if verbose:
|
||||
venv.append(RayVectorEnv(env_fns))
|
||||
if has_ray():
|
||||
venv += [RayVectorEnv(env_fns)]
|
||||
for v in venv:
|
||||
v.seed(0)
|
||||
action_list = [1] * 5 + [0] * 10 + [1] * 20
|
||||
if not verbose:
|
||||
o = [v.reset() for v in venv]
|
||||
for i, a in enumerate(action_list):
|
||||
o = []
|
||||
for v in venv:
|
||||
A, B, C, D = v.step([a] * num)
|
||||
if sum(C):
|
||||
A = v.reset(np.where(C)[0])
|
||||
o.append([A, B, C, D])
|
||||
for index, infos in enumerate(zip(*o)):
|
||||
if index == 3: # do not check info here
|
||||
continue
|
||||
for info in infos:
|
||||
assert recurse_comp(infos[0], info)
|
||||
else:
|
||||
o = [v.reset() for v in venv]
|
||||
for i, a in enumerate(action_list):
|
||||
o = []
|
||||
for v in venv:
|
||||
A, B, C, D = v.step([a] * num)
|
||||
if sum(C):
|
||||
A = v.reset(np.where(C)[0])
|
||||
o.append([A, B, C, D])
|
||||
for index, infos in enumerate(zip(*o)):
|
||||
if index == 3: # do not check info here
|
||||
continue
|
||||
for info in infos:
|
||||
assert recurse_comp(infos[0], info)
|
||||
if __name__ == '__main__':
|
||||
t = [0] * len(venv)
|
||||
for i, e in enumerate(venv):
|
||||
t[i] = time.time()
|
||||
@ -117,7 +157,6 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
assert v.size == list(range(size, size + num))
|
||||
assert v.env_num == num
|
||||
assert v.action_space == [Discrete(2)] * num
|
||||
|
||||
for v in venv:
|
||||
v.close()
|
||||
|
||||
@ -125,3 +164,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
if __name__ == '__main__':
|
||||
test_vecenv()
|
||||
test_async_env()
|
||||
test_async_check_id()
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
@ -55,10 +55,10 @@ def test_ddpg(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -104,8 +104,6 @@ def test_ddpg(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -113,7 +111,6 @@ def test_ddpg(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.policy.dist import DiagGaussian
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
@ -58,10 +58,10 @@ def test_ppo(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -119,8 +119,6 @@ def test_ppo(args=get_args()):
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -128,7 +126,6 @@ def test_ppo(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.policy import SACPolicy, ImitationPolicy
|
||||
@ -54,10 +54,10 @@ def test_sac_with_il(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -105,7 +105,6 @@ def test_sac_with_il(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -113,7 +112,6 @@ def test_sac_with_il(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
# here we define an imitation collector with a trivial policy
|
||||
if args.task == 'Pendulum-v0':
|
||||
@ -123,15 +121,17 @@ def test_sac_with_il(args=get_args()):
|
||||
).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='continuous')
|
||||
il_test_collector = Collector(il_policy, test_envs)
|
||||
il_test_collector = Collector(
|
||||
il_policy,
|
||||
DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
)
|
||||
train_collector.reset()
|
||||
result = offpolicy_trainer(
|
||||
il_policy, train_collector, il_test_collector, args.epoch,
|
||||
args.step_per_epoch // 5, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
il_test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -139,7 +139,6 @@ def test_sac_with_il(args=get_args()):
|
||||
collector = Collector(il_policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
@ -57,10 +57,10 @@ def test_td3(args=get_args()):
|
||||
args.max_action = env.action_space.high[0]
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -109,8 +109,6 @@ def test_td3(args=get_args()):
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -118,7 +116,6 @@ def test_td3(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
from tianshou.policy import A2CPolicy, ImitationPolicy
|
||||
from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
|
||||
@ -52,10 +52,10 @@ def test_a2c_with_il(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -94,7 +94,6 @@ def test_a2c_with_il(args=get_args()):
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -102,7 +101,6 @@ def test_a2c_with_il(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
# here we define an imitation collector with a trivial policy
|
||||
if args.task == 'CartPole-v0':
|
||||
@ -111,15 +109,17 @@ def test_a2c_with_il(args=get_args()):
|
||||
net = Actor(net, args.action_shape).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
|
||||
il_policy = ImitationPolicy(net, optim, mode='discrete')
|
||||
il_test_collector = Collector(il_policy, test_envs)
|
||||
il_test_collector = Collector(
|
||||
il_policy,
|
||||
DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
)
|
||||
train_collector.reset()
|
||||
result = offpolicy_trainer(
|
||||
il_policy, train_collector, il_test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
il_test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -127,7 +127,6 @@ def test_a2c_with_il(args=get_args()):
|
||||
collector = Collector(il_policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
@ -49,10 +49,10 @@ def test_dqn(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -110,8 +110,6 @@ def test_dqn(args=get_args()):
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -119,7 +117,6 @@ def test_dqn(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
def test_pdqn(args=get_args()):
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
@ -47,10 +47,10 @@ def test_drqn(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
[lambda: gym.make(args.task)for _ in range(args.training_num)])
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -96,8 +96,6 @@ def test_drqn(args=get_args()):
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
|
||||
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -105,7 +103,6 @@ def test_drqn(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -8,7 +8,7 @@ import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
@ -112,10 +112,10 @@ def test_pg(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -151,8 +151,6 @@ def test_pg(args=get_args()):
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -160,7 +158,6 @@ def test_pg(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -6,7 +6,7 @@ import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
@ -54,10 +54,10 @@ def test_ppo(args=get_args()):
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# you can also use tianshou.env.SubprocVectorEnv
|
||||
train_envs = VectorEnv(
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = VectorEnv(
|
||||
test_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
@ -108,8 +108,6 @@ def test_ppo(args=get_args()):
|
||||
args.test_num, args.batch_size, stop_fn=stop_fn, save_fn=save_fn,
|
||||
writer=writer)
|
||||
assert stop_fn(result['best_reward'])
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
if __name__ == '__main__':
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
@ -117,7 +115,6 @@ def test_ppo(args=get_args()):
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
from copy import deepcopy
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.data import Collector
|
||||
from tianshou.policy import RandomPolicy
|
||||
|
||||
@ -37,7 +37,7 @@ def gomoku(args=get_args()):
|
||||
|
||||
def env_func():
|
||||
return TicTacToeEnv(args.board_size, args.win_size)
|
||||
test_envs = VectorEnv([env_func for _ in range(args.test_num)])
|
||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
||||
for r in range(args.self_play_round):
|
||||
rews = []
|
||||
agent_learn.set_eps(0.0)
|
||||
|
@ -6,7 +6,7 @@ from copy import deepcopy
|
||||
from typing import Optional, Tuple
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.env import VectorEnv
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
@ -106,8 +106,8 @@ def train_agent(args: argparse.Namespace = get_args(),
|
||||
) -> Tuple[dict, BasePolicy]:
|
||||
def env_func():
|
||||
return TicTacToeEnv(args.board_size, args.win_size)
|
||||
train_envs = VectorEnv([env_func for _ in range(args.training_num)])
|
||||
test_envs = VectorEnv([env_func for _ in range(args.test_num)])
|
||||
train_envs = DummyVectorEnv([env_func for _ in range(args.training_num)])
|
||||
test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
@ -159,9 +159,6 @@ def train_agent(args: argparse.Namespace = get_args(),
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
|
||||
test_in_train=False)
|
||||
|
||||
train_collector.close()
|
||||
test_collector.close()
|
||||
|
||||
return result, policy.policies[args.agent_id - 1]
|
||||
|
||||
|
||||
@ -175,4 +172,3 @@ def watch(args: argparse.Namespace = get_args(),
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
|
||||
collector.close()
|
||||
|
@ -29,10 +29,10 @@ def data():
|
||||
batch3 = Batch(obs=[np.arange(20) for _ in np.arange(batch_len)],
|
||||
reward=np.arange(batch_len))
|
||||
indexs = np.random.choice(batch_len,
|
||||
size=batch_len//10, replace=False)
|
||||
size=batch_len // 10, replace=False)
|
||||
slice_dict = {'obs': [np.arange(20)
|
||||
for _ in np.arange(batch_len//10)],
|
||||
'reward': np.arange(batch_len//10)}
|
||||
for _ in np.arange(batch_len // 10)],
|
||||
'reward': np.arange(batch_len // 10)}
|
||||
dict_set = [{'obs': np.arange(20), 'info': "this is info", 'reward': 0}
|
||||
for _ in np.arange(1e2)]
|
||||
batch4 = Batch(
|
||||
@ -45,16 +45,17 @@ def data():
|
||||
)
|
||||
|
||||
print("Initialised")
|
||||
return {'batch_set': batch_set,
|
||||
'batch0': batch0,
|
||||
'batchs1': batchs1,
|
||||
'batchs2': batchs2,
|
||||
'batch3': batch3,
|
||||
'indexs': indexs,
|
||||
'dict_set': dict_set,
|
||||
'slice_dict': slice_dict,
|
||||
'batch4': batch4
|
||||
}
|
||||
return {
|
||||
'batch_set': batch_set,
|
||||
'batch0': batch0,
|
||||
'batchs1': batchs1,
|
||||
'batchs2': batchs2,
|
||||
'batch3': batch3,
|
||||
'indexs': indexs,
|
||||
'dict_set': dict_set,
|
||||
'slice_dict': slice_dict,
|
||||
'batch4': batch4
|
||||
}
|
||||
|
||||
|
||||
def test_init(data):
|
||||
|
@ -8,15 +8,15 @@ from tianshou.data import (ListReplayBuffer, PrioritizedReplayBuffer,
|
||||
@pytest.fixture(scope="module")
|
||||
def data():
|
||||
np.random.seed(0)
|
||||
obs = {'observable': np.random.rand(
|
||||
100, 100), 'hidden': np.random.randint(1000, size=200)}
|
||||
obs = {'observable': np.random.rand(100, 100),
|
||||
'hidden': np.random.randint(1000, size=200)}
|
||||
info = {'policy': "dqn", 'base': np.arange(10)}
|
||||
add_data = {'obs': obs, 'rew': 1., 'act': np.random.rand(30),
|
||||
'done': False, 'obs_next': obs, 'info': info}
|
||||
buffer = ReplayBuffer(int(1e3), stack_num=100)
|
||||
buffer2 = ReplayBuffer(int(1e4), stack_num=100)
|
||||
indexes = np.random.choice(int(1e3), size=3, replace=False)
|
||||
return{
|
||||
return {
|
||||
'add_data': add_data,
|
||||
'buffer': buffer,
|
||||
'buffer2': buffer2,
|
||||
|
@ -5,7 +5,7 @@ from gym.spaces.discrete import Discrete
|
||||
from gym.utils import seeding
|
||||
|
||||
from tianshou.data import Batch, Collector, ReplayBuffer
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class SimplePolicy(BasePolicy):
|
||||
return super().learn(batch, **kwargs)
|
||||
|
||||
def forward(self, batch, state=None, **kwargs):
|
||||
return Batch(act=np.array([30]*len(batch)), state=None, logits=None)
|
||||
return Batch(act=np.array([30] * len(batch)), state=None, logits=None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@ -56,7 +56,7 @@ def data():
|
||||
np.random.seed(0)
|
||||
env = SimpleEnv()
|
||||
env.seed(0)
|
||||
env_vec = VectorEnv(
|
||||
env_vec = DummyVectorEnv(
|
||||
[lambda: SimpleEnv() for _ in range(100)])
|
||||
env_vec.seed(np.random.randint(1000, size=100).tolist())
|
||||
env_subproc = SubprocVectorEnv(
|
||||
@ -70,7 +70,7 @@ def data():
|
||||
collector = Collector(policy, env, ReplayBuffer(50000))
|
||||
collector_vec = Collector(policy, env_vec, ReplayBuffer(50000))
|
||||
collector_subproc = Collector(policy, env_subproc, ReplayBuffer(50000))
|
||||
return{
|
||||
return {
|
||||
"env": env,
|
||||
"env_vec": env_vec,
|
||||
"env_subproc": env_subproc,
|
||||
@ -79,14 +79,13 @@ def data():
|
||||
"buffer": buffer,
|
||||
"collector": collector,
|
||||
"collector_vec": collector_vec,
|
||||
"collector_subproc": collector_subproc
|
||||
}
|
||||
"collector_subproc": collector_subproc,
|
||||
}
|
||||
|
||||
|
||||
def test_init(data):
|
||||
for _ in range(5000):
|
||||
c = Collector(data["policy"], data["env"], data["buffer"])
|
||||
c.close()
|
||||
Collector(data["policy"], data["env"], data["buffer"])
|
||||
|
||||
|
||||
def test_reset(data):
|
||||
@ -111,8 +110,7 @@ def test_sample(data):
|
||||
|
||||
def test_init_vec_env(data):
|
||||
for _ in range(5000):
|
||||
c = Collector(data["policy"], data["env_vec"], data["buffer"])
|
||||
c.close()
|
||||
Collector(data["policy"], data["env_vec"], data["buffer"])
|
||||
|
||||
|
||||
def test_reset_vec_env(data):
|
||||
@ -137,10 +135,7 @@ def test_sample_vec_env(data):
|
||||
|
||||
def test_init_subproc_env(data):
|
||||
for _ in range(5000):
|
||||
c = Collector(data["policy"], data["env_subproc_init"], data["buffer"])
|
||||
"""TODO: This should be changed to c.close() in theory,
|
||||
but currently subproc_env doesn't support that."""
|
||||
c.reset()
|
||||
Collector(data["policy"], data["env_subproc_init"], data["buffer"])
|
||||
|
||||
|
||||
def test_reset_subproc_env(data):
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, \
|
||||
exploration
|
||||
|
||||
__version__ = '0.2.5'
|
||||
__version__ = '0.2.6'
|
||||
__all__ = [
|
||||
'env',
|
||||
'data',
|
||||
|
@ -5,7 +5,7 @@ import warnings
|
||||
import numpy as np
|
||||
from typing import Any, Dict, List, Union, Optional, Callable
|
||||
|
||||
from tianshou.env import BaseVectorEnv, VectorEnv, AsyncVectorEnv
|
||||
from tianshou.env import BaseVectorEnv, DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.exploration import BaseNoise
|
||||
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
|
||||
@ -51,7 +51,8 @@ class Collector(object):
|
||||
collector = Collector(policy, env, buffer=replay_buffer)
|
||||
|
||||
# the collector supports vectorized environments as well
|
||||
envs = VectorEnv([lambda: gym.make('CartPole-v0') for _ in range(3)])
|
||||
envs = DummyVectorEnv([lambda: gym.make('CartPole-v0')
|
||||
for _ in range(3)])
|
||||
collector = Collector(policy, envs, buffer=replay_buffer)
|
||||
|
||||
# collect 3 episodes
|
||||
@ -84,7 +85,7 @@ class Collector(object):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(env, BaseVectorEnv):
|
||||
env = VectorEnv([lambda: env])
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
self.env = env
|
||||
self.env_num = len(env)
|
||||
# environments that are available in step()
|
||||
@ -93,7 +94,7 @@ class Collector(object):
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
# self.async is a flag to indicate whether this collector works
|
||||
# with asynchronous simulation
|
||||
self.is_async = isinstance(env, AsyncVectorEnv)
|
||||
self.is_async = env.is_async
|
||||
# need cache buffers before storing in the main buffer
|
||||
self._cached_buf = [ListReplayBuffer() for _ in range(self.env_num)]
|
||||
self.collect_time, self.collect_step, self.collect_episode = 0., 0, 0
|
||||
@ -101,6 +102,7 @@ class Collector(object):
|
||||
self.policy = policy
|
||||
self.preprocess_fn = preprocess_fn
|
||||
self.process_fn = policy.process_fn
|
||||
self._action_space = env.action_space
|
||||
self._action_noise = action_noise
|
||||
self._rew_metric = reward_metric or Collector._default_rew_metric
|
||||
# avoid creating attribute outside __init__
|
||||
@ -119,6 +121,8 @@ class Collector(object):
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all related variables in the collector."""
|
||||
# use empty Batch for ``state`` so that ``self.data`` supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
|
||||
obs_next={}, policy={})
|
||||
self.reset_env()
|
||||
@ -156,10 +160,6 @@ class Collector(object):
|
||||
"""Render all the environment(s)."""
|
||||
return self.env.render(**kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the environment(s)."""
|
||||
self.env.close()
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
"""Reset the hidden state: self.data.state[id]."""
|
||||
state = self.data.state # it is a reference
|
||||
@ -228,20 +228,13 @@ class Collector(object):
|
||||
|
||||
# restore the state and the input data
|
||||
last_state = self.data.state
|
||||
if last_state.is_empty():
|
||||
if isinstance(last_state, Batch) and last_state.is_empty():
|
||||
last_state = None
|
||||
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())
|
||||
|
||||
# calculate the next action
|
||||
if random:
|
||||
if self.is_async:
|
||||
# TODO self.env.action_space will invoke remote call for
|
||||
# all environments, which may hang in async simulation.
|
||||
# This can be avoided by using a random policy, but not
|
||||
# in the collector level. Leave it as a future work.
|
||||
raise RuntimeError("cannot use random "
|
||||
"sampling in async simulation!")
|
||||
spaces = self.env.action_space
|
||||
spaces = self._action_space
|
||||
result = Batch(
|
||||
act=[spaces[i].sample() for i in self._ready_env_ids])
|
||||
else:
|
||||
@ -254,7 +247,9 @@ class Collector(object):
|
||||
state = Batch()
|
||||
self.data.update(state=state, policy=result.get('policy', Batch()))
|
||||
# save hidden state to policy._state, in order to save into buffer
|
||||
self.data.policy._state = self.data.state
|
||||
if not (isinstance(self.data.state, Batch)
|
||||
and self.data.state.is_empty()):
|
||||
self.data.policy._state = self.data.state
|
||||
|
||||
self.data.act = to_numpy(result.act)
|
||||
if self._action_noise is not None:
|
||||
@ -354,7 +349,6 @@ class Collector(object):
|
||||
the buffer, otherwise it will extract the data with the given
|
||||
batch_size.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
'Collector.sample is deprecated and will cause error if you use '
|
||||
'prioritized experience replay! Collector.sample will be removed '
|
||||
@ -363,23 +357,36 @@ class Collector(object):
|
||||
batch_data = self.process_fn(batch_data, self.buffer, indice)
|
||||
return batch_data
|
||||
|
||||
def close(self) -> None:
|
||||
warnings.warn(
|
||||
'Collector.close is deprecated and will be removed upon version '
|
||||
'0.3.', Warning)
|
||||
|
||||
|
||||
def _batch_set_item(source: Batch, indices: np.ndarray,
|
||||
target: Batch, size: int):
|
||||
# for any key chain k, there are three cases
|
||||
# for any key chain k, there are four cases
|
||||
# 1. source[k] is non-reserved, but target[k] does not exist or is reserved
|
||||
# 2. source[k] does not exist or is reserved, but target[k] is non-reserved
|
||||
# 3. both source[k] and target[k] is non-reserved
|
||||
for k, v in target.items():
|
||||
if not isinstance(v, Batch) or not v.is_empty():
|
||||
# 3. both source[k] and target[k] are non-reserved
|
||||
# 4. both source[k] and target[k] do not exist or are reserved, do nothing.
|
||||
# A special case in case 4, if target[k] is reserved but source[k] does
|
||||
# not exist, make source[k] reserved, too.
|
||||
for k, vt in target.items():
|
||||
if not isinstance(vt, Batch) or not vt.is_empty():
|
||||
# target[k] is non-reserved
|
||||
vs = source.get(k, Batch())
|
||||
if isinstance(vs, Batch) and vs.is_empty():
|
||||
# case 2
|
||||
# use __dict__ to avoid many type checks
|
||||
source.__dict__[k] = _create_value(v[0], size)
|
||||
if isinstance(vs, Batch):
|
||||
if vs.is_empty():
|
||||
# case 2, use __dict__ to avoid many type checks
|
||||
source.__dict__[k] = _create_value(vt[0], size)
|
||||
else:
|
||||
assert isinstance(vt, Batch)
|
||||
_batch_set_item(source.__dict__[k], indices, vt, size)
|
||||
else:
|
||||
# target[k] is reserved
|
||||
# case 1
|
||||
# case 1 or special case of case 4
|
||||
if k not in source.__dict__:
|
||||
source.__dict__[k] = Batch()
|
||||
continue
|
||||
source.__dict__[k][indices] = v
|
||||
source.__dict__[k][indices] = vt
|
||||
|
14
tianshou/env/__init__.py
vendored
14
tianshou/env/__init__.py
vendored
@ -1,17 +1,13 @@
|
||||
from tianshou.env.vecenv.base import BaseVectorEnv
|
||||
from tianshou.env.vecenv.dummy import VectorEnv
|
||||
from tianshou.env.vecenv.subproc import SubprocVectorEnv
|
||||
from tianshou.env.vecenv.asyncenv import AsyncVectorEnv
|
||||
from tianshou.env.vecenv.rayenv import RayVectorEnv
|
||||
from tianshou.env.vecenv.shmemenv import ShmemVectorEnv
|
||||
from tianshou.env.venvs import BaseVectorEnv, DummyVectorEnv, VectorEnv, \
|
||||
SubprocVectorEnv, ShmemVectorEnv, RayVectorEnv
|
||||
from tianshou.env.maenv import MultiAgentEnv
|
||||
|
||||
__all__ = [
|
||||
'BaseVectorEnv',
|
||||
'VectorEnv',
|
||||
'DummyVectorEnv',
|
||||
'VectorEnv', # TODO: remove in later version
|
||||
'SubprocVectorEnv',
|
||||
'AsyncVectorEnv',
|
||||
'RayVectorEnv',
|
||||
'ShmemVectorEnv',
|
||||
'RayVectorEnv',
|
||||
'MultiAgentEnv',
|
||||
]
|
||||
|
0
tianshou/env/vecenv/__init__.py
vendored
0
tianshou/env/vecenv/__init__.py
vendored
104
tianshou/env/vecenv/asyncenv.py
vendored
104
tianshou/env/vecenv/asyncenv.py
vendored
@ -1,104 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from multiprocessing import connection
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
|
||||
|
||||
class AsyncVectorEnv(SubprocVectorEnv):
|
||||
"""Vectorized asynchronous environment wrapper based on subprocess.
|
||||
|
||||
:param wait_num: used in asynchronous simulation if the time cost of
|
||||
``env.step`` varies with time and synchronously waiting for all
|
||||
environments to finish a step is time-wasting. In that case, we can
|
||||
return when ``wait_num`` environments finish a step and keep on
|
||||
simulation in these environments. If ``None``, asynchronous simulation
|
||||
is disabled; else, ``1 <= wait_num <= env_num``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert 1 <= self.wait_num <= len(env_fns), \
|
||||
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
|
||||
self.waiting_conn = []
|
||||
# environments in self.ready_id is actually ready
|
||||
# but environments in self.waiting_id are just waiting when checked,
|
||||
# and they may be ready now, but this is not known until we check it
|
||||
# in the step() function
|
||||
self.waiting_id = []
|
||||
# all environments are ready in the beginning
|
||||
self.ready_id = list(range(self.env_num))
|
||||
|
||||
def _assert_and_transform_id(self,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> List[int]:
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, \
|
||||
f'Cannot reset environment {i} which is stepping now!'
|
||||
assert i in self.ready_id, \
|
||||
f'Can only reset ready environments {self.ready_id}.'
|
||||
return id
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
id = self._assert_and_transform_id(id)
|
||||
return super().reset(id)
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
if len(self.waiting_id) > 0:
|
||||
raise RuntimeError(
|
||||
f"Environments {self.waiting_id} are still "
|
||||
f"stepping, cannot render them now.")
|
||||
return super().render(**kwargs)
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
if self.closed:
|
||||
return []
|
||||
# finish remaining steps, and close
|
||||
self.step(None)
|
||||
return super().close()
|
||||
|
||||
def step(self,
|
||||
action: Optional[np.ndarray],
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Provide the given action to the environments. The action sequence
|
||||
should correspond to the ``id`` argument, and the ``id`` argument
|
||||
should be a subset of the ``env_id`` in the last returned ``info``
|
||||
(initially they are env_ids of all the environments). If action is
|
||||
``None``, fetch unfinished step() calls instead.
|
||||
"""
|
||||
if action is not None:
|
||||
id = self._assert_and_transform_id(id)
|
||||
assert len(action) == len(id)
|
||||
for i, (act, env_id) in enumerate(zip(action, id)):
|
||||
self.parent_remote[env_id].send(['step', act])
|
||||
self.waiting_conn.append(self.parent_remote[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
result = []
|
||||
while len(self.waiting_conn) > 0 and len(result) < self.wait_num:
|
||||
ready_conns = connection.wait(self.waiting_conn)
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
ans = conn.recv()
|
||||
obs, rew, done, info = ans
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
127
tianshou/env/vecenv/base.py
vendored
127
tianshou/env/vecenv/base.py
vendored
@ -1,127 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Union, Optional, Callable
|
||||
|
||||
|
||||
class BaseVectorEnv(ABC, gym.Env):
|
||||
"""Base class for vectorized environments wrapper. Usage:
|
||||
::
|
||||
|
||||
env_num = 8
|
||||
envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||
assert len(envs) == env_num
|
||||
|
||||
It accepts a list of environment generators. In other words, an environment
|
||||
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||
environment of the given task, for example, ``gym.make(task)``.
|
||||
|
||||
All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
|
||||
Here are some other usages:
|
||||
::
|
||||
|
||||
envs.seed(2) # which is equal to the next line
|
||||
envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
|
||||
obs = envs.reset() # reset all environments
|
||||
obs = envs.reset([0, 5, 7]) # reset 3 specific environments
|
||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||
envs.render() # render all environments
|
||||
envs.close() # close all environments
|
||||
|
||||
.. warning::
|
||||
|
||||
If you use your own environment, please make sure the ``seed`` method
|
||||
is set up properly, e.g.,
|
||||
::
|
||||
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
def __getattribute__(self, key: str):
|
||||
"""Switch between the default attribute getter or one
|
||||
looking at wrapped environment level depending on the key."""
|
||||
if key not in ('observation_space', 'action_space'):
|
||||
return super().__getattribute__(key)
|
||||
else:
|
||||
return self.__getattr__(key)
|
||||
|
||||
@abstractmethod
|
||||
def __getattr__(self, key: str):
|
||||
"""Try to retrieve an attribute from each individual wrapped
|
||||
environment, if it does not belong to the wrapping vector
|
||||
environment class."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None):
|
||||
"""Reset the state of all the environments and return initial
|
||||
observations if id is ``None``, otherwise reset the specific
|
||||
environments with given id, either an int or a list.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Run one timestep of all the environments’ dynamics if id is
|
||||
``None``, otherwise run one timestep for some environments
|
||||
with given id, either an int or a list. When the end of
|
||||
episode is reached, you are responsible for calling reset(id)
|
||||
to reset this environment’s state.
|
||||
|
||||
Accept a batch of action and return a tuple (obs, rew, done, info).
|
||||
|
||||
:param numpy.ndarray action: a batch of action provided by the agent.
|
||||
|
||||
:return: A tuple including four items:
|
||||
|
||||
* ``obs`` a numpy.ndarray, the agent's observation of current \
|
||||
environments
|
||||
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
||||
previous actions
|
||||
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
|
||||
which case further step() calls will return undefined results
|
||||
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
|
||||
information (helpful for debugging, and sometimes learning)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
"""Set the seed for all environments.
|
||||
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
``[i, i + 1, i + 2, ...]``) or a list.
|
||||
|
||||
:return: The list of seeds used in this env's random number \
|
||||
generators. The first value in the list should be the "main" seed, or \
|
||||
the value which a reproducer pass to "seed".
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render(self, **kwargs) -> None:
|
||||
"""Render all of the environments."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments.
|
||||
|
||||
Environments will automatically close() themselves when garbage
|
||||
collected or when the program exits.
|
||||
"""
|
||||
pass
|
65
tianshou/env/vecenv/dummy.py
vendored
65
tianshou/env/vecenv/dummy.py
vendored
@ -1,65 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.envs = [_() for _ in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return [getattr(env, key) if hasattr(env, key) else None
|
||||
for env in self.envs]
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack([self.envs[i].reset() for i in id])
|
||||
return obs
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = [self.envs[i].step(action[i]) for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
result = []
|
||||
for e, s in zip(self.envs, seed):
|
||||
if hasattr(e, 'seed'):
|
||||
result.append(e.seed(s))
|
||||
return result
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
result = []
|
||||
for e in self.envs:
|
||||
if hasattr(e, 'render'):
|
||||
result.append(e.render(**kwargs))
|
||||
return result
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
return [e.close() for e in self.envs]
|
76
tianshou/env/vecenv/rayenv.py
vendored
76
tianshou/env/vecenv/rayenv.py
vendored
@ -1,76 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on
|
||||
`ray <https://github.com/ray-project/ray>`_. This is a choice to run
|
||||
distributed environments in a cluster.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
try:
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
except NameError:
|
||||
raise ImportError(
|
||||
'Please install ray to support RayVectorEnv: pip install ray')
|
||||
self.envs = [
|
||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
|
||||
def __getattr__(self, key):
|
||||
return ray.get([e.__getattr__.remote(key) for e in self.envs])
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
result = ray.get([self.envs[j].step.remote(action[i])
|
||||
for i, j in enumerate(id)])
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
obs = np.stack(ray.get([self.envs[i].reset.remote() for i in id]))
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if not hasattr(self.envs[0], 'seed'):
|
||||
return []
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
return ray.get([e.seed.remote(s) for e, s in zip(self.envs, seed)])
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
if not hasattr(self.envs[0], 'render'):
|
||||
return [None for e in self.envs]
|
||||
return ray.get([e.render.remote(**kwargs) for e in self.envs])
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
return ray.get([e.close.remote() for e in self.envs])
|
177
tianshou/env/vecenv/shmemenv.py
vendored
177
tianshou/env/vecenv/shmemenv.py
vendored
@ -1,177 +0,0 @@
|
||||
import gym
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from multiprocessing import Pipe, Process, Array
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
from tianshou.env import BaseVectorEnv, SubprocVectorEnv
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
_NP_TO_CT = {np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
np.uint32: ctypes.c_uint32,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.int16: ctypes.c_int16,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double}
|
||||
|
||||
|
||||
def _shmem_worker(parent, p, env_fn_wrapper, obs_bufs):
|
||||
"""Control a single environment instance using IPC and shared memory."""
|
||||
def _encode_obs(obs, buffer):
|
||||
if isinstance(obs, np.ndarray):
|
||||
buffer.save(obs)
|
||||
elif isinstance(obs, tuple):
|
||||
for o, b in zip(obs, buffer):
|
||||
_encode_obs(o, b)
|
||||
elif isinstance(obs, dict):
|
||||
for k in obs.keys():
|
||||
_encode_obs(obs[k], buffer[k])
|
||||
return None
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
obs, reward, done, info = env.step(data)
|
||||
p.send((_encode_obs(obs, obs_bufs), reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
p.send(_encode_obs(env.reset(), obs_bufs))
|
||||
elif cmd == 'close':
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
elif cmd == 'getattr':
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array"""
|
||||
|
||||
def __init__(self, dtype, shape):
|
||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray):
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self):
|
||||
return np.frombuffer(self.arr.get_obj(),
|
||||
dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
class ShmemVectorEnv(SubprocVectorEnv):
|
||||
"""Optimized version of SubprocVectorEnv that uses shared variables to
|
||||
communicate observations. SubprocVectorEnv has exactly the same API as
|
||||
SubprocVectorEnv.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
|
||||
detailed explanation.
|
||||
|
||||
ShmemVectorEnv Class was inspired by openai baseline's implementation.
|
||||
Please refer to 'https://github.com/openai/baselines/blob/master/baselines/
|
||||
common/vec_env/shmem_vec_env.py' for more info if you are interested.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
BaseVectorEnv.__init__(self, env_fns)
|
||||
# Mind that SubprocVectorEnv is not initialised.
|
||||
self.closed = False
|
||||
dummy = env_fns[0]()
|
||||
obs_space = dummy.observation_space
|
||||
dummy.close()
|
||||
del dummy
|
||||
self.obs_bufs = [ShmemVectorEnv._setup_buf(obs_space)
|
||||
for _ in range(self.env_num)]
|
||||
self.parent_remote, self.child_remote = \
|
||||
zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [
|
||||
Process(target=_shmem_worker, args=(
|
||||
parent, child, CloudpickleWrapper(env_fn),
|
||||
obs_buf), daemon=True)
|
||||
for (parent, child, env_fn, obs_buf) in zip(
|
||||
self.parent_remote, self.child_remote, env_fns, self.obs_bufs)
|
||||
]
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
for c in self.child_remote:
|
||||
c.close()
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.parent_remote[j].send(['step', action[i]])
|
||||
result = []
|
||||
for i in id:
|
||||
obs, rew, done, info = self.parent_remote[i].recv()
|
||||
obs = self._decode_obs(obs, i)
|
||||
result.append((obs, rew, done, info))
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self.parent_remote[i].send(['reset', None])
|
||||
obs = np.stack(
|
||||
[self._decode_obs(self.parent_remote[i].recv(), i) for i in id])
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def _setup_buf(space):
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
buffer = {k: ShmemVectorEnv._setup_buf(v)
|
||||
for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
buffer = tuple([ShmemVectorEnv._setup_buf(t)
|
||||
for t in space.spaces])
|
||||
else:
|
||||
buffer = ShArray(space.dtype, space.shape)
|
||||
return buffer
|
||||
|
||||
def _decode_obs(self, isNone, index):
|
||||
def decode_obs(buffer):
|
||||
if isinstance(buffer, ShArray):
|
||||
return buffer.get()
|
||||
elif isinstance(buffer, tuple):
|
||||
return tuple([decode_obs(b) for b in buffer])
|
||||
elif isinstance(buffer, dict):
|
||||
return {k: decode_obs(v) for k, v in buffer.items()}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return decode_obs(self.obs_bufs[index])
|
115
tianshou/env/vecenv/subproc.py
vendored
115
tianshou/env/vecenv/subproc.py
vendored
@ -1,115 +0,0 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from multiprocessing import Process, Pipe
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
def _worker(parent, p, env_fn_wrapper):
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
cmd, data = p.recv()
|
||||
if cmd == 'step':
|
||||
p.send(env.step(data))
|
||||
elif cmd == 'reset':
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
elif cmd == 'getattr':
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]]) -> None:
|
||||
super().__init__(env_fns)
|
||||
self.closed = False
|
||||
self.parent_remote, self.child_remote = \
|
||||
zip(*[Pipe() for _ in range(self.env_num)])
|
||||
self.processes = [
|
||||
Process(target=_worker, args=(
|
||||
parent, child, CloudpickleWrapper(env_fn)), daemon=True)
|
||||
for (parent, child, env_fn) in zip(
|
||||
self.parent_remote, self.child_remote, env_fns)
|
||||
]
|
||||
for p in self.processes:
|
||||
p.start()
|
||||
for c in self.child_remote:
|
||||
c.close()
|
||||
|
||||
def __getattr__(self, key):
|
||||
for p in self.parent_remote:
|
||||
p.send(['getattr', key])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def step(self,
|
||||
action: np.ndarray,
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.parent_remote[j].send(['step', action[i]])
|
||||
result = [self.parent_remote[i].recv() for i in id]
|
||||
obs, rew, done, info = map(np.stack, zip(*result))
|
||||
return obs, rew, done, info
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
if id is None:
|
||||
id = range(self.env_num)
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
for i in id:
|
||||
self.parent_remote[i].send(['reset', None])
|
||||
obs = np.stack([self.parent_remote[i].recv() for i in id])
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
for p, s in zip(self.parent_remote, seed):
|
||||
p.send(['seed', s])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
for p in self.parent_remote:
|
||||
p.send(['render', kwargs])
|
||||
return [p.recv() for p in self.parent_remote]
|
||||
|
||||
def close(self) -> List[Any]:
|
||||
if self.closed:
|
||||
return []
|
||||
for p in self.parent_remote:
|
||||
p.send(['close', None])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
self.closed = True
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
return result
|
336
tianshou/env/venvs.py
vendored
Normal file
336
tianshou/env/venvs.py
vendored
Normal file
@ -0,0 +1,336 @@
|
||||
import gym
|
||||
import warnings
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Union, Optional, Callable, Any
|
||||
|
||||
from tianshou.env.worker import EnvWorker, DummyEnvWorker, SubprocEnvWorker, \
|
||||
RayEnvWorker
|
||||
|
||||
|
||||
class BaseVectorEnv(gym.Env):
|
||||
"""Base class for vectorized environments wrapper. Usage:
|
||||
::
|
||||
|
||||
env_num = 8
|
||||
envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||
assert len(envs) == env_num
|
||||
|
||||
It accepts a list of environment generators. In other words, an environment
|
||||
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||
environment of the given task, for example, ``gym.make(task)``.
|
||||
|
||||
All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
|
||||
Here are some other usages:
|
||||
::
|
||||
|
||||
envs.seed(2) # which is equal to the next line
|
||||
envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
|
||||
obs = envs.reset() # reset all environments
|
||||
obs = envs.reset([0, 5, 7]) # reset 3 specific environments
|
||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||
envs.render() # render all environments
|
||||
envs.close() # close all environments
|
||||
|
||||
.. warning::
|
||||
|
||||
If you use your own environment, please make sure the ``seed`` method
|
||||
is set up properly, e.g.,
|
||||
::
|
||||
|
||||
def seed(self, seed):
|
||||
np.random.seed(seed)
|
||||
|
||||
Otherwise, the outputs of these envs may be the same with each other.
|
||||
|
||||
:param env_fns: a list of callable envs, ``env_fns[i]()`` generates the ith
|
||||
env.
|
||||
:param worker_fn: a callable worker, ``worker_fn(env_fns[i])`` generates a
|
||||
worker which contains this env.
|
||||
:param int wait_num: use in asynchronous simulation if the time cost of
|
||||
``env.step`` varies with time and synchronously waiting for all
|
||||
environments to finish a step is time-wasting. In that case, we can
|
||||
return when ``wait_num`` environments finish a step and keep on
|
||||
simulation in these environments. If ``None``, asynchronous simulation
|
||||
is disabled; else, ``1 <= wait_num <= env_num``.
|
||||
:param float timeout: use in asynchronous simulation same as above, in each
|
||||
vectorized step it only deal with those environments spending time
|
||||
within ``timeout`` seconds.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
env_fns: List[Callable[[], gym.Env]],
|
||||
worker_fn: Callable[[Callable[[], gym.Env]], EnvWorker],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> None:
|
||||
self._env_fns = env_fns
|
||||
# A VectorEnv contains a pool of EnvWorkers, which corresponds to
|
||||
# interact with the given envs (one worker <-> one env).
|
||||
self.workers = [worker_fn(fn) for fn in env_fns]
|
||||
self.worker_class = type(self.workers[0])
|
||||
assert issubclass(self.worker_class, EnvWorker)
|
||||
assert all([isinstance(w, self.worker_class) for w in self.workers])
|
||||
|
||||
self.env_num = len(env_fns)
|
||||
self.wait_num = wait_num or len(env_fns)
|
||||
assert 1 <= self.wait_num <= len(env_fns), \
|
||||
f'wait_num should be in [1, {len(env_fns)}], but got {wait_num}'
|
||||
self.timeout = timeout
|
||||
assert self.timeout is None or self.timeout > 0, \
|
||||
f'timeout is {timeout}, it should be positive if provided!'
|
||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
||||
self.waiting_conn = []
|
||||
# environments in self.ready_id is actually ready
|
||||
# but environments in self.waiting_id are just waiting when checked,
|
||||
# and they may be ready now, but this is not known until we check it
|
||||
# in the step() function
|
||||
self.waiting_id = []
|
||||
# all environments are ready in the beginning
|
||||
self.ready_id = list(range(self.env_num))
|
||||
self.is_closed = False
|
||||
|
||||
def _assert_is_not_closed(self) -> None:
|
||||
assert not self.is_closed, f"Methods of {self.__class__.__name__} "\
|
||||
"should not be called after close."
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
def __getattribute__(self, key: str) -> Any:
|
||||
"""Any class who inherits ``gym.Env`` will inherit some attributes,
|
||||
like ``action_space``. However, we would like the attribute lookup to
|
||||
go straight into the worker (in fact, this vector env's action_space
|
||||
is always ``None``).
|
||||
"""
|
||||
if key in ['metadata', 'reward_range', 'spec', 'action_space',
|
||||
'observation_space']: # reserved keys in gym.Env
|
||||
return self.__getattr__(key)
|
||||
else:
|
||||
return super().__getattribute__(key)
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
"""Try to retrieve an attribute from each individual wrapped
|
||||
environment, if it does not belong to the wrapping vector environment
|
||||
class.
|
||||
"""
|
||||
return [getattr(worker, key) for worker in self.workers]
|
||||
|
||||
def _wrap_id(
|
||||
self, id: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
if id is None:
|
||||
id = list(range(self.env_num))
|
||||
elif np.isscalar(id):
|
||||
id = [id]
|
||||
return id
|
||||
|
||||
def _assert_id(
|
||||
self, id: Optional[Union[int, List[int]]] = None) -> List[int]:
|
||||
for i in id:
|
||||
assert i not in self.waiting_id, \
|
||||
f'Cannot interact with environment {i} which is stepping now.'
|
||||
assert i in self.ready_id, \
|
||||
f'Can only interact with ready environments {self.ready_id}.'
|
||||
|
||||
def reset(self, id: Optional[Union[int, List[int]]] = None) -> np.ndarray:
|
||||
"""Reset the state of all the environments and return initial
|
||||
observations if id is ``None``, otherwise reset the specific
|
||||
environments with the given id, either an int or a list.
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
obs = np.stack([self.workers[i].reset() for i in id])
|
||||
return obs
|
||||
|
||||
def step(self,
|
||||
action: Optional[np.ndarray],
|
||||
id: Optional[Union[int, List[int]]] = None
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Run one timestep of all the environments’ dynamics if id is "None",
|
||||
otherwise run one timestep for some environments with given id, either
|
||||
an int or a list. When the end of episode is reached, you are
|
||||
responsible for calling reset(id) to reset this environment’s state.
|
||||
|
||||
Accept a batch of action and return a tuple (obs, rew, done, info).
|
||||
|
||||
:param numpy.ndarray action: a batch of action provided by the agent.
|
||||
|
||||
:return: A tuple including four items:
|
||||
|
||||
* ``obs`` a numpy.ndarray, the agent's observation of current \
|
||||
environments
|
||||
* ``rew`` a numpy.ndarray, the amount of rewards returned after \
|
||||
previous actions
|
||||
* ``done`` a numpy.ndarray, whether these episodes have ended, in \
|
||||
which case further step() calls will return undefined results
|
||||
* ``info`` a numpy.ndarray, contains auxiliary diagnostic \
|
||||
information (helpful for debugging, and sometimes learning)
|
||||
|
||||
For the async simulation:
|
||||
|
||||
Provide the given action to the environments. The action sequence
|
||||
should correspond to the ``id`` argument, and the ``id`` argument
|
||||
should be a subset of the ``env_id`` in the last returned ``info``
|
||||
(initially they are env_ids of all the environments). If action is
|
||||
``None``, fetch unfinished step() calls instead.
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if not self.is_async:
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.workers[j].send_action(action[i])
|
||||
result = [self.workers[j].get_result() for j in id]
|
||||
else:
|
||||
if action is not None:
|
||||
self._assert_id(id)
|
||||
assert len(action) == len(id)
|
||||
for i, (act, env_id) in enumerate(zip(action, id)):
|
||||
self.workers[env_id].send_action(act)
|
||||
self.waiting_conn.append(self.workers[env_id])
|
||||
self.waiting_id.append(env_id)
|
||||
self.ready_id = [x for x in self.ready_id if x not in id]
|
||||
ready_conns, result = [], []
|
||||
while not ready_conns:
|
||||
ready_conns = self.worker_class.wait(
|
||||
self.waiting_conn, self.wait_num, self.timeout)
|
||||
for conn in ready_conns:
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
obs, rew, done, info = conn.get_result()
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
return list(map(np.stack, zip(*result)))
|
||||
|
||||
def seed(self,
|
||||
seed: Optional[Union[int, List[int]]] = None) -> List[List[int]]:
|
||||
"""Set the seed for all environments.
|
||||
|
||||
Accept ``None``, an int (which will extend ``i`` to
|
||||
``[i, i + 1, i + 2, ...]``) or a list.
|
||||
|
||||
:return: The list of seeds used in this env's random number generators.
|
||||
The first value in the list should be the "main" seed, or the value
|
||||
which a reproducer pass to "seed".
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
if np.isscalar(seed):
|
||||
seed = [seed + _ for _ in range(self.env_num)]
|
||||
elif seed is None:
|
||||
seed = [seed] * self.env_num
|
||||
return [w.seed(s) for w, s in zip(self.workers, seed)]
|
||||
|
||||
def render(self, **kwargs) -> List[Any]:
|
||||
"""Render all of the environments."""
|
||||
self._assert_is_not_closed()
|
||||
if self.is_async and len(self.waiting_id) > 0:
|
||||
raise RuntimeError(
|
||||
f"Environments {self.waiting_id} are still "
|
||||
f"stepping, cannot render them now.")
|
||||
return [w.render(**kwargs) for w in self.workers]
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close all of the environments. This function will be called only
|
||||
once (if not, it will be called during garbage collected). This way,
|
||||
``close`` of all workers can be assured.
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
for w in self.workers:
|
||||
w.close()
|
||||
self.is_closed = True
|
||||
|
||||
def __del__(self) -> None:
|
||||
if not self.is_closed:
|
||||
self.close()
|
||||
|
||||
|
||||
class DummyVectorEnv(BaseVectorEnv):
|
||||
"""Dummy vectorized environment wrapper, implemented in for-loop.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
super().__init__(env_fns, DummyEnvWorker,
|
||||
wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class VectorEnv(DummyVectorEnv):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
warnings.warn(
|
||||
'VectorEnv is renamed to DummyVectorEnv, and will be removed in '
|
||||
'0.3. Use DummyVectorEnv instead!', Warning)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on subprocess.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
def worker_fn(fn):
|
||||
return SubprocEnvWorker(fn, share_memory=False)
|
||||
super().__init__(env_fns, worker_fn,
|
||||
wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class ShmemVectorEnv(BaseVectorEnv):
|
||||
"""Optimized version of SubprocVectorEnv which uses shared variables to
|
||||
communicate observations. ShmemVectorEnv has exactly the same API as
|
||||
SubprocVectorEnv.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.SubprocVectorEnv` for more
|
||||
detailed explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
def worker_fn(fn):
|
||||
return SubprocEnvWorker(fn, share_memory=True)
|
||||
super().__init__(env_fns, worker_fn,
|
||||
wait_num=wait_num, timeout=timeout)
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
"""Vectorized environment wrapper based on
|
||||
`ray <https://github.com/ray-project/ray>`_. This is a choice to run
|
||||
distributed environments in a cluster.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.env.BaseVectorEnv` for more detailed
|
||||
explanation.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns: List[Callable[[], gym.Env]],
|
||||
wait_num: Optional[int] = None,
|
||||
timeout: Optional[float] = None) -> None:
|
||||
try:
|
||||
import ray
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
'Please install ray to support RayVectorEnv: pip install ray'
|
||||
) from e
|
||||
if not ray.is_initialized():
|
||||
ray.init()
|
||||
super().__init__(env_fns, RayEnvWorker,
|
||||
wait_num=wait_num, timeout=timeout)
|
11
tianshou/env/worker/__init__.py
vendored
Normal file
11
tianshou/env/worker/__init__.py
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
from tianshou.env.worker.base import EnvWorker
|
||||
from tianshou.env.worker.dummy import DummyEnvWorker
|
||||
from tianshou.env.worker.subproc import SubprocEnvWorker
|
||||
from tianshou.env.worker.ray import RayEnvWorker
|
||||
|
||||
__all__ = [
|
||||
'EnvWorker',
|
||||
'DummyEnvWorker',
|
||||
'SubprocEnvWorker',
|
||||
'RayEnvWorker',
|
||||
]
|
64
tianshou/env/worker/base.py
vendored
Normal file
64
tianshou/env/worker/base.py
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional, Callable, Any
|
||||
|
||||
|
||||
class EnvWorker(ABC):
|
||||
"""An abstract worker for an environment."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||
self._env_fn = env_fn
|
||||
self.is_closed = False
|
||||
|
||||
@abstractmethod
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def get_result(self) -> Tuple[
|
||||
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
return self.result
|
||||
|
||||
def step(self, action: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""``send_action`` and ``get_result`` are coupled in sync simulation,
|
||||
so typically users only call ``step`` function. But they can be called
|
||||
separately in async simulation, i.e. someone calls ``send_action``
|
||||
first, and calls ``get_result`` later.
|
||||
"""
|
||||
self.send_action(action)
|
||||
return self.get_result()
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['EnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['EnvWorker']:
|
||||
"""Given a list of workers, return those ready ones."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render(self, **kwargs) -> Any:
|
||||
"""Renders the environment."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close_env(self) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
if self.is_closed:
|
||||
return None
|
||||
self.is_closed = True
|
||||
self.close_env()
|
41
tianshou/env/worker/dummy.py
vendored
Normal file
41
tianshou/env/worker/dummy.py
vendored
Normal file
@ -0,0 +1,41 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Callable, Optional, Any
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
|
||||
|
||||
class DummyEnvWorker(EnvWorker):
|
||||
"""Dummy worker used in sequential vector environments."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.env = env_fn()
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
if hasattr(self.env, key):
|
||||
return getattr(self.env, key)
|
||||
return None
|
||||
|
||||
def reset(self) -> Any:
|
||||
return self.env.reset()
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['DummyEnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['DummyEnvWorker']:
|
||||
# SequentialEnvWorker objects are always ready
|
||||
return workers
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
self.result = self.env.step(action)
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
return self.env.seed(seed) if hasattr(self.env, 'seed') else None
|
||||
|
||||
def render(self, **kwargs) -> Any:
|
||||
return self.env.render(**kwargs) \
|
||||
if hasattr(self.env, 'render') else None
|
||||
|
||||
def close_env(self) -> None:
|
||||
self.env.close()
|
54
tianshou/env/worker/ray.py
vendored
Normal file
54
tianshou/env/worker/ray.py
vendored
Normal file
@ -0,0 +1,54 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from typing import List, Callable, Tuple, Optional, Any
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
|
||||
try:
|
||||
import ray
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class RayEnvWorker(EnvWorker):
|
||||
"""Ray worker used in RayVectorEnv."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env]) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.env = ray.remote(gym.Wrapper).options(num_cpus=0).remote(env_fn())
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
return ray.get(self.env.__getattr__.remote(key))
|
||||
|
||||
def reset(self) -> Any:
|
||||
return ray.get(self.env.reset.remote())
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['RayEnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['RayEnvWorker']:
|
||||
results = [x.result for x in workers]
|
||||
ready_results, _ = ray.wait(results,
|
||||
num_returns=wait_num, timeout=timeout)
|
||||
return [workers[results.index(result)] for result in ready_results]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
# self.action is actually a handle
|
||||
self.result = self.env.step.remote(action)
|
||||
|
||||
def get_result(self) -> Tuple[
|
||||
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
return ray.get(self.result)
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
if hasattr(self.env, 'seed'):
|
||||
return ray.get(self.env.seed.remote(seed))
|
||||
return None
|
||||
|
||||
def render(self, **kwargs) -> Any:
|
||||
if hasattr(self.env, 'render'):
|
||||
return ray.get(self.env.render.remote(**kwargs))
|
||||
return None
|
||||
|
||||
def close_env(self) -> None:
|
||||
ray.get(self.env.close.remote())
|
202
tianshou/env/worker/subproc.py
vendored
Normal file
202
tianshou/env/worker/subproc.py
vendored
Normal file
@ -0,0 +1,202 @@
|
||||
import gym
|
||||
import time
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from multiprocessing.context import Process
|
||||
from multiprocessing import Array, Pipe, connection
|
||||
from typing import Callable, Any, List, Tuple, Optional
|
||||
|
||||
from tianshou.env.worker import EnvWorker
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
def _worker(parent, p, env_fn_wrapper, obs_bufs=None):
|
||||
def _encode_obs(obs, buffer):
|
||||
if isinstance(obs, np.ndarray):
|
||||
buffer.save(obs)
|
||||
elif isinstance(obs, tuple):
|
||||
for o, b in zip(obs, buffer):
|
||||
_encode_obs(o, b)
|
||||
elif isinstance(obs, dict):
|
||||
for k in obs.keys():
|
||||
_encode_obs(obs[k], buffer[k])
|
||||
return None
|
||||
|
||||
parent.close()
|
||||
env = env_fn_wrapper.data()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
cmd, data = p.recv()
|
||||
except EOFError: # the pipe has been closed
|
||||
p.close()
|
||||
break
|
||||
if cmd == 'step':
|
||||
obs, reward, done, info = env.step(data)
|
||||
if obs_bufs is not None:
|
||||
obs = _encode_obs(obs, obs_bufs)
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == 'reset':
|
||||
obs = env.reset()
|
||||
if obs_bufs is not None:
|
||||
obs = _encode_obs(obs, obs_bufs)
|
||||
p.send(obs)
|
||||
elif cmd == 'close':
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
p.send(env.render(**data) if hasattr(env, 'render') else None)
|
||||
elif cmd == 'seed':
|
||||
p.send(env.seed(data) if hasattr(env, 'seed') else None)
|
||||
elif cmd == 'getattr':
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
else:
|
||||
p.close()
|
||||
raise NotImplementedError
|
||||
except KeyboardInterrupt:
|
||||
p.close()
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool: ctypes.c_bool,
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
np.uint16: ctypes.c_uint16,
|
||||
np.uint32: ctypes.c_uint32,
|
||||
np.uint64: ctypes.c_uint64,
|
||||
np.int8: ctypes.c_int8,
|
||||
np.int16: ctypes.c_int16,
|
||||
np.int32: ctypes.c_int32,
|
||||
np.int64: ctypes.c_int64,
|
||||
np.float32: ctypes.c_float,
|
||||
np.float64: ctypes.c_double,
|
||||
}
|
||||
|
||||
|
||||
class ShArray:
|
||||
"""Wrapper of multiprocessing Array"""
|
||||
|
||||
def __init__(self, dtype, shape):
|
||||
self.arr = Array(_NP_TO_CT[dtype.type], int(np.prod(shape)))
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
|
||||
def save(self, ndarray):
|
||||
assert isinstance(ndarray, np.ndarray)
|
||||
dst = self.arr.get_obj()
|
||||
dst_np = np.frombuffer(dst, dtype=self.dtype).reshape(self.shape)
|
||||
np.copyto(dst_np, ndarray)
|
||||
|
||||
def get(self):
|
||||
return np.frombuffer(self.arr.get_obj(),
|
||||
dtype=self.dtype).reshape(self.shape)
|
||||
|
||||
|
||||
def _setup_buf(space):
|
||||
if isinstance(space, gym.spaces.Dict):
|
||||
assert isinstance(space.spaces, OrderedDict)
|
||||
buffer = {k: _setup_buf(v) for k, v in space.spaces.items()}
|
||||
elif isinstance(space, gym.spaces.Tuple):
|
||||
assert isinstance(space.spaces, tuple)
|
||||
buffer = tuple([_setup_buf(t) for t in space.spaces])
|
||||
else:
|
||||
buffer = ShArray(space.dtype, space.shape)
|
||||
return buffer
|
||||
|
||||
|
||||
class SubprocEnvWorker(EnvWorker):
|
||||
"""Subprocess worker used in SubprocVectorEnv and ShmemVectorEnv."""
|
||||
|
||||
def __init__(self, env_fn: Callable[[], gym.Env],
|
||||
share_memory=False) -> None:
|
||||
super().__init__(env_fn)
|
||||
self.parent_remote, self.child_remote = Pipe()
|
||||
self.share_memory = share_memory
|
||||
self.buffer = None
|
||||
if self.share_memory:
|
||||
dummy = env_fn()
|
||||
obs_space = dummy.observation_space
|
||||
dummy.close()
|
||||
del dummy
|
||||
self.buffer = _setup_buf(obs_space)
|
||||
args = (self.parent_remote, self.child_remote,
|
||||
CloudpickleWrapper(env_fn), self.buffer)
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
self.parent_remote.send(['getattr', key])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def _decode_obs(self, isNone):
|
||||
def decode_obs(buffer):
|
||||
if isinstance(buffer, ShArray):
|
||||
return buffer.get()
|
||||
elif isinstance(buffer, tuple):
|
||||
return tuple([decode_obs(b) for b in buffer])
|
||||
elif isinstance(buffer, dict):
|
||||
return {k: decode_obs(v) for k, v in buffer.items()}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return decode_obs(self.buffer)
|
||||
|
||||
def reset(self) -> Any:
|
||||
self.parent_remote.send(['reset', None])
|
||||
obs = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs(obs)
|
||||
return obs
|
||||
|
||||
@staticmethod
|
||||
def wait(workers: List['SubprocEnvWorker'],
|
||||
wait_num: int,
|
||||
timeout: Optional[float] = None) -> List['SubprocEnvWorker']:
|
||||
conns, ready_conns = [x.parent_remote for x in workers], []
|
||||
remain_conns = conns
|
||||
t1 = time.time()
|
||||
while len(remain_conns) > 0 and len(ready_conns) < wait_num:
|
||||
if timeout:
|
||||
remain_time = timeout - (time.time() - t1)
|
||||
if remain_time <= 0:
|
||||
break
|
||||
else:
|
||||
remain_time = timeout
|
||||
remain_conns = [conn for conn in remain_conns
|
||||
if conn not in ready_conns]
|
||||
new_ready_conns = connection.wait(
|
||||
remain_conns, timeout=remain_time)
|
||||
ready_conns.extend(new_ready_conns)
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send_action(self, action: np.ndarray) -> None:
|
||||
self.parent_remote.send(['step', action])
|
||||
|
||||
def get_result(self) -> Tuple[
|
||||
np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
obs, rew, done, info = self.parent_remote.recv()
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs(obs)
|
||||
return obs, rew, done, info
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
self.parent_remote.send(['seed', seed])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def render(self, **kwargs) -> Any:
|
||||
self.parent_remote.send(['render', kwargs])
|
||||
return self.parent_remote.recv()
|
||||
|
||||
def close_env(self) -> None:
|
||||
try:
|
||||
self.parent_remote.send(['close', None])
|
||||
# mp may be deleted so it may raise AttributeError
|
||||
self.parent_remote.recv()
|
||||
self.process.join()
|
||||
except (BrokenPipeError, EOFError, AttributeError):
|
||||
pass
|
||||
# ensure the subproc is terminated
|
||||
self.process.terminate()
|
@ -7,7 +7,7 @@ class BaseNoise(ABC, object):
|
||||
"""The action noise base class."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super(BaseNoise, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, **kwargs) -> np.ndarray:
|
||||
|
Loading…
x
Reference in New Issue
Block a user