Saving and loading replay buffer with HDF5 (#261)

As mentioned in #260, this pull request is about an implementation of saving and loading the replay buffer with HDF5.
This commit is contained in:
Nico Gürtler 2020-12-17 01:58:43 +01:00 committed by GitHub
parent cd481423dc
commit 5d13d8a453
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 211 additions and 5 deletions

1
.gitignore vendored
View File

@ -144,3 +144,4 @@ MUJOCO_LOG.TXT
.DS_Store
*.zip
*.pstats
*.swp

9
docs/bibtex.json Normal file
View File

@ -0,0 +1,9 @@
{
"cited": {
"tutorials/dqn": [
"DQN",
"DDPG",
"PPO"
]
}
}

View File

@ -70,6 +70,7 @@ autodoc_default_options = {
]
)
}
bibtex_bibfiles = ['refs.bib']
# -- Options for HTML output -------------------------------------------------

View File

@ -47,10 +47,11 @@ setup(
install_requires=[
"gym>=0.15.4",
"tqdm",
"numpy",
"numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard",
"torch>=1.4.0",
"numba>=0.51.0",
"h5py>=3.1.0"
],
extras_require={
"dev": [

View File

@ -1,11 +1,15 @@
import os
import torch
import pickle
import pytest
import tempfile
import h5py
import numpy as np
from timeit import timeit
from tianshou.data import Batch, SegmentTree, \
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.utils.converter import to_hdf5
if __name__ == '__main__':
from env import MyTestEnv
@ -278,7 +282,73 @@ def test_pickle():
pbuf.weight[np.arange(len(pbuf))])
def test_hdf5():
size = 100
buffers = {
"array": ReplayBuffer(size, stack_num=2),
"list": ListReplayBuffer(),
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
}
buffer_types = {k: b.__class__ for k, b in buffers.items()}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rew = torch.tensor([1.]).to(device)
for i in range(4):
kwargs = {
'obs': Batch(index=np.array([i])),
'act': i,
'rew': rew,
'done': 0,
'info': {"number": {"n": i}, 'extra': None},
}
buffers["array"].add(**kwargs)
buffers["list"].add(**kwargs)
buffers["prioritized"].add(weight=np.random.rand(), **kwargs)
# save
paths = {}
for k, buf in buffers.items():
f, path = tempfile.mkstemp(suffix='.hdf5')
os.close(f)
buf.save_hdf5(path)
paths[k] = path
# load replay buffer
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
# compare
for k in buffers.keys():
assert len(_buffers[k]) == len(buffers[k])
assert np.allclose(_buffers[k].act, buffers[k].act)
assert _buffers[k].stack_num == buffers[k].stack_num
assert _buffers[k]._maxsize == buffers[k]._maxsize
assert _buffers[k]._index == buffers[k]._index
assert np.all(_buffers[k]._indices == buffers[k]._indices)
for k in ["array", "prioritized"]:
assert isinstance(buffers[k].get(0, "info"), Batch)
assert isinstance(_buffers[k].get(0, "info"), Batch)
for k in ["array"]:
assert np.all(
buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
assert np.all(
buffers[k][:].info.extra == _buffers[k][:].info.extra)
for path in paths.values():
os.remove(path)
# raise exception when value cannot be pickled
data = {"not_supported": lambda x: x*x}
grp = h5py.Group
with pytest.raises(NotImplementedError):
to_hdf5(data, grp)
# ndarray with data type not supported by HDF5 that cannot be pickled
data = {"not_supported": np.array(lambda x: x*x)}
grp = h5py.Group
with pytest.raises(RuntimeError):
to_hdf5(data, grp)
if __name__ == '__main__':
test_hdf5()
test_replaybuffer()
test_ignore_obs_next()
test_stack()

View File

@ -1,10 +1,12 @@
import h5py
import torch
import numpy as np
from numbers import Number
from typing import Any, Dict, List, Tuple, Union, Optional
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.batch import _create_value
from tianshou.data import Batch, SegmentTree, to_numpy
from tianshou.data.utils.converter import to_hdf5, from_hdf5
class ReplayBuffer:
@ -38,7 +40,10 @@ class ReplayBuffer:
>>> # but there are only three valid items, so len(buf) == 3.
>>> len(buf)
3
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
>>> # save to file "buf.pkl"
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
>>> # save to HDF5 file
>>> buf.save_hdf5('buf.hdf5')
>>> buf2 = ReplayBuffer(size=10)
>>> for i in range(15):
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
@ -54,7 +59,7 @@ class ReplayBuffer:
0., 0., 0., 0., 0., 0., 0.])
>>> # get a random sample from buffer
>>> # the batch_data is equal to buf[incide].
>>> # the batch_data is equal to buf[indice].
>>> batch_data, indice = buf.sample(batch_size=4)
>>> batch_data.obs == buf[indice].obs
array([ True, True, True, True])
@ -63,6 +68,15 @@ class ReplayBuffer:
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
>>> len(buf)
3
>>> # load complete buffer from HDF5 file
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
>>> len(buf)
3
>>> # load contents of HDF5 file into existing buffer
>>> # (only possible if size of buffer and data in file match)
>>> buf.load_contents_hdf5('buf.hdf5')
>>> len(buf)
3
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
(typically for RNN usage, see issue#19), ignoring storing the next
@ -167,8 +181,14 @@ class ReplayBuffer:
We need it because pickling buffer does not work out-of-the-box
("buffer.__getattr__" is customized).
"""
self._indices = np.arange(state["_maxsize"])
self.__dict__.update(state)
def __getstate__(self) -> dict:
exclude = {"_indices"}
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
return state
def _add_to_buffer(self, name: str, inst: Any) -> None:
try:
value = self._meta.__dict__[name]
@ -359,6 +379,21 @@ class ReplayBuffer:
policy=self.get(index, "policy"),
)
def save_hdf5(self, path: str) -> None:
"""Save replay buffer to HDF5 file."""
with h5py.File(path, "w") as f:
to_hdf5(self.__getstate__(), f)
@classmethod
def load_hdf5(
cls, path: str, device: Optional[str] = None
) -> "ReplayBuffer":
"""Load replay buffer from HDF5 file."""
with h5py.File(path, "r") as f:
buf = cls.__new__(cls)
buf.__setstate__(from_hdf5(f, device=device))
return buf
class ListReplayBuffer(ReplayBuffer):
"""List-based replay buffer.

View File

@ -1,8 +1,10 @@
import h5py
import torch
import pickle
import numpy as np
from copy import deepcopy
from numbers import Number
from typing import Union, Optional
from typing import Dict, Union, Optional
from tianshou.data.batch import _parse_value, Batch
@ -80,3 +82,90 @@ def to_torch_as(
"""
assert isinstance(y, torch.Tensor)
return to_torch(x, dtype=y.dtype, device=y.device)
# Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[ # type: ignore
int, float, Batch, np.ndarray, torch.Tensor, object,
'Hdf5ConvertibleType', # type: ignore
]
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
"""Copy object into HDF5 group."""
def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None:
"""Pickle, convert to numpy array and write to HDF5 dataset."""
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
y.create_dataset(key, data=data)
for k, v in x.items():
if isinstance(v, (Batch, dict)):
# dicts and batches are both represented by groups
subgrp = y.create_group(k)
if isinstance(v, Batch):
subgrp_data = v.__getstate__()
subgrp.attrs["__data_type__"] = "Batch"
else:
subgrp_data = v
to_hdf5(subgrp_data, subgrp)
elif isinstance(v, torch.Tensor):
# PyTorch tensors are written to datasets
y.create_dataset(k, data=to_numpy(v))
y[k].attrs["__data_type__"] = "Tensor"
elif isinstance(v, np.ndarray):
try:
# NumPy arrays are written to datasets
y.create_dataset(k, data=v)
y[k].attrs["__data_type__"] = "ndarray"
except TypeError:
# If data type is not supported by HDF5 fall back to pickle.
# This happens if dtype=object (e.g. due to entries being None)
# and possibly in other cases like structured arrays.
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise RuntimeError(
f"Attempted to pickle {v.__class__.__name__} due to "
"data type not supported by HDF5 and failed."
) from e
y[k].attrs["__data_type__"] = "pickled_ndarray"
elif isinstance(v, (int, float)):
# ints and floats are stored as attributes of groups
y.attrs[k] = v
else: # resort to pickle for any other type of object
try:
to_hdf5_via_pickle(v, y, k)
except Exception as e:
raise NotImplementedError(
f"No conversion to HDF5 for object of type '{type(v)}' "
"implemented and fallback to pickle failed."
) from e
y[k].attrs["__data_type__"] = v.__class__.__name__
def from_hdf5(
x: h5py.Group, device: Optional[str] = None
) -> Hdf5ConvertibleType:
"""Restore object from HDF5 group."""
if isinstance(x, h5py.Dataset):
# handle datasets
if x.attrs["__data_type__"] == "ndarray":
y = np.array(x)
elif x.attrs["__data_type__"] == "Tensor":
y = torch.tensor(x, device=device)
else:
y = pickle.loads(x[()])
else:
# handle groups representing a dict or a Batch
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
for k, v in x.items():
y[k] = from_hdf5(v, device)
if "__data_type__" in x.attrs:
# if dictionary represents Batch, convert to Batch
if x.attrs["__data_type__"] == "Batch":
y = Batch(y)
return y