2020-12-17 01:58:43 +01:00
|
|
|
import pickle
|
2020-09-12 15:39:01 +08:00
|
|
|
from copy import deepcopy
|
2020-07-07 12:40:55 +02:00
|
|
|
from numbers import Number
|
2023-09-05 23:34:23 +02:00
|
|
|
from typing import Any, Union, no_type_check
|
2021-09-03 05:05:04 +08:00
|
|
|
|
|
|
|
import h5py
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
2020-05-29 14:45:21 +02:00
|
|
|
|
2021-09-03 05:05:04 +08:00
|
|
|
from tianshou.data.batch import Batch, _parse_value
|
2020-05-29 14:45:21 +02:00
|
|
|
|
|
|
|
|
2023-11-16 18:27:53 +01:00
|
|
|
# TODO: confusing name, could actually return a batch...
|
|
|
|
# Overrides and generic types should be added
|
2022-05-15 15:40:32 +02:00
|
|
|
@no_type_check
|
2023-09-05 23:34:23 +02:00
|
|
|
def to_numpy(x: Any) -> Batch | np.ndarray:
|
2020-05-29 14:45:21 +02:00
|
|
|
"""Return an object without torch.Tensor."""
|
2020-08-27 12:15:18 +08:00
|
|
|
if isinstance(x, torch.Tensor): # most often case
|
2020-09-12 15:39:01 +08:00
|
|
|
return x.detach().cpu().numpy()
|
2023-08-25 23:40:56 +02:00
|
|
|
if isinstance(x, np.ndarray): # second often case
|
2020-09-12 15:39:01 +08:00
|
|
|
return x
|
2023-09-05 23:34:23 +02:00
|
|
|
if isinstance(x, np.number | np.bool_ | Number):
|
2020-09-12 15:39:01 +08:00
|
|
|
return np.asanyarray(x)
|
2023-08-25 23:40:56 +02:00
|
|
|
if x is None:
|
2021-03-30 16:06:03 +08:00
|
|
|
return np.array(None, dtype=object)
|
2023-09-05 23:34:23 +02:00
|
|
|
if isinstance(x, dict | Batch):
|
2021-03-30 16:06:03 +08:00
|
|
|
x = Batch(x) if isinstance(x, dict) else deepcopy(x)
|
2020-08-27 12:15:18 +08:00
|
|
|
x.to_numpy()
|
2020-09-12 15:39:01 +08:00
|
|
|
return x
|
2023-09-05 23:34:23 +02:00
|
|
|
if isinstance(x, list | tuple):
|
2021-03-30 16:06:03 +08:00
|
|
|
return to_numpy(_parse_value(x))
|
2023-08-25 23:40:56 +02:00
|
|
|
# fallback
|
|
|
|
return np.asanyarray(x)
|
2020-05-29 14:45:21 +02:00
|
|
|
|
|
|
|
|
2022-05-15 15:40:32 +02:00
|
|
|
@no_type_check
|
2020-09-12 15:39:01 +08:00
|
|
|
def to_torch(
|
2021-03-30 16:06:03 +08:00
|
|
|
x: Any,
|
2023-09-05 23:34:23 +02:00
|
|
|
dtype: torch.dtype | None = None,
|
|
|
|
device: str | int | torch.device = "cpu",
|
|
|
|
) -> Batch | torch.Tensor:
|
2020-05-29 14:45:21 +02:00
|
|
|
"""Return an object without np.ndarray."""
|
2020-09-12 15:39:01 +08:00
|
|
|
if isinstance(x, np.ndarray) and issubclass(
|
2023-08-25 23:40:56 +02:00
|
|
|
x.dtype.type,
|
2023-09-05 23:34:23 +02:00
|
|
|
np.bool_ | np.number,
|
2020-09-12 15:39:01 +08:00
|
|
|
): # most often case
|
2022-05-15 15:40:32 +02:00
|
|
|
x = torch.from_numpy(x).to(device)
|
2020-08-27 12:15:18 +08:00
|
|
|
if dtype is not None:
|
|
|
|
x = x.type(dtype)
|
2020-09-12 15:39:01 +08:00
|
|
|
return x
|
2023-08-25 23:40:56 +02:00
|
|
|
if isinstance(x, torch.Tensor): # second often case
|
2020-05-30 15:40:31 +02:00
|
|
|
if dtype is not None:
|
|
|
|
x = x.type(dtype)
|
2022-05-15 15:40:32 +02:00
|
|
|
return x.to(device)
|
2023-09-05 23:34:23 +02:00
|
|
|
if isinstance(x, np.number | np.bool_ | Number):
|
2020-09-12 15:39:01 +08:00
|
|
|
return to_torch(np.asanyarray(x), dtype, device)
|
2023-09-05 23:34:23 +02:00
|
|
|
if isinstance(x, dict | Batch):
|
2021-03-30 16:06:03 +08:00
|
|
|
x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
|
2020-05-30 15:40:31 +02:00
|
|
|
x.to_torch(dtype, device)
|
2020-09-12 15:39:01 +08:00
|
|
|
return x
|
2023-09-05 23:34:23 +02:00
|
|
|
if isinstance(x, list | tuple):
|
2021-03-30 16:06:03 +08:00
|
|
|
return to_torch(_parse_value(x), dtype, device)
|
2023-08-25 23:40:56 +02:00
|
|
|
# fallback
|
|
|
|
raise TypeError(f"object {x} cannot be converted to torch.")
|
2020-06-03 13:59:47 +08:00
|
|
|
|
|
|
|
|
2022-05-15 15:40:32 +02:00
|
|
|
@no_type_check
|
2023-09-05 23:34:23 +02:00
|
|
|
def to_torch_as(x: Any, y: torch.Tensor) -> Batch | torch.Tensor:
|
2020-09-11 07:55:37 +08:00
|
|
|
"""Return an object without np.ndarray.
|
|
|
|
|
|
|
|
Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
|
2020-06-03 13:59:47 +08:00
|
|
|
"""
|
|
|
|
assert isinstance(y, torch.Tensor)
|
|
|
|
return to_torch(x, dtype=y.dtype, device=y.device)
|
2020-12-17 01:58:43 +01:00
|
|
|
|
|
|
|
|
|
|
|
# Note: object is used as a proxy for objects that can be pickled
|
|
|
|
# Note: mypy does not support cyclic definition currently
|
2023-08-25 23:40:56 +02:00
|
|
|
Hdf5ConvertibleValues = Union[
|
|
|
|
int,
|
|
|
|
float,
|
|
|
|
Batch,
|
|
|
|
np.ndarray,
|
|
|
|
torch.Tensor,
|
|
|
|
object,
|
|
|
|
"Hdf5ConvertibleType",
|
|
|
|
]
|
2020-12-17 01:58:43 +01:00
|
|
|
|
2023-08-25 23:40:56 +02:00
|
|
|
Hdf5ConvertibleType = dict[str, Hdf5ConvertibleValues]
|
2020-12-17 01:58:43 +01:00
|
|
|
|
|
|
|
|
2023-09-05 23:34:23 +02:00
|
|
|
def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group, compression: str | None = None) -> None:
|
2020-12-17 01:58:43 +01:00
|
|
|
"""Copy object into HDF5 group."""
|
|
|
|
|
2022-04-29 04:33:28 -07:00
|
|
|
def to_hdf5_via_pickle(
|
2023-08-25 23:40:56 +02:00
|
|
|
x: object,
|
|
|
|
y: h5py.Group,
|
|
|
|
key: str,
|
2023-09-05 23:34:23 +02:00
|
|
|
compression: str | None = None,
|
2022-04-29 04:33:28 -07:00
|
|
|
) -> None:
|
2020-12-17 01:58:43 +01:00
|
|
|
"""Pickle, convert to numpy array and write to HDF5 dataset."""
|
|
|
|
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
|
2022-04-29 04:33:28 -07:00
|
|
|
y.create_dataset(key, data=data, compression=compression)
|
2020-12-17 01:58:43 +01:00
|
|
|
|
|
|
|
for k, v in x.items():
|
2023-09-05 23:34:23 +02:00
|
|
|
if isinstance(v, Batch | dict):
|
2020-12-17 01:58:43 +01:00
|
|
|
# dicts and batches are both represented by groups
|
|
|
|
subgrp = y.create_group(k)
|
|
|
|
if isinstance(v, Batch):
|
|
|
|
subgrp_data = v.__getstate__()
|
|
|
|
subgrp.attrs["__data_type__"] = "Batch"
|
|
|
|
else:
|
|
|
|
subgrp_data = v
|
2022-04-29 04:33:28 -07:00
|
|
|
to_hdf5(subgrp_data, subgrp, compression=compression)
|
2020-12-17 01:58:43 +01:00
|
|
|
elif isinstance(v, torch.Tensor):
|
|
|
|
# PyTorch tensors are written to datasets
|
2022-04-29 04:33:28 -07:00
|
|
|
y.create_dataset(k, data=to_numpy(v), compression=compression)
|
2020-12-17 01:58:43 +01:00
|
|
|
y[k].attrs["__data_type__"] = "Tensor"
|
|
|
|
elif isinstance(v, np.ndarray):
|
|
|
|
try:
|
|
|
|
# NumPy arrays are written to datasets
|
2022-04-29 04:33:28 -07:00
|
|
|
y.create_dataset(k, data=v, compression=compression)
|
2020-12-17 01:58:43 +01:00
|
|
|
y[k].attrs["__data_type__"] = "ndarray"
|
|
|
|
except TypeError:
|
|
|
|
# If data type is not supported by HDF5 fall back to pickle.
|
|
|
|
# This happens if dtype=object (e.g. due to entries being None)
|
|
|
|
# and possibly in other cases like structured arrays.
|
|
|
|
try:
|
2022-04-29 04:33:28 -07:00
|
|
|
to_hdf5_via_pickle(v, y, k, compression=compression)
|
2022-01-30 00:53:56 +08:00
|
|
|
except Exception as exception:
|
2020-12-17 01:58:43 +01:00
|
|
|
raise RuntimeError(
|
|
|
|
f"Attempted to pickle {v.__class__.__name__} due to "
|
2023-08-25 23:40:56 +02:00
|
|
|
"data type not supported by HDF5 and failed.",
|
2022-01-30 00:53:56 +08:00
|
|
|
) from exception
|
2020-12-17 01:58:43 +01:00
|
|
|
y[k].attrs["__data_type__"] = "pickled_ndarray"
|
2023-09-05 23:34:23 +02:00
|
|
|
elif isinstance(v, int | float):
|
2020-12-17 01:58:43 +01:00
|
|
|
# ints and floats are stored as attributes of groups
|
|
|
|
y.attrs[k] = v
|
|
|
|
else: # resort to pickle for any other type of object
|
|
|
|
try:
|
2022-04-29 04:33:28 -07:00
|
|
|
to_hdf5_via_pickle(v, y, k, compression=compression)
|
2022-01-30 00:53:56 +08:00
|
|
|
except Exception as exception:
|
2020-12-17 01:58:43 +01:00
|
|
|
raise NotImplementedError(
|
|
|
|
f"No conversion to HDF5 for object of type '{type(v)}' "
|
2023-08-25 23:40:56 +02:00
|
|
|
"implemented and fallback to pickle failed.",
|
2022-01-30 00:53:56 +08:00
|
|
|
) from exception
|
2020-12-17 01:58:43 +01:00
|
|
|
y[k].attrs["__data_type__"] = v.__class__.__name__
|
|
|
|
|
|
|
|
|
2023-09-05 23:34:23 +02:00
|
|
|
def from_hdf5(x: h5py.Group, device: str | None = None) -> Hdf5ConvertibleValues:
|
2020-12-17 01:58:43 +01:00
|
|
|
"""Restore object from HDF5 group."""
|
|
|
|
if isinstance(x, h5py.Dataset):
|
|
|
|
# handle datasets
|
|
|
|
if x.attrs["__data_type__"] == "ndarray":
|
2021-03-30 16:06:03 +08:00
|
|
|
return np.array(x)
|
2023-08-25 23:40:56 +02:00
|
|
|
if x.attrs["__data_type__"] == "Tensor":
|
2021-03-30 16:06:03 +08:00
|
|
|
return torch.tensor(x, device=device)
|
2023-08-25 23:40:56 +02:00
|
|
|
return pickle.loads(x[()])
|
|
|
|
# handle groups representing a dict or a Batch
|
|
|
|
y = dict(x.attrs.items())
|
|
|
|
data_type = y.pop("__data_type__", None)
|
|
|
|
for k, v in x.items():
|
|
|
|
y[k] = from_hdf5(v, device)
|
|
|
|
return Batch(y) if data_type == "Batch" else y
|