Saving and loading replay buffer with HDF5 (#261)
As mentioned in #260, this pull request is about an implementation of saving and loading the replay buffer with HDF5.
This commit is contained in:
parent
cd481423dc
commit
5d13d8a453
1
.gitignore
vendored
1
.gitignore
vendored
@ -144,3 +144,4 @@ MUJOCO_LOG.TXT
|
||||
.DS_Store
|
||||
*.zip
|
||||
*.pstats
|
||||
*.swp
|
||||
|
9
docs/bibtex.json
Normal file
9
docs/bibtex.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"cited": {
|
||||
"tutorials/dqn": [
|
||||
"DQN",
|
||||
"DDPG",
|
||||
"PPO"
|
||||
]
|
||||
}
|
||||
}
|
@ -70,6 +70,7 @@ autodoc_default_options = {
|
||||
]
|
||||
)
|
||||
}
|
||||
bibtex_bibfiles = ['refs.bib']
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
|
3
setup.py
3
setup.py
@ -47,10 +47,11 @@ setup(
|
||||
install_requires=[
|
||||
"gym>=0.15.4",
|
||||
"tqdm",
|
||||
"numpy",
|
||||
"numpy!=1.16.0", # https://github.com/numpy/numpy/issues/12793
|
||||
"tensorboard",
|
||||
"torch>=1.4.0",
|
||||
"numba>=0.51.0",
|
||||
"h5py>=3.1.0"
|
||||
],
|
||||
extras_require={
|
||||
"dev": [
|
||||
|
@ -1,11 +1,15 @@
|
||||
import os
|
||||
import torch
|
||||
import pickle
|
||||
import pytest
|
||||
import tempfile
|
||||
import h5py
|
||||
import numpy as np
|
||||
from timeit import timeit
|
||||
|
||||
from tianshou.data import Batch, SegmentTree, \
|
||||
ReplayBuffer, ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.utils.converter import to_hdf5
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv
|
||||
@ -278,7 +282,73 @@ def test_pickle():
|
||||
pbuf.weight[np.arange(len(pbuf))])
|
||||
|
||||
|
||||
def test_hdf5():
|
||||
size = 100
|
||||
buffers = {
|
||||
"array": ReplayBuffer(size, stack_num=2),
|
||||
"list": ListReplayBuffer(),
|
||||
"prioritized": PrioritizedReplayBuffer(size, 0.6, 0.4)
|
||||
}
|
||||
buffer_types = {k: b.__class__ for k, b in buffers.items()}
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
rew = torch.tensor([1.]).to(device)
|
||||
for i in range(4):
|
||||
kwargs = {
|
||||
'obs': Batch(index=np.array([i])),
|
||||
'act': i,
|
||||
'rew': rew,
|
||||
'done': 0,
|
||||
'info': {"number": {"n": i}, 'extra': None},
|
||||
}
|
||||
buffers["array"].add(**kwargs)
|
||||
buffers["list"].add(**kwargs)
|
||||
buffers["prioritized"].add(weight=np.random.rand(), **kwargs)
|
||||
|
||||
# save
|
||||
paths = {}
|
||||
for k, buf in buffers.items():
|
||||
f, path = tempfile.mkstemp(suffix='.hdf5')
|
||||
os.close(f)
|
||||
buf.save_hdf5(path)
|
||||
paths[k] = path
|
||||
|
||||
# load replay buffer
|
||||
_buffers = {k: buffer_types[k].load_hdf5(paths[k]) for k in paths.keys()}
|
||||
|
||||
# compare
|
||||
for k in buffers.keys():
|
||||
assert len(_buffers[k]) == len(buffers[k])
|
||||
assert np.allclose(_buffers[k].act, buffers[k].act)
|
||||
assert _buffers[k].stack_num == buffers[k].stack_num
|
||||
assert _buffers[k]._maxsize == buffers[k]._maxsize
|
||||
assert _buffers[k]._index == buffers[k]._index
|
||||
assert np.all(_buffers[k]._indices == buffers[k]._indices)
|
||||
for k in ["array", "prioritized"]:
|
||||
assert isinstance(buffers[k].get(0, "info"), Batch)
|
||||
assert isinstance(_buffers[k].get(0, "info"), Batch)
|
||||
for k in ["array"]:
|
||||
assert np.all(
|
||||
buffers[k][:].info.number.n == _buffers[k][:].info.number.n)
|
||||
assert np.all(
|
||||
buffers[k][:].info.extra == _buffers[k][:].info.extra)
|
||||
|
||||
for path in paths.values():
|
||||
os.remove(path)
|
||||
|
||||
# raise exception when value cannot be pickled
|
||||
data = {"not_supported": lambda x: x*x}
|
||||
grp = h5py.Group
|
||||
with pytest.raises(NotImplementedError):
|
||||
to_hdf5(data, grp)
|
||||
# ndarray with data type not supported by HDF5 that cannot be pickled
|
||||
data = {"not_supported": np.array(lambda x: x*x)}
|
||||
grp = h5py.Group
|
||||
with pytest.raises(RuntimeError):
|
||||
to_hdf5(data, grp)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hdf5()
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
test_stack()
|
||||
|
@ -1,10 +1,12 @@
|
||||
import h5py
|
||||
import torch
|
||||
import numpy as np
|
||||
from numbers import Number
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
|
||||
from tianshou.data import Batch, SegmentTree, to_numpy
|
||||
from tianshou.data.batch import _create_value
|
||||
from tianshou.data import Batch, SegmentTree, to_numpy
|
||||
from tianshou.data.utils.converter import to_hdf5, from_hdf5
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
@ -38,7 +40,10 @@ class ReplayBuffer:
|
||||
>>> # but there are only three valid items, so len(buf) == 3.
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # save to file "buf.pkl"
|
||||
>>> # save to file "buf.pkl"
|
||||
>>> pickle.dump(buf, open('buf.pkl', 'wb'))
|
||||
>>> # save to HDF5 file
|
||||
>>> buf.save_hdf5('buf.hdf5')
|
||||
>>> buf2 = ReplayBuffer(size=10)
|
||||
>>> for i in range(15):
|
||||
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
|
||||
@ -54,7 +59,7 @@ class ReplayBuffer:
|
||||
0., 0., 0., 0., 0., 0., 0.])
|
||||
|
||||
>>> # get a random sample from buffer
|
||||
>>> # the batch_data is equal to buf[incide].
|
||||
>>> # the batch_data is equal to buf[indice].
|
||||
>>> batch_data, indice = buf.sample(batch_size=4)
|
||||
>>> batch_data.obs == buf[indice].obs
|
||||
array([ True, True, True, True])
|
||||
@ -63,6 +68,15 @@ class ReplayBuffer:
|
||||
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> # load complete buffer from HDF5 file
|
||||
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5')
|
||||
>>> len(buf)
|
||||
3
|
||||
>>> # load contents of HDF5 file into existing buffer
|
||||
>>> # (only possible if size of buffer and data in file match)
|
||||
>>> buf.load_contents_hdf5('buf.hdf5')
|
||||
>>> len(buf)
|
||||
3
|
||||
|
||||
:class:`~tianshou.data.ReplayBuffer` also supports frame_stack sampling
|
||||
(typically for RNN usage, see issue#19), ignoring storing the next
|
||||
@ -167,8 +181,14 @@ class ReplayBuffer:
|
||||
We need it because pickling buffer does not work out-of-the-box
|
||||
("buffer.__getattr__" is customized).
|
||||
"""
|
||||
self._indices = np.arange(state["_maxsize"])
|
||||
self.__dict__.update(state)
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
exclude = {"_indices"}
|
||||
state = {k: v for k, v in self.__dict__.items() if k not in exclude}
|
||||
return state
|
||||
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
try:
|
||||
value = self._meta.__dict__[name]
|
||||
@ -359,6 +379,21 @@ class ReplayBuffer:
|
||||
policy=self.get(index, "policy"),
|
||||
)
|
||||
|
||||
def save_hdf5(self, path: str) -> None:
|
||||
"""Save replay buffer to HDF5 file."""
|
||||
with h5py.File(path, "w") as f:
|
||||
to_hdf5(self.__getstate__(), f)
|
||||
|
||||
@classmethod
|
||||
def load_hdf5(
|
||||
cls, path: str, device: Optional[str] = None
|
||||
) -> "ReplayBuffer":
|
||||
"""Load replay buffer from HDF5 file."""
|
||||
with h5py.File(path, "r") as f:
|
||||
buf = cls.__new__(cls)
|
||||
buf.__setstate__(from_hdf5(f, device=device))
|
||||
return buf
|
||||
|
||||
|
||||
class ListReplayBuffer(ReplayBuffer):
|
||||
"""List-based replay buffer.
|
||||
|
@ -1,8 +1,10 @@
|
||||
import h5py
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from numbers import Number
|
||||
from typing import Union, Optional
|
||||
from typing import Dict, Union, Optional
|
||||
|
||||
from tianshou.data.batch import _parse_value, Batch
|
||||
|
||||
@ -80,3 +82,90 @@ def to_torch_as(
|
||||
"""
|
||||
assert isinstance(y, torch.Tensor)
|
||||
return to_torch(x, dtype=y.dtype, device=y.device)
|
||||
|
||||
|
||||
# Note: object is used as a proxy for objects that can be pickled
|
||||
# Note: mypy does not support cyclic definition currently
|
||||
Hdf5ConvertibleValues = Union[ # type: ignore
|
||||
int, float, Batch, np.ndarray, torch.Tensor, object,
|
||||
'Hdf5ConvertibleType', # type: ignore
|
||||
]
|
||||
|
||||
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
|
||||
|
||||
|
||||
def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
||||
"""Copy object into HDF5 group."""
|
||||
|
||||
def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None:
|
||||
"""Pickle, convert to numpy array and write to HDF5 dataset."""
|
||||
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
|
||||
y.create_dataset(key, data=data)
|
||||
|
||||
for k, v in x.items():
|
||||
if isinstance(v, (Batch, dict)):
|
||||
# dicts and batches are both represented by groups
|
||||
subgrp = y.create_group(k)
|
||||
if isinstance(v, Batch):
|
||||
subgrp_data = v.__getstate__()
|
||||
subgrp.attrs["__data_type__"] = "Batch"
|
||||
else:
|
||||
subgrp_data = v
|
||||
to_hdf5(subgrp_data, subgrp)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
# PyTorch tensors are written to datasets
|
||||
y.create_dataset(k, data=to_numpy(v))
|
||||
y[k].attrs["__data_type__"] = "Tensor"
|
||||
elif isinstance(v, np.ndarray):
|
||||
try:
|
||||
# NumPy arrays are written to datasets
|
||||
y.create_dataset(k, data=v)
|
||||
y[k].attrs["__data_type__"] = "ndarray"
|
||||
except TypeError:
|
||||
# If data type is not supported by HDF5 fall back to pickle.
|
||||
# This happens if dtype=object (e.g. due to entries being None)
|
||||
# and possibly in other cases like structured arrays.
|
||||
try:
|
||||
to_hdf5_via_pickle(v, y, k)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Attempted to pickle {v.__class__.__name__} due to "
|
||||
"data type not supported by HDF5 and failed."
|
||||
) from e
|
||||
y[k].attrs["__data_type__"] = "pickled_ndarray"
|
||||
elif isinstance(v, (int, float)):
|
||||
# ints and floats are stored as attributes of groups
|
||||
y.attrs[k] = v
|
||||
else: # resort to pickle for any other type of object
|
||||
try:
|
||||
to_hdf5_via_pickle(v, y, k)
|
||||
except Exception as e:
|
||||
raise NotImplementedError(
|
||||
f"No conversion to HDF5 for object of type '{type(v)}' "
|
||||
"implemented and fallback to pickle failed."
|
||||
) from e
|
||||
y[k].attrs["__data_type__"] = v.__class__.__name__
|
||||
|
||||
|
||||
def from_hdf5(
|
||||
x: h5py.Group, device: Optional[str] = None
|
||||
) -> Hdf5ConvertibleType:
|
||||
"""Restore object from HDF5 group."""
|
||||
if isinstance(x, h5py.Dataset):
|
||||
# handle datasets
|
||||
if x.attrs["__data_type__"] == "ndarray":
|
||||
y = np.array(x)
|
||||
elif x.attrs["__data_type__"] == "Tensor":
|
||||
y = torch.tensor(x, device=device)
|
||||
else:
|
||||
y = pickle.loads(x[()])
|
||||
else:
|
||||
# handle groups representing a dict or a Batch
|
||||
y = {k: v for k, v in x.attrs.items() if k != "__data_type__"}
|
||||
for k, v in x.items():
|
||||
y[k] = from_hdf5(v, device)
|
||||
if "__data_type__" in x.attrs:
|
||||
# if dictionary represents Batch, convert to Batch
|
||||
if x.attrs["__data_type__"] == "Batch":
|
||||
y = Batch(y)
|
||||
return y
|
||||
|
Loading…
x
Reference in New Issue
Block a user