DQN Atari examples (#187)
This PR aims to provide the script of Atari DQN setting: - A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps) - A general script for all atari game Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable. It also adds another parameter save_only_last_obs for replay buffer in order to save the memory. Co-authored-by: Trinkle23897 <463003665@qq.com>
@ -129,7 +129,7 @@ We select some of famous reinforcement learning platforms: 2 GitHub repos with m
|
||||
|
||||
All of the platforms use 5 different seeds for testing. We erase those trials which failed for training. The reward threshold is 195.0 in CartPole and -250.0 in Pendulum over consecutive 100 episodes' mean returns (except for PyTorch-DRL).
|
||||
|
||||
We will add results of Atari Pong / Mujoco these days.
|
||||
The Atari/Mujoco benchmark results are under [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders.
|
||||
|
||||
### Reproducible
|
||||
|
||||
|
@ -73,6 +73,12 @@ Tianshou has many short-but-efficient lines of code. For example, when we want t
|
||||
.. Jiayi: I write each line of code after quite a lot of time of consideration. Details make a difference.
|
||||
|
||||
|
||||
Atari/Mujoco Task Specific
|
||||
--------------------------
|
||||
|
||||
Please refer to `Atari examples page <https://github.com/thu-ml/tianshou/tree/master/examples/atari>`_ and `Mujoco examples page <https://github.com/thu-ml/tianshou/tree/master/examples/mujoco>`_.
|
||||
|
||||
|
||||
Finally
|
||||
-------
|
||||
|
||||
|
25
examples/atari/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# Atari General
|
||||
|
||||
The sample speed is \~3000 env step per second (\~12000 Atari frame per second in fact since we use frame_stack=4) under the normal mode (use a CNN policy and a collector, also storing data into the buffer). The main bottleneck is training the convolutional neural network.
|
||||
|
||||
The Atari env seed cannot be fixed due to the discussion [here](https://github.com/openai/gym/issues/1478), but it is not a big issue since on Atari it will always have the similar results.
|
||||
|
||||
The env wrapper is a crucial thing. Without wrappers, the agent cannot perform well enough on Atari games. Many existing RL codebases use [OpenAI wrapper](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py), but it is not the original DeepMind version ([related issue](https://github.com/openai/baselines/issues/240)). Dopamine has a different [wrapper](https://github.com/google/dopamine/blob/master/dopamine/discrete_domains/atari_lib.py) but unfortunately it cannot work very well in our codebase.
|
||||
|
||||
# DQN (single run)
|
||||
|
||||
One epoch here is equal to 100,000 env step, 100 epochs stand for 10M.
|
||||
|
||||
| task | best reward | reward curve | parameters | time cost |
|
||||
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ | ------------------- |
|
||||
| PongNoFrameskip-v4 | 20 |  | `python3 atari_dqn.py --task "PongNoFrameskip-v4" --batch_size 64` | ~30 min (~15 epoch) |
|
||||
| BreakoutNoFrameskip-v4 | 316 |  | `python3 atari_dqn.py --task "BreakoutNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| EnduroNoFrameskip-v4 | 670 |  | `python3 atari_dqn.py --task "EnduroNoFrameskip-v4 " --test_num 100` | 3~4h (100 epoch) |
|
||||
| QbertNoFrameskip-v4 | 7307 |  | `python3 atari_dqn.py --task "QbertNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| MsPacmanNoFrameskip-v4 | 2107 |  | `python3 atari_dqn.py --task "MsPacmanNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| SeaquestNoFrameskip-v4 | 2088 |  | `python3 atari_dqn.py --task "SeaquestNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
| SpaceInvadersNoFrameskip-v4 | 812.2 |  | `python3 atari_dqn.py --task "SpaceInvadersNoFrameskip-v4" --test_num 100` | 3~4h (100 epoch) |
|
||||
|
||||
Note: The eps_train_final and eps_test in the original DQN paper is 0.1 and 0.01, but [some works](https://github.com/google/dopamine/tree/master/baselines) found that smaller eps helps improve the performance. Also, a large batchsize (say 64 instead of 32) will help faster convergence but will slow down the training speed.
|
||||
|
||||
We haven't tuned this result to the best, so have fun with playing these hyperparameters!
|
147
examples/atari/atari_dqn.py
Normal file
@ -0,0 +1,147 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DQNPolicy
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.discrete import DQN
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, ReplayBuffer
|
||||
|
||||
from atari_wrapper import wrap_deepmind
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps_test', type=float, default=0.005)
|
||||
parser.add_argument('--eps_train', type=float, default=1.)
|
||||
parser.add_argument('--eps_train_final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=100000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--n_step', type=int, default=3)
|
||||
parser.add_argument('--target_update_freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step_per_epoch', type=int, default=10000)
|
||||
parser.add_argument('--collect_per_step', type=int, default=10)
|
||||
parser.add_argument('--batch_size', type=int, default=32)
|
||||
parser.add_argument('--training_num', type=int, default=16)
|
||||
parser.add_argument('--test_num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames_stack', type=int, default=4)
|
||||
parser.add_argument('--resume_path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def make_atari_env(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
||||
|
||||
|
||||
def make_atari_env_watch(args):
|
||||
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
||||
episode_life=False, clip_rewards=False)
|
||||
|
||||
|
||||
def test_dqn(args=get_args()):
|
||||
env = make_atari_env(args)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape: ", args.state_shape)
|
||||
print("Actions shape: ", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
||||
for _ in range(args.training_num)])
|
||||
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
||||
for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
net = DQN(*args.state_shape,
|
||||
args.action_shape, args.device).to(args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = DQNPolicy(net, optim, args.gamma, args.n_step,
|
||||
target_update_freq=args.target_update_freq)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
|
||||
save_last_obs=True, stack_num=args.frames_stack)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
||||
writer = SummaryWriter(log_path)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(x):
|
||||
if env.env.spec.reward_threshold:
|
||||
return x >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return x >= 20
|
||||
|
||||
def train_fn(x):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
now = x * args.collect_per_step * args.step_per_epoch
|
||||
if now <= 1e6:
|
||||
eps = args.eps_train - now / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
policy.set_eps(eps)
|
||||
else:
|
||||
policy.set_eps(args.eps_train_final)
|
||||
print("set eps =", policy.eps)
|
||||
|
||||
def test_fn(x):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Testing agent ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=[1] * args.test_num,
|
||||
render=args.render)
|
||||
pprint.pprint(result)
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * 4)
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.collect_per_step, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dqn(get_args())
|
237
examples/atari/atari_wrapper.py
Normal file
@ -0,0 +1,237 @@
|
||||
# Borrow a lot from openai baselines:
|
||||
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
||||
|
||||
import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
|
||||
class NoopResetEnv(gym.Wrapper):
|
||||
"""Sample initial states by taking random number of no-ops on reset.
|
||||
No-op is assumed to be action 0.
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
:param int noop_max: the maximum value of no-ops to run.
|
||||
"""
|
||||
|
||||
def __init__(self, env, noop_max=30):
|
||||
super().__init__(env)
|
||||
self.noop_max = noop_max
|
||||
self.noop_action = 0
|
||||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
noops = np.random.randint(1, self.noop_max + 1)
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(self.noop_action)
|
||||
if done:
|
||||
obs = self.env.reset()
|
||||
return obs
|
||||
|
||||
|
||||
class MaxAndSkipEnv(gym.Wrapper):
|
||||
"""Return only every `skip`-th frame (frameskipping) using most recent raw
|
||||
observations (for max pooling across time steps)
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
:param int skip: number of `skip`-th frame.
|
||||
"""
|
||||
|
||||
def __init__(self, env, skip=4):
|
||||
super().__init__(env)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action):
|
||||
"""Step the environment with the given action. Repeat action, sum
|
||||
reward, and max over last observations.
|
||||
"""
|
||||
obs_list, total_reward, done = [], 0., False
|
||||
for i in range(self._skip):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
obs_list.append(obs)
|
||||
total_reward += reward
|
||||
if done:
|
||||
break
|
||||
max_frame = np.max(obs_list[-2:], axis=0)
|
||||
return max_frame, total_reward, done, info
|
||||
|
||||
|
||||
class EpisodicLifeEnv(gym.Wrapper):
|
||||
"""Make end-of-life == end-of-episode, but only reset on true game over. It
|
||||
helps the value estimation.
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.lives = 0
|
||||
self.was_real_done = True
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.was_real_done = done
|
||||
# check current lives, make loss of life terminal, then update lives to
|
||||
# handle bonus lives
|
||||
lives = self.env.unwrapped.ale.lives()
|
||||
if 0 < lives < self.lives:
|
||||
# for Qbert sometimes we stay in lives == 0 condition for a few
|
||||
# frames, so its important to keep lives > 0, so that we only reset
|
||||
# once the environment is actually done.
|
||||
done = True
|
||||
self.lives = lives
|
||||
return obs, reward, done, info
|
||||
|
||||
def reset(self):
|
||||
"""Calls the Gym environment reset, only when lives are exhausted. This
|
||||
way all states are still reachable even though lives are episodic, and
|
||||
the learner need not know about any of this behind-the-scenes.
|
||||
"""
|
||||
if self.was_real_done:
|
||||
obs = self.env.reset()
|
||||
else:
|
||||
# no-op step to advance from terminal/lost life state
|
||||
obs = self.env.step(0)[0]
|
||||
self.lives = self.env.unwrapped.ale.lives()
|
||||
return obs
|
||||
|
||||
|
||||
class FireResetEnv(gym.Wrapper):
|
||||
"""Take action on reset for environments that are fixed until firing.
|
||||
Related discussion: https://github.com/openai/baselines/issues/240
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
return self.env.step(1)[0]
|
||||
|
||||
|
||||
class WarpFrame(gym.ObservationWrapper):
|
||||
"""Warp frames to 84x84 as done in the Nature paper and later work.
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.size = 84
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.min(env.observation_space.low),
|
||||
high=np.max(env.observation_space.high),
|
||||
shape=(self.size, self.size), dtype=env.observation_space.dtype)
|
||||
|
||||
def observation(self, frame):
|
||||
"""returns the current observation from a frame"""
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||
return cv2.resize(frame, (self.size, self.size),
|
||||
interpolation=cv2.INTER_AREA)
|
||||
|
||||
|
||||
class ScaledFloatFrame(gym.ObservationWrapper):
|
||||
"""Normalize observations to 0~1.
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
low = np.min(env.observation_space.low)
|
||||
high = np.max(env.observation_space.high)
|
||||
self.bias = low
|
||||
self.scale = high - low
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0., high=1., shape=env.observation_space.shape,
|
||||
dtype=np.float32)
|
||||
|
||||
def observation(self, observation):
|
||||
return (observation - self.bias) / self.scale
|
||||
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
"""clips the reward to {+1, 0, -1} by its sign.
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
"""
|
||||
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.reward_range = (-1, 1)
|
||||
|
||||
def reward(self, reward):
|
||||
"""Bin reward to {+1, 0, -1} by its sign. Note: np.sign(0) == 0."""
|
||||
return np.sign(reward)
|
||||
|
||||
|
||||
class FrameStack(gym.Wrapper):
|
||||
"""Stack n_frames last frames.
|
||||
|
||||
:param gym.Env env: the environment to wrap.
|
||||
:param int n_frames: the number of frames to stack.
|
||||
"""
|
||||
|
||||
def __init__(self, env, n_frames):
|
||||
super().__init__(env)
|
||||
self.n_frames = n_frames
|
||||
self.frames = deque([], maxlen=n_frames)
|
||||
shape = (n_frames,) + env.observation_space.shape
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=np.min(env.observation_space.low),
|
||||
high=np.max(env.observation_space.high),
|
||||
shape=shape, dtype=env.observation_space.dtype)
|
||||
|
||||
def reset(self):
|
||||
obs = self.env.reset()
|
||||
for _ in range(self.n_frames):
|
||||
self.frames.append(obs)
|
||||
return self._get_ob()
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, done, info = self.env.step(action)
|
||||
self.frames.append(obs)
|
||||
return self._get_ob(), reward, done, info
|
||||
|
||||
def _get_ob(self):
|
||||
# the original wrapper use `LazyFrames` but since we use np buffer,
|
||||
# it has no effect
|
||||
return np.stack(self.frames, axis=0)
|
||||
|
||||
|
||||
def wrap_deepmind(env_id, episode_life=True, clip_rewards=True,
|
||||
frame_stack=4, scale=False, warp_frame=True):
|
||||
"""Configure environment for DeepMind-style Atari. The observation is
|
||||
channel-first: (c, h, w) instead of (h, w, c).
|
||||
|
||||
:param str env_id: the atari environment id.
|
||||
:param bool episode_life: wrap the episode life wrapper.
|
||||
:param bool clip_rewards: wrap the reward clipping wrapper.
|
||||
:param int frame_stack: wrap the frame stacking wrapper.
|
||||
:param bool scale: wrap the scaling observation wrapper.
|
||||
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
|
||||
:return: the wrapped atari environment.
|
||||
"""
|
||||
assert 'NoFrameskip' in env_id
|
||||
env = gym.make(env_id)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if episode_life:
|
||||
env = EpisodicLifeEnv(env)
|
||||
if 'FIRE' in env.unwrapped.get_action_meanings():
|
||||
env = FireResetEnv(env)
|
||||
if warp_frame:
|
||||
env = WarpFrame(env)
|
||||
if scale:
|
||||
env = ScaledFloatFrame(env)
|
||||
if clip_rewards:
|
||||
env = ClipRewardEnv(env)
|
||||
if frame_stack:
|
||||
env = FrameStack(env, frame_stack)
|
||||
return env
|
BIN
examples/atari/results/dqn/Breakout_rew.png
Normal file
After Width: | Height: | Size: 70 KiB |
BIN
examples/atari/results/dqn/Enduro_rew.png
Normal file
After Width: | Height: | Size: 100 KiB |
BIN
examples/atari/results/dqn/MsPacman_rew.png
Normal file
After Width: | Height: | Size: 77 KiB |
BIN
examples/atari/results/dqn/Pong_rew.png
Normal file
After Width: | Height: | Size: 36 KiB |
BIN
examples/atari/results/dqn/Qbert_rew.png
Normal file
After Width: | Height: | Size: 72 KiB |
BIN
examples/atari/results/dqn/Seaquest_rew.png
Normal file
After Width: | Height: | Size: 85 KiB |
BIN
examples/atari/results/dqn/SpaceInvader_rew.png
Normal file
After Width: | Height: | Size: 51 KiB |
@ -91,11 +91,13 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
env = MyTestEnv(size)
|
||||
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
||||
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
||||
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
|
||||
obs = env.reset(1)
|
||||
for i in range(16):
|
||||
obs_next, rew, done, info = env.step(1)
|
||||
buf.add(obs, 1, rew, done, None, info)
|
||||
buf2.add(obs, 1, rew, done, None, info)
|
||||
buf3.add([None, None, obs], 1, rew, done, [None, obs], info)
|
||||
obs = obs_next
|
||||
if done:
|
||||
obs = env.reset(1)
|
||||
@ -104,6 +106,8 @@ def test_stack(size=5, bufsize=9, stack_num=4):
|
||||
[1, 1, 1, 2], [1, 1, 2, 3], [1, 2, 3, 4],
|
||||
[1, 1, 1, 1], [1, 1, 1, 2], [1, 1, 2, 3],
|
||||
[1, 2, 3, 4], [4, 4, 4, 4], [1, 1, 1, 1]])
|
||||
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs'))
|
||||
assert np.allclose(buf.get(indice, 'obs'), buf3.get(indice, 'obs_next'))
|
||||
_, indice = buf2.sample(0)
|
||||
assert indice.tolist() == [2, 6]
|
||||
_, indice = buf2.sample(1)
|
||||
|
@ -115,6 +115,9 @@ class ReplayBuffer:
|
||||
than or equal to 1, defaults to 1 (no stacking).
|
||||
:param bool ignore_obs_next: whether to store obs_next, defaults to
|
||||
``False``.
|
||||
:param bool save_only_last_obs: only save the last obs/obs_next when it has
|
||||
a shape of (timestep, ...) because of temporal stacking, defaults to
|
||||
``False``.
|
||||
:param bool sample_avail: the parameter indicating sampling only available
|
||||
index when using frame-stack sampling method, defaults to ``False``.
|
||||
This feature is not supported in Prioritized Replay Buffer currently.
|
||||
@ -122,6 +125,7 @@ class ReplayBuffer:
|
||||
|
||||
def __init__(self, size: int, stack_num: int = 1,
|
||||
ignore_obs_next: bool = False,
|
||||
save_only_last_obs: bool = False,
|
||||
sample_avail: bool = False) -> None:
|
||||
super().__init__()
|
||||
self._maxsize = size
|
||||
@ -131,6 +135,7 @@ class ReplayBuffer:
|
||||
self._avail = sample_avail and stack_num > 1
|
||||
self._avail_index = []
|
||||
self._save_s_ = not ignore_obs_next
|
||||
self._last_obs = save_only_last_obs
|
||||
self._index = 0
|
||||
self._size = 0
|
||||
self._meta = Batch()
|
||||
@ -210,6 +215,8 @@ class ReplayBuffer:
|
||||
"""Add a batch of data into replay buffer."""
|
||||
assert isinstance(info, (dict, Batch)), \
|
||||
'You should return a dict in the last argument of env.step().'
|
||||
if self._last_obs:
|
||||
obs = obs[-1]
|
||||
self._add_to_buffer('obs', obs)
|
||||
self._add_to_buffer('act', act)
|
||||
self._add_to_buffer('rew', rew)
|
||||
@ -217,6 +224,8 @@ class ReplayBuffer:
|
||||
if self._save_s_:
|
||||
if obs_next is None:
|
||||
obs_next = Batch()
|
||||
elif self._last_obs:
|
||||
obs_next = obs_next[-1]
|
||||
self._add_to_buffer('obs_next', obs_next)
|
||||
self._add_to_buffer('info', info)
|
||||
self._add_to_buffer('policy', policy)
|
||||
|