docs for env
This commit is contained in:
parent
9380368ca3
commit
b6c9db6b0b
@ -49,7 +49,8 @@ If no error occurs, you have successfully installed Tianshou.
|
||||
|
||||
tutorials/dqn
|
||||
tutorials/concepts
|
||||
tutorials/trick.rst
|
||||
tutorials/trick
|
||||
tutorials/tabular
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
@ -3,7 +3,7 @@ Basic concepts in Tianshou
|
||||
|
||||
Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as:
|
||||
|
||||
.. image:: ../_static/images/concepts_arch.png
|
||||
.. image:: /_static/images/concepts_arch.png
|
||||
:align: center
|
||||
:height: 300
|
||||
|
||||
|
@ -211,8 +211,9 @@ No problem! Tianshou supports user-defined training code. Here is the usage:
|
||||
# train policy with a sampled batch data
|
||||
losses = policy.learn(train_collector.sample(batch_size=64))
|
||||
|
||||
For further usage, you can refer to :doc:`/tutorials/tabular`.
|
||||
|
||||
.. rubric:: References
|
||||
|
||||
.. bibliography:: ../refs.bib
|
||||
.. bibliography:: /refs.bib
|
||||
:style: unsrtalpha
|
||||
|
11
docs/tutorials/tabular.rst
Normal file
11
docs/tutorials/tabular.rst
Normal file
@ -0,0 +1,11 @@
|
||||
Tabular Q Learning Implementation
|
||||
=================================
|
||||
|
||||
This tutorial shows how to use Tianshou to develop new algorithms.
|
||||
|
||||
|
||||
Background
|
||||
----------
|
||||
|
||||
TODO
|
||||
|
@ -80,5 +80,5 @@ With fast-speed sampling, we could use large batch-size and large learning rate
|
||||
|
||||
RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed.
|
||||
|
||||
.. image:: ../_static/images/testpg.gif
|
||||
.. image:: /_static/images/testpg.gif
|
||||
:align: center
|
||||
|
@ -1,7 +1,6 @@
|
||||
import time
|
||||
import pytest
|
||||
import numpy as np
|
||||
from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -9,32 +8,6 @@ else: # pytest
|
||||
from test.base.env import MyTestEnv
|
||||
|
||||
|
||||
def test_framestack(k=4, size=10):
|
||||
env = MyTestEnv(size=size)
|
||||
fsenv = FrameStack(env, k)
|
||||
fsenv.seed()
|
||||
obs = fsenv.reset()
|
||||
assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0
|
||||
for i in range(5):
|
||||
obs, rew, done, info = fsenv.step(1)
|
||||
assert abs(obs - np.array([2, 3, 4, 5])).sum() == 0
|
||||
for i in range(10):
|
||||
obs, rew, done, info = fsenv.step(0)
|
||||
assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0
|
||||
for i in range(9):
|
||||
obs, rew, done, info = fsenv.step(1)
|
||||
assert abs(obs - np.array([6, 7, 8, 9])).sum() == 0
|
||||
assert (rew, done) == (0, False)
|
||||
obs, rew, done, info = fsenv.step(1)
|
||||
assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0
|
||||
assert (rew, done) == (1, True)
|
||||
with pytest.raises(ValueError):
|
||||
obs, rew, done, info = fsenv.step(0)
|
||||
# assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0
|
||||
# assert (rew, done) == (0, True)
|
||||
fsenv.close()
|
||||
|
||||
|
||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
verbose = __name__ == '__main__'
|
||||
env_fns = [
|
||||
@ -86,5 +59,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_framestack()
|
||||
test_vecenv()
|
||||
|
@ -67,9 +67,7 @@ class Batch(object):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Return self[index].
|
||||
"""
|
||||
"""Return self[index]."""
|
||||
b = Batch()
|
||||
for k in self.__dict__.keys():
|
||||
if self.__dict__[k] is not None:
|
||||
@ -77,9 +75,7 @@ class Batch(object):
|
||||
return b
|
||||
|
||||
def append(self, batch):
|
||||
"""
|
||||
Append a :class:`~tianshou.data.Batch` object to the end.
|
||||
"""
|
||||
"""Append a :class:`~tianshou.data.Batch` object to the end."""
|
||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||
for k in batch.__dict__.keys():
|
||||
if batch.__dict__[k] is None:
|
||||
@ -101,9 +97,7 @@ class Batch(object):
|
||||
raise TypeError(s)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return len(self).
|
||||
"""
|
||||
"""Return len(self)."""
|
||||
return min([
|
||||
len(self.__dict__[k]) for k in self.__dict__.keys()
|
||||
if self.__dict__[k] is not None])
|
||||
|
@ -43,14 +43,8 @@ class ReplayBuffer(object):
|
||||
self._maxsize = size
|
||||
self.reset()
|
||||
|
||||
def __del__(self):
|
||||
for k in list(self.__dict__.keys()):
|
||||
del self.__dict__[k]
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return len(self).
|
||||
"""
|
||||
"""Return len(self)."""
|
||||
return self._size
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
@ -70,9 +64,7 @@ class ReplayBuffer(object):
|
||||
self.__dict__[name][self._index] = inst
|
||||
|
||||
def update(self, buffer):
|
||||
"""
|
||||
Move the data from the given buffer to self.
|
||||
"""
|
||||
"""Move the data from the given buffer to self."""
|
||||
i = begin = buffer._index % len(buffer)
|
||||
while True:
|
||||
self.add(
|
||||
@ -83,9 +75,7 @@ class ReplayBuffer(object):
|
||||
break
|
||||
|
||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||
'''
|
||||
Add a batch of data into replay buffer.
|
||||
'''
|
||||
"""Add a batch of data into replay buffer."""
|
||||
assert isinstance(info, dict), \
|
||||
'You should return a dict in the last argument of env.step().'
|
||||
self._add_to_buffer('obs', obs)
|
||||
@ -101,9 +91,7 @@ class ReplayBuffer(object):
|
||||
self._size = self._index = self._index + 1
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear all the data in replay buffer.
|
||||
"""
|
||||
"""Clear all the data in replay buffer."""
|
||||
self._index = self._size = 0
|
||||
self.indice = []
|
||||
|
||||
@ -123,9 +111,7 @@ class ReplayBuffer(object):
|
||||
return self[indice], indice
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Return a data batch: self[index].
|
||||
"""
|
||||
"""Return a data batch: self[index]."""
|
||||
return Batch(
|
||||
obs=self.obs[index],
|
||||
act=self.act[index],
|
||||
|
5
tianshou/env/__init__.py
vendored
5
tianshou/env/__init__.py
vendored
@ -1,14 +1,9 @@
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
from tianshou.env.common import EnvWrapper, FrameStack
|
||||
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
|
||||
SubprocVectorEnv, RayVectorEnv
|
||||
|
||||
__all__ = [
|
||||
'EnvWrapper',
|
||||
'FrameStack',
|
||||
'BaseVectorEnv',
|
||||
'VectorEnv',
|
||||
'SubprocVectorEnv',
|
||||
'RayVectorEnv',
|
||||
'CloudpickleWrapper',
|
||||
]
|
||||
|
49
tianshou/env/common.py
vendored
49
tianshou/env/common.py
vendored
@ -1,49 +0,0 @@
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
|
||||
class EnvWrapper(object):
|
||||
def __init__(self, env):
|
||||
self.env = env
|
||||
|
||||
def step(self, action):
|
||||
return self.env.step(action)
|
||||
|
||||
def reset(self):
|
||||
return self.env.reset()
|
||||
|
||||
def seed(self, seed=None):
|
||||
if hasattr(self.env, 'seed'):
|
||||
return self.env.seed(seed)
|
||||
|
||||
def render(self, **kwargs):
|
||||
if hasattr(self.env, 'render'):
|
||||
return self.env.render(**kwargs)
|
||||
|
||||
def close(self):
|
||||
self.env.close()
|
||||
|
||||
|
||||
class FrameStack(EnvWrapper):
|
||||
def __init__(self, env, stack_num):
|
||||
"""Stack last k frames."""
|
||||
super().__init__(env)
|
||||
self.stack_num = stack_num
|
||||
self._frames = deque([], maxlen=stack_num)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self._frames.append(obs)
|
||||
return self._get_obs(), reward, done, info
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
for _ in range(self.stack_num):
|
||||
self._frames.append(obs)
|
||||
return self._get_obs()
|
||||
|
||||
def _get_obs(self):
|
||||
try:
|
||||
return np.concatenate(self._frames, axis=-1)
|
||||
except ValueError:
|
||||
return np.stack(self._frames, axis=-1)
|
2
tianshou/env/utils.py
vendored
2
tianshou/env/utils.py
vendored
@ -2,6 +2,8 @@ import cloudpickle
|
||||
|
||||
|
||||
class CloudpickleWrapper(object):
|
||||
"""A cloudpickle wrapper used in :class:`~tianshou.env.SubprocVectorEnv`"""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
|
87
tianshou/env/vecenv.py
vendored
87
tianshou/env/vecenv.py
vendored
@ -1,3 +1,4 @@
|
||||
import gym
|
||||
import numpy as np
|
||||
from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, Pipe
|
||||
@ -7,40 +8,98 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
from tianshou.env import EnvWrapper, CloudpickleWrapper
|
||||
from tianshou.env.utils import CloudpickleWrapper
|
||||
|
||||
|
||||
class BaseVectorEnv(ABC):
|
||||
class BaseVectorEnv(ABC, gym.Wrapper):
|
||||
"""
|
||||
Base class for vectorized environments wrapper. Usage:
|
||||
::
|
||||
|
||||
env_num = 8
|
||||
envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||
|
||||
It accepts a list of environment generators. In other words, an environment
|
||||
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||
environment of the given task, for example, ``gym.make(task)``.
|
||||
|
||||
All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
|
||||
Here are some other usages:
|
||||
::
|
||||
|
||||
envs.seed(2) # which is equal to the next line
|
||||
envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
|
||||
obs = envs.reset() # reset all environments
|
||||
obs = envs.reset([0, 5, 7]) # reset 3 specific environments
|
||||
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||
envs.render() # render all environments
|
||||
envs.close() # close all environments
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
self._env_fns = env_fns
|
||||
self.env_num = len(env_fns)
|
||||
|
||||
def __len__(self):
|
||||
"""Return len(self), which is the number of environments."""
|
||||
return self.env_num
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
def reset(self, id=None):
|
||||
"""
|
||||
Reset the state of all the environments and returns initial
|
||||
observations if id is ``None``, otherwise reset the specific
|
||||
environments with given id, either an int or a list.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(self, action):
|
||||
"""
|
||||
Run one timestep of all the environments’ dynamics. When end of episode
|
||||
is reached, you are responsible for calling reset(id) to reset this
|
||||
environment’s state.
|
||||
|
||||
Accepts a batch of action and returns a tuple (obs, rew, done, info).
|
||||
|
||||
:args:
|
||||
action (numpy.ndarray): a batch of action provided by the agent
|
||||
|
||||
:return:
|
||||
* obs (numpy.ndarray): agent's observation of current environments
|
||||
* rew (numpy.ndarray) : amount of rewards returned after previous \
|
||||
actions
|
||||
* done (numpy.ndarray): whether these episodes have ended, in \
|
||||
which case further step() calls will return undefined results
|
||||
* info (numpy.ndarray): contains auxiliary diagnostic information \
|
||||
(helpful for debugging, and sometimes learning)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def seed(self, seed=None):
|
||||
"""
|
||||
Set the seed for all environments. Accept ``None``, an int (which will
|
||||
extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def render(self, **kwargs):
|
||||
"""Renders the environment."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self):
|
||||
"""Close all of the environments."""
|
||||
pass
|
||||
|
||||
|
||||
class VectorEnv(BaseVectorEnv):
|
||||
"""docstring for VectorEnv"""
|
||||
"""
|
||||
Dummy vectorized environment wrapper, implemented in for-loop. The usage \
|
||||
is in :class:`~tianshou.env.BaseVectorEnv`.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
super().__init__(env_fns)
|
||||
@ -85,8 +144,7 @@ class VectorEnv(BaseVectorEnv):
|
||||
return result
|
||||
|
||||
def close(self):
|
||||
for e in self.envs:
|
||||
e.close()
|
||||
return [e.close() for e in self.envs]
|
||||
|
||||
|
||||
def worker(parent, p, env_fn_wrapper):
|
||||
@ -100,6 +158,7 @@ def worker(parent, p, env_fn_wrapper):
|
||||
elif cmd == 'reset':
|
||||
p.send(env.reset())
|
||||
elif cmd == 'close':
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
break
|
||||
elif cmd == 'render':
|
||||
@ -114,7 +173,10 @@ def worker(parent, p, env_fn_wrapper):
|
||||
|
||||
|
||||
class SubprocVectorEnv(BaseVectorEnv):
|
||||
"""docstring for SubProcVectorEnv"""
|
||||
"""
|
||||
Vectorized environment wrapper based on subprocess. The usage is in \
|
||||
:class:`~tianshou.env.BaseVectorEnv`.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
super().__init__(env_fns)
|
||||
@ -178,13 +240,20 @@ class SubprocVectorEnv(BaseVectorEnv):
|
||||
return
|
||||
for p in self.parent_remote:
|
||||
p.send(['close', None])
|
||||
result = [p.recv() for p in self.parent_remote]
|
||||
self.closed = True
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
return result
|
||||
|
||||
|
||||
class RayVectorEnv(BaseVectorEnv):
|
||||
"""docstring for RayVectorEnv"""
|
||||
"""
|
||||
Vectorized environment wrapper based on \
|
||||
`ray <https://github.com/ray-project/ray>`_. However, according to our \
|
||||
test, it is slower than :class:`~tianshou.env.SubprocVectorEnv`. The usage\
|
||||
is in :class:`~tianshou.env.BaseVectorEnv`.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns):
|
||||
super().__init__(env_fns)
|
||||
@ -195,7 +264,7 @@ class RayVectorEnv(BaseVectorEnv):
|
||||
raise ImportError(
|
||||
'Please install ray to support RayVectorEnv: pip3 install ray')
|
||||
self.envs = [
|
||||
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
|
||||
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||
for e in env_fns]
|
||||
|
||||
def step(self, action):
|
||||
|
@ -3,9 +3,7 @@ import numpy as np
|
||||
|
||||
|
||||
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||
"""
|
||||
A simple wrapper of testing policy in collector.
|
||||
"""
|
||||
"""A simple wrapper of testing policy in collector."""
|
||||
collector.reset_env()
|
||||
collector.reset_buffer()
|
||||
policy.eval()
|
||||
|
@ -43,23 +43,17 @@ class MovAvg(object):
|
||||
return self.get()
|
||||
|
||||
def get(self):
|
||||
"""
|
||||
Get the average.
|
||||
"""
|
||||
"""Get the average."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.mean(self.cache)
|
||||
|
||||
def mean(self):
|
||||
"""
|
||||
Get the average. Same as :meth:`get`.
|
||||
"""
|
||||
"""Get the average. Same as :meth:`get`."""
|
||||
return self.get()
|
||||
|
||||
def std(self):
|
||||
"""
|
||||
Get the standard deviation.
|
||||
"""
|
||||
"""Get the standard deviation."""
|
||||
if len(self.cache) == 0:
|
||||
return 0
|
||||
return np.std(self.cache)
|
||||
|
Loading…
x
Reference in New Issue
Block a user