docs for env
This commit is contained in:
parent
9380368ca3
commit
b6c9db6b0b
@ -49,7 +49,8 @@ If no error occurs, you have successfully installed Tianshou.
|
|||||||
|
|
||||||
tutorials/dqn
|
tutorials/dqn
|
||||||
tutorials/concepts
|
tutorials/concepts
|
||||||
tutorials/trick.rst
|
tutorials/trick
|
||||||
|
tutorials/tabular
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
@ -3,7 +3,7 @@ Basic concepts in Tianshou
|
|||||||
|
|
||||||
Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as:
|
Tianshou splits a Reinforcement Learning agent training procedure into these parts: trainer, collector, policy, and data buffer. The general control flow can be described as:
|
||||||
|
|
||||||
.. image:: ../_static/images/concepts_arch.png
|
.. image:: /_static/images/concepts_arch.png
|
||||||
:align: center
|
:align: center
|
||||||
:height: 300
|
:height: 300
|
||||||
|
|
||||||
|
@ -211,8 +211,9 @@ No problem! Tianshou supports user-defined training code. Here is the usage:
|
|||||||
# train policy with a sampled batch data
|
# train policy with a sampled batch data
|
||||||
losses = policy.learn(train_collector.sample(batch_size=64))
|
losses = policy.learn(train_collector.sample(batch_size=64))
|
||||||
|
|
||||||
|
For further usage, you can refer to :doc:`/tutorials/tabular`.
|
||||||
|
|
||||||
.. rubric:: References
|
.. rubric:: References
|
||||||
|
|
||||||
.. bibliography:: ../refs.bib
|
.. bibliography:: /refs.bib
|
||||||
:style: unsrtalpha
|
:style: unsrtalpha
|
||||||
|
11
docs/tutorials/tabular.rst
Normal file
11
docs/tutorials/tabular.rst
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
Tabular Q Learning Implementation
|
||||||
|
=================================
|
||||||
|
|
||||||
|
This tutorial shows how to use Tianshou to develop new algorithms.
|
||||||
|
|
||||||
|
|
||||||
|
Background
|
||||||
|
----------
|
||||||
|
|
||||||
|
TODO
|
||||||
|
|
@ -80,5 +80,5 @@ With fast-speed sampling, we could use large batch-size and large learning rate
|
|||||||
|
|
||||||
RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed.
|
RL algorithms are seed-sensitive. Try more seeds and pick the best. But for our demo, we just used seed = 0 and found it work surprisingly well on policy gradient, so we did not try other seed.
|
||||||
|
|
||||||
.. image:: ../_static/images/testpg.gif
|
.. image:: /_static/images/testpg.gif
|
||||||
:align: center
|
:align: center
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
import pytest
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tianshou.env import FrameStack, VectorEnv, SubprocVectorEnv, RayVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv, RayVectorEnv
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from env import MyTestEnv
|
from env import MyTestEnv
|
||||||
@ -9,32 +8,6 @@ else: # pytest
|
|||||||
from test.base.env import MyTestEnv
|
from test.base.env import MyTestEnv
|
||||||
|
|
||||||
|
|
||||||
def test_framestack(k=4, size=10):
|
|
||||||
env = MyTestEnv(size=size)
|
|
||||||
fsenv = FrameStack(env, k)
|
|
||||||
fsenv.seed()
|
|
||||||
obs = fsenv.reset()
|
|
||||||
assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0
|
|
||||||
for i in range(5):
|
|
||||||
obs, rew, done, info = fsenv.step(1)
|
|
||||||
assert abs(obs - np.array([2, 3, 4, 5])).sum() == 0
|
|
||||||
for i in range(10):
|
|
||||||
obs, rew, done, info = fsenv.step(0)
|
|
||||||
assert abs(obs - np.array([0, 0, 0, 0])).sum() == 0
|
|
||||||
for i in range(9):
|
|
||||||
obs, rew, done, info = fsenv.step(1)
|
|
||||||
assert abs(obs - np.array([6, 7, 8, 9])).sum() == 0
|
|
||||||
assert (rew, done) == (0, False)
|
|
||||||
obs, rew, done, info = fsenv.step(1)
|
|
||||||
assert abs(obs - np.array([7, 8, 9, 10])).sum() == 0
|
|
||||||
assert (rew, done) == (1, True)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
obs, rew, done, info = fsenv.step(0)
|
|
||||||
# assert abs(obs - np.array([8, 9, 10, 10])).sum() == 0
|
|
||||||
# assert (rew, done) == (0, True)
|
|
||||||
fsenv.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_vecenv(size=10, num=8, sleep=0.001):
|
def test_vecenv(size=10, num=8, sleep=0.001):
|
||||||
verbose = __name__ == '__main__'
|
verbose = __name__ == '__main__'
|
||||||
env_fns = [
|
env_fns = [
|
||||||
@ -86,5 +59,4 @@ def test_vecenv(size=10, num=8, sleep=0.001):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_framestack()
|
|
||||||
test_vecenv()
|
test_vecenv()
|
||||||
|
@ -67,9 +67,7 @@ class Batch(object):
|
|||||||
self.__dict__.update(kwargs)
|
self.__dict__.update(kwargs)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""
|
"""Return self[index]."""
|
||||||
Return self[index].
|
|
||||||
"""
|
|
||||||
b = Batch()
|
b = Batch()
|
||||||
for k in self.__dict__.keys():
|
for k in self.__dict__.keys():
|
||||||
if self.__dict__[k] is not None:
|
if self.__dict__[k] is not None:
|
||||||
@ -77,9 +75,7 @@ class Batch(object):
|
|||||||
return b
|
return b
|
||||||
|
|
||||||
def append(self, batch):
|
def append(self, batch):
|
||||||
"""
|
"""Append a :class:`~tianshou.data.Batch` object to the end."""
|
||||||
Append a :class:`~tianshou.data.Batch` object to the end.
|
|
||||||
"""
|
|
||||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||||
for k in batch.__dict__.keys():
|
for k in batch.__dict__.keys():
|
||||||
if batch.__dict__[k] is None:
|
if batch.__dict__[k] is None:
|
||||||
@ -101,9 +97,7 @@ class Batch(object):
|
|||||||
raise TypeError(s)
|
raise TypeError(s)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""Return len(self)."""
|
||||||
Return len(self).
|
|
||||||
"""
|
|
||||||
return min([
|
return min([
|
||||||
len(self.__dict__[k]) for k in self.__dict__.keys()
|
len(self.__dict__[k]) for k in self.__dict__.keys()
|
||||||
if self.__dict__[k] is not None])
|
if self.__dict__[k] is not None])
|
||||||
|
@ -43,14 +43,8 @@ class ReplayBuffer(object):
|
|||||||
self._maxsize = size
|
self._maxsize = size
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
for k in list(self.__dict__.keys()):
|
|
||||||
del self.__dict__[k]
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""
|
"""Return len(self)."""
|
||||||
Return len(self).
|
|
||||||
"""
|
|
||||||
return self._size
|
return self._size
|
||||||
|
|
||||||
def _add_to_buffer(self, name, inst):
|
def _add_to_buffer(self, name, inst):
|
||||||
@ -70,9 +64,7 @@ class ReplayBuffer(object):
|
|||||||
self.__dict__[name][self._index] = inst
|
self.__dict__[name][self._index] = inst
|
||||||
|
|
||||||
def update(self, buffer):
|
def update(self, buffer):
|
||||||
"""
|
"""Move the data from the given buffer to self."""
|
||||||
Move the data from the given buffer to self.
|
|
||||||
"""
|
|
||||||
i = begin = buffer._index % len(buffer)
|
i = begin = buffer._index % len(buffer)
|
||||||
while True:
|
while True:
|
||||||
self.add(
|
self.add(
|
||||||
@ -83,9 +75,7 @@ class ReplayBuffer(object):
|
|||||||
break
|
break
|
||||||
|
|
||||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
|
||||||
'''
|
"""Add a batch of data into replay buffer."""
|
||||||
Add a batch of data into replay buffer.
|
|
||||||
'''
|
|
||||||
assert isinstance(info, dict), \
|
assert isinstance(info, dict), \
|
||||||
'You should return a dict in the last argument of env.step().'
|
'You should return a dict in the last argument of env.step().'
|
||||||
self._add_to_buffer('obs', obs)
|
self._add_to_buffer('obs', obs)
|
||||||
@ -101,9 +91,7 @@ class ReplayBuffer(object):
|
|||||||
self._size = self._index = self._index + 1
|
self._size = self._index = self._index + 1
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""Clear all the data in replay buffer."""
|
||||||
Clear all the data in replay buffer.
|
|
||||||
"""
|
|
||||||
self._index = self._size = 0
|
self._index = self._size = 0
|
||||||
self.indice = []
|
self.indice = []
|
||||||
|
|
||||||
@ -123,9 +111,7 @@ class ReplayBuffer(object):
|
|||||||
return self[indice], indice
|
return self[indice], indice
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""
|
"""Return a data batch: self[index]."""
|
||||||
Return a data batch: self[index].
|
|
||||||
"""
|
|
||||||
return Batch(
|
return Batch(
|
||||||
obs=self.obs[index],
|
obs=self.obs[index],
|
||||||
act=self.act[index],
|
act=self.act[index],
|
||||||
|
5
tianshou/env/__init__.py
vendored
5
tianshou/env/__init__.py
vendored
@ -1,14 +1,9 @@
|
|||||||
from tianshou.env.utils import CloudpickleWrapper
|
|
||||||
from tianshou.env.common import EnvWrapper, FrameStack
|
|
||||||
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
|
from tianshou.env.vecenv import BaseVectorEnv, VectorEnv, \
|
||||||
SubprocVectorEnv, RayVectorEnv
|
SubprocVectorEnv, RayVectorEnv
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'EnvWrapper',
|
|
||||||
'FrameStack',
|
|
||||||
'BaseVectorEnv',
|
'BaseVectorEnv',
|
||||||
'VectorEnv',
|
'VectorEnv',
|
||||||
'SubprocVectorEnv',
|
'SubprocVectorEnv',
|
||||||
'RayVectorEnv',
|
'RayVectorEnv',
|
||||||
'CloudpickleWrapper',
|
|
||||||
]
|
]
|
||||||
|
49
tianshou/env/common.py
vendored
49
tianshou/env/common.py
vendored
@ -1,49 +0,0 @@
|
|||||||
import numpy as np
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
|
|
||||||
class EnvWrapper(object):
|
|
||||||
def __init__(self, env):
|
|
||||||
self.env = env
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
return self.env.step(action)
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
return self.env.reset()
|
|
||||||
|
|
||||||
def seed(self, seed=None):
|
|
||||||
if hasattr(self.env, 'seed'):
|
|
||||||
return self.env.seed(seed)
|
|
||||||
|
|
||||||
def render(self, **kwargs):
|
|
||||||
if hasattr(self.env, 'render'):
|
|
||||||
return self.env.render(**kwargs)
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
self.env.close()
|
|
||||||
|
|
||||||
|
|
||||||
class FrameStack(EnvWrapper):
|
|
||||||
def __init__(self, env, stack_num):
|
|
||||||
"""Stack last k frames."""
|
|
||||||
super().__init__(env)
|
|
||||||
self.stack_num = stack_num
|
|
||||||
self._frames = deque([], maxlen=stack_num)
|
|
||||||
|
|
||||||
def step(self, action):
|
|
||||||
obs, reward, done, info = self.env.step(action)
|
|
||||||
self._frames.append(obs)
|
|
||||||
return self._get_obs(), reward, done, info
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
obs = self.env.reset()
|
|
||||||
for _ in range(self.stack_num):
|
|
||||||
self._frames.append(obs)
|
|
||||||
return self._get_obs()
|
|
||||||
|
|
||||||
def _get_obs(self):
|
|
||||||
try:
|
|
||||||
return np.concatenate(self._frames, axis=-1)
|
|
||||||
except ValueError:
|
|
||||||
return np.stack(self._frames, axis=-1)
|
|
2
tianshou/env/utils.py
vendored
2
tianshou/env/utils.py
vendored
@ -2,6 +2,8 @@ import cloudpickle
|
|||||||
|
|
||||||
|
|
||||||
class CloudpickleWrapper(object):
|
class CloudpickleWrapper(object):
|
||||||
|
"""A cloudpickle wrapper used in :class:`~tianshou.env.SubprocVectorEnv`"""
|
||||||
|
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
|
87
tianshou/env/vecenv.py
vendored
87
tianshou/env/vecenv.py
vendored
@ -1,3 +1,4 @@
|
|||||||
|
import gym
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
@ -7,40 +8,98 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from tianshou.env import EnvWrapper, CloudpickleWrapper
|
from tianshou.env.utils import CloudpickleWrapper
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorEnv(ABC):
|
class BaseVectorEnv(ABC, gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Base class for vectorized environments wrapper. Usage:
|
||||||
|
::
|
||||||
|
|
||||||
|
env_num = 8
|
||||||
|
envs = VectorEnv([lambda: gym.make(task) for _ in range(env_num)])
|
||||||
|
|
||||||
|
It accepts a list of environment generators. In other words, an environment
|
||||||
|
generator ``efn`` of a specific task means that ``efn()`` returns the
|
||||||
|
environment of the given task, for example, ``gym.make(task)``.
|
||||||
|
|
||||||
|
All of the VectorEnv must inherit :class:`~tianshou.env.BaseVectorEnv`.
|
||||||
|
Here are some other usages:
|
||||||
|
::
|
||||||
|
|
||||||
|
envs.seed(2) # which is equal to the next line
|
||||||
|
envs.seed([2, 3, 4, 5, 6, 7, 8, 9]) # set specific seed for each env
|
||||||
|
obs = envs.reset() # reset all environments
|
||||||
|
obs = envs.reset([0, 5, 7]) # reset 3 specific environments
|
||||||
|
obs, rew, done, info = envs.step([1] * 8) # step synchronously
|
||||||
|
envs.render() # render all environments
|
||||||
|
envs.close() # close all environments
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns):
|
def __init__(self, env_fns):
|
||||||
self._env_fns = env_fns
|
self._env_fns = env_fns
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
"""Return len(self), which is the number of environments."""
|
||||||
return self.env_num
|
return self.env_num
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def reset(self):
|
def reset(self, id=None):
|
||||||
|
"""
|
||||||
|
Reset the state of all the environments and returns initial
|
||||||
|
observations if id is ``None``, otherwise reset the specific
|
||||||
|
environments with given id, either an int or a list.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
"""
|
||||||
|
Run one timestep of all the environments’ dynamics. When end of episode
|
||||||
|
is reached, you are responsible for calling reset(id) to reset this
|
||||||
|
environment’s state.
|
||||||
|
|
||||||
|
Accepts a batch of action and returns a tuple (obs, rew, done, info).
|
||||||
|
|
||||||
|
:args:
|
||||||
|
action (numpy.ndarray): a batch of action provided by the agent
|
||||||
|
|
||||||
|
:return:
|
||||||
|
* obs (numpy.ndarray): agent's observation of current environments
|
||||||
|
* rew (numpy.ndarray) : amount of rewards returned after previous \
|
||||||
|
actions
|
||||||
|
* done (numpy.ndarray): whether these episodes have ended, in \
|
||||||
|
which case further step() calls will return undefined results
|
||||||
|
* info (numpy.ndarray): contains auxiliary diagnostic information \
|
||||||
|
(helpful for debugging, and sometimes learning)
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
|
"""
|
||||||
|
Set the seed for all environments. Accept ``None``, an int (which will
|
||||||
|
extend ``i`` to ``[i, i + 1, i + 2, ...]``) or a list.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def render(self, **kwargs):
|
def render(self, **kwargs):
|
||||||
|
"""Renders the environment."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def close(self):
|
def close(self):
|
||||||
|
"""Close all of the environments."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class VectorEnv(BaseVectorEnv):
|
class VectorEnv(BaseVectorEnv):
|
||||||
"""docstring for VectorEnv"""
|
"""
|
||||||
|
Dummy vectorized environment wrapper, implemented in for-loop. The usage \
|
||||||
|
is in :class:`~tianshou.env.BaseVectorEnv`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns):
|
def __init__(self, env_fns):
|
||||||
super().__init__(env_fns)
|
super().__init__(env_fns)
|
||||||
@ -85,8 +144,7 @@ class VectorEnv(BaseVectorEnv):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
for e in self.envs:
|
return [e.close() for e in self.envs]
|
||||||
e.close()
|
|
||||||
|
|
||||||
|
|
||||||
def worker(parent, p, env_fn_wrapper):
|
def worker(parent, p, env_fn_wrapper):
|
||||||
@ -100,6 +158,7 @@ def worker(parent, p, env_fn_wrapper):
|
|||||||
elif cmd == 'reset':
|
elif cmd == 'reset':
|
||||||
p.send(env.reset())
|
p.send(env.reset())
|
||||||
elif cmd == 'close':
|
elif cmd == 'close':
|
||||||
|
p.send(env.close())
|
||||||
p.close()
|
p.close()
|
||||||
break
|
break
|
||||||
elif cmd == 'render':
|
elif cmd == 'render':
|
||||||
@ -114,7 +173,10 @@ def worker(parent, p, env_fn_wrapper):
|
|||||||
|
|
||||||
|
|
||||||
class SubprocVectorEnv(BaseVectorEnv):
|
class SubprocVectorEnv(BaseVectorEnv):
|
||||||
"""docstring for SubProcVectorEnv"""
|
"""
|
||||||
|
Vectorized environment wrapper based on subprocess. The usage is in \
|
||||||
|
:class:`~tianshou.env.BaseVectorEnv`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns):
|
def __init__(self, env_fns):
|
||||||
super().__init__(env_fns)
|
super().__init__(env_fns)
|
||||||
@ -178,13 +240,20 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
return
|
return
|
||||||
for p in self.parent_remote:
|
for p in self.parent_remote:
|
||||||
p.send(['close', None])
|
p.send(['close', None])
|
||||||
|
result = [p.recv() for p in self.parent_remote]
|
||||||
self.closed = True
|
self.closed = True
|
||||||
for p in self.processes:
|
for p in self.processes:
|
||||||
p.join()
|
p.join()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class RayVectorEnv(BaseVectorEnv):
|
class RayVectorEnv(BaseVectorEnv):
|
||||||
"""docstring for RayVectorEnv"""
|
"""
|
||||||
|
Vectorized environment wrapper based on \
|
||||||
|
`ray <https://github.com/ray-project/ray>`_. However, according to our \
|
||||||
|
test, it is slower than :class:`~tianshou.env.SubprocVectorEnv`. The usage\
|
||||||
|
is in :class:`~tianshou.env.BaseVectorEnv`.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns):
|
def __init__(self, env_fns):
|
||||||
super().__init__(env_fns)
|
super().__init__(env_fns)
|
||||||
@ -195,7 +264,7 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
'Please install ray to support RayVectorEnv: pip3 install ray')
|
'Please install ray to support RayVectorEnv: pip3 install ray')
|
||||||
self.envs = [
|
self.envs = [
|
||||||
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
|
ray.remote(gym.Wrapper).options(num_cpus=0).remote(e())
|
||||||
for e in env_fns]
|
for e in env_fns]
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
@ -3,9 +3,7 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
def test_episode(policy, collector, test_fn, epoch, n_episode):
|
||||||
"""
|
"""A simple wrapper of testing policy in collector."""
|
||||||
A simple wrapper of testing policy in collector.
|
|
||||||
"""
|
|
||||||
collector.reset_env()
|
collector.reset_env()
|
||||||
collector.reset_buffer()
|
collector.reset_buffer()
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
@ -43,23 +43,17 @@ class MovAvg(object):
|
|||||||
return self.get()
|
return self.get()
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
"""
|
"""Get the average."""
|
||||||
Get the average.
|
|
||||||
"""
|
|
||||||
if len(self.cache) == 0:
|
if len(self.cache) == 0:
|
||||||
return 0
|
return 0
|
||||||
return np.mean(self.cache)
|
return np.mean(self.cache)
|
||||||
|
|
||||||
def mean(self):
|
def mean(self):
|
||||||
"""
|
"""Get the average. Same as :meth:`get`."""
|
||||||
Get the average. Same as :meth:`get`.
|
|
||||||
"""
|
|
||||||
return self.get()
|
return self.get()
|
||||||
|
|
||||||
def std(self):
|
def std(self):
|
||||||
"""
|
"""Get the standard deviation."""
|
||||||
Get the standard deviation.
|
|
||||||
"""
|
|
||||||
if len(self.cache) == 0:
|
if len(self.cache) == 0:
|
||||||
return 0
|
return 0
|
||||||
return np.std(self.cache)
|
return np.std(self.cache)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user