Convert RL Unplugged Atari datasets to tianshou ReplayBuffer (#621)
This commit is contained in:
parent
7f23748347
commit
41afc2584a
@ -37,7 +37,7 @@ Tianshou provides an `offline_trainer` for offline reinforcement learning. You c
|
|||||||
|
|
||||||
## Discrete control
|
## Discrete control
|
||||||
|
|
||||||
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent. In the future, we can switch to better benchmarks such as the Atari portion of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged).
|
For discrete control, we currently use ad hoc Atari data generated from a trained QRDQN agent.
|
||||||
|
|
||||||
### Gather Data
|
### Gather Data
|
||||||
|
|
||||||
@ -100,3 +100,24 @@ We test our CRR implementation on two example tasks (different from author's ver
|
|||||||
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
| BreakoutNoFrameskip-v4 | 394.3 | 46.9 | 23.3 (epoch 12) | 76.9 (epoch 12) | `python3 atari_crr.py --task BreakoutNoFrameskip-v4 --load-buffer-name log/BreakoutNoFrameskip-v4/qrdqn/expert.hdf5 --epoch 12 --min-q-weight 50` |
|
||||||
|
|
||||||
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
|
Note that CRR itself does not work well in Atari tasks but adding CQL loss/regularizer helps.
|
||||||
|
|
||||||
|
### RL Unplugged Data
|
||||||
|
|
||||||
|
We provide a script to convert the Atari datasets of [RL Unplugged](https://github.com/deepmind/deepmind-research/tree/master/rl_unplugged) to Tianshou ReplayBuffer.
|
||||||
|
|
||||||
|
For example, the following command will download the first shard of the first run of Breakout game to `~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100` then convert it to a `tianshou.data.ReplayBuffer` and save it to `~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5` (use `--dataset-dir` and `--buffer-dir` to change the default directories):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can use it to train an agent by:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
|
||||||
|
```
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Each shard contains about 500k transitions.
|
||||||
|
- This conversion script depends on Tensorflow.
|
||||||
|
- It takes about 1 hour to process one shard on my machine. YMMV.
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|||||||
|
|
||||||
from examples.atari.atari_network import DQN
|
from examples.atari.atari_network import DQN
|
||||||
from examples.atari.atari_wrapper import make_atari_env
|
from examples.atari.atari_wrapper import make_atari_env
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.policy import DiscreteBCQPolicy
|
from tianshou.policy import DiscreteBCQPolicy
|
||||||
from tianshou.trainer import offline_trainer
|
from tianshou.trainer import offline_trainer
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
@ -59,6 +59,9 @@ def get_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
|
"--load-buffer-name", type=str, default="./expert_DQN_PongNoFrameskip-v4.hdf5"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--buffer-from-rl-unplugged", action="store_true", default=False
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
|
||||||
)
|
)
|
||||||
@ -120,7 +123,10 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
if args.load_buffer_name.endswith(".pkl"):
|
if args.load_buffer_name.endswith(".pkl"):
|
||||||
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
buffer = pickle.load(open(args.load_buffer_name, "rb"))
|
||||||
elif args.load_buffer_name.endswith(".hdf5"):
|
elif args.load_buffer_name.endswith(".hdf5"):
|
||||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
if args.buffer_from_rl_unplugged:
|
||||||
|
buffer = ReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||||
|
else:
|
||||||
|
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||||
else:
|
else:
|
||||||
print(f"Unknown buffer format: {args.load_buffer_name}")
|
print(f"Unknown buffer format: {args.load_buffer_name}")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|||||||
278
examples/offline/convert_rl_unplugged_atari.py
Executable file
278
examples/offline/convert_rl_unplugged_atari.py
Executable file
@ -0,0 +1,278 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/deepmind/deepmind-research/blob/master/rl_unplugged/atari.py
|
||||||
|
#
|
||||||
|
"""Convert Atari RL Unplugged datasets to Tianshou replay buffers.
|
||||||
|
|
||||||
|
Examples in the dataset represent SARSA transitions stored during a
|
||||||
|
DQN training run as described in https://arxiv.org/pdf/1907.04543.
|
||||||
|
|
||||||
|
For every training run we have recorded all 50 million transitions corresponding
|
||||||
|
to 200 million environment steps (4x factor because of frame skipping). There
|
||||||
|
are 5 separate datasets for each of the 45 games.
|
||||||
|
|
||||||
|
Every transition in the dataset is a tuple containing the following features:
|
||||||
|
|
||||||
|
* o_t: Observation at time t. Observations have been processed using the
|
||||||
|
canonical Atari frame processing, including 4x frame stacking. The shape
|
||||||
|
of a single observation is [84, 84, 4].
|
||||||
|
* a_t: Action taken at time t.
|
||||||
|
* r_t: Reward after a_t.
|
||||||
|
* d_t: Discount after a_t.
|
||||||
|
* o_tp1: Observation at time t+1.
|
||||||
|
* a_tp1: Action at time t+1.
|
||||||
|
* extras:
|
||||||
|
* episode_id: Episode identifier.
|
||||||
|
* episode_return: Total episode return computed using per-step [-1, 1]
|
||||||
|
clipping.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import tensorflow as tf
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from tianshou.data import Batch, ReplayBuffer
|
||||||
|
|
||||||
|
tf.config.set_visible_devices([], 'GPU')
|
||||||
|
|
||||||
|
# 9 tuning games.
|
||||||
|
TUNING_SUITE = [
|
||||||
|
"BeamRider",
|
||||||
|
"DemonAttack",
|
||||||
|
"DoubleDunk",
|
||||||
|
"IceHockey",
|
||||||
|
"MsPacman",
|
||||||
|
"Pooyan",
|
||||||
|
"RoadRunner",
|
||||||
|
"Robotank",
|
||||||
|
"Zaxxon",
|
||||||
|
]
|
||||||
|
|
||||||
|
# 36 testing games.
|
||||||
|
TESTING_SUITE = [
|
||||||
|
"Alien",
|
||||||
|
"Amidar",
|
||||||
|
"Assault",
|
||||||
|
"Asterix",
|
||||||
|
"Atlantis",
|
||||||
|
"BankHeist",
|
||||||
|
"BattleZone",
|
||||||
|
"Boxing",
|
||||||
|
"Breakout",
|
||||||
|
"Carnival",
|
||||||
|
"Centipede",
|
||||||
|
"ChopperCommand",
|
||||||
|
"CrazyClimber",
|
||||||
|
"Enduro",
|
||||||
|
"FishingDerby",
|
||||||
|
"Freeway",
|
||||||
|
"Frostbite",
|
||||||
|
"Gopher",
|
||||||
|
"Gravitar",
|
||||||
|
"Hero",
|
||||||
|
"Jamesbond",
|
||||||
|
"Kangaroo",
|
||||||
|
"Krull",
|
||||||
|
"KungFuMaster",
|
||||||
|
"NameThisGame",
|
||||||
|
"Phoenix",
|
||||||
|
"Pong",
|
||||||
|
"Qbert",
|
||||||
|
"Riverraid",
|
||||||
|
"Seaquest",
|
||||||
|
"SpaceInvaders",
|
||||||
|
"StarGunner",
|
||||||
|
"TimePilot",
|
||||||
|
"UpNDown",
|
||||||
|
"VideoPinball",
|
||||||
|
"WizardOfWor",
|
||||||
|
"YarsRevenge",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Total of 45 games.
|
||||||
|
ALL_GAMES = TUNING_SUITE + TESTING_SUITE
|
||||||
|
URL_PREFIX = "http://storage.googleapis.com/rl_unplugged/atari"
|
||||||
|
|
||||||
|
|
||||||
|
def _filename(run_id: int, shard_id: int, total_num_shards: int = 100) -> str:
|
||||||
|
return f"run_{run_id}-{shard_id:05d}-of-{total_num_shards:05d}"
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_frames(pngs: tf.Tensor) -> tf.Tensor:
|
||||||
|
"""Decode PNGs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pngs: String Tensor of size (4,) containing PNG encoded images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
4 84x84 grayscale images packed in a (84, 84, 4) uint8 Tensor.
|
||||||
|
"""
|
||||||
|
# Statically unroll png decoding
|
||||||
|
frames = [tf.image.decode_png(pngs[i], channels=1) for i in range(4)]
|
||||||
|
# NOTE: to match tianshou's convention for framestacking
|
||||||
|
frames = tf.squeeze(tf.stack(frames, axis=0))
|
||||||
|
frames.set_shape((4, 84, 84))
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tianshou_batch(
|
||||||
|
o_t: tf.Tensor,
|
||||||
|
a_t: tf.Tensor,
|
||||||
|
r_t: tf.Tensor,
|
||||||
|
d_t: tf.Tensor,
|
||||||
|
o_tp1: tf.Tensor,
|
||||||
|
a_tp1: tf.Tensor,
|
||||||
|
) -> Batch:
|
||||||
|
"""Create Tianshou batch with offline data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o_t: Observation at time t.
|
||||||
|
a_t: Action at time t.
|
||||||
|
r_t: Reward at time t.
|
||||||
|
d_t: Discount at time t.
|
||||||
|
o_tp1: Observation at time t+1.
|
||||||
|
a_tp1: Action at time t+1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tianshou.data.Batch object.
|
||||||
|
"""
|
||||||
|
return Batch(
|
||||||
|
obs=o_t.numpy(),
|
||||||
|
act=a_t.numpy(),
|
||||||
|
rew=r_t.numpy(),
|
||||||
|
done=1 - d_t.numpy(),
|
||||||
|
obs_next=o_tp1.numpy()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tf_example_to_tianshou_batch(tf_example: tf.train.Example) -> Batch:
|
||||||
|
"""Create a tianshou Batch replay sample from a TF example."""
|
||||||
|
|
||||||
|
# Parse tf.Example.
|
||||||
|
feature_description = {
|
||||||
|
"o_t": tf.io.FixedLenFeature([4], tf.string),
|
||||||
|
"o_tp1": tf.io.FixedLenFeature([4], tf.string),
|
||||||
|
"a_t": tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
"a_tp1": tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
"r_t": tf.io.FixedLenFeature([], tf.float32),
|
||||||
|
"d_t": tf.io.FixedLenFeature([], tf.float32),
|
||||||
|
"episode_id": tf.io.FixedLenFeature([], tf.int64),
|
||||||
|
"episode_return": tf.io.FixedLenFeature([], tf.float32),
|
||||||
|
}
|
||||||
|
data = tf.io.parse_single_example(tf_example, feature_description)
|
||||||
|
|
||||||
|
# Process data.
|
||||||
|
o_t = _decode_frames(data["o_t"])
|
||||||
|
o_tp1 = _decode_frames(data["o_tp1"])
|
||||||
|
a_t = tf.cast(data["a_t"], tf.int32)
|
||||||
|
a_tp1 = tf.cast(data["a_tp1"], tf.int32)
|
||||||
|
|
||||||
|
# Build tianshou Batch replay sample.
|
||||||
|
return _make_tianshou_batch(o_t, a_t, data["r_t"], data["d_t"], o_tp1, a_tp1)
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted From https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51
|
||||||
|
def download(url: str, fname: str, chunk_size=1024):
|
||||||
|
resp = requests.get(url, stream=True)
|
||||||
|
total = int(resp.headers.get('content-length', 0))
|
||||||
|
if os.path.exists(fname):
|
||||||
|
print(f"Found cached file at {fname}.")
|
||||||
|
return
|
||||||
|
with open(fname, 'wb') as ofile, tqdm(
|
||||||
|
desc=fname,
|
||||||
|
total=total,
|
||||||
|
unit='iB',
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
) as bar:
|
||||||
|
for data in resp.iter_content(chunk_size=chunk_size):
|
||||||
|
size = ofile.write(data)
|
||||||
|
bar.update(size)
|
||||||
|
|
||||||
|
|
||||||
|
def process_shard(url: str, fname: str, ofname: str) -> None:
|
||||||
|
download(url, fname)
|
||||||
|
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
|
||||||
|
buffer = ReplayBuffer(500000)
|
||||||
|
cnt = 0
|
||||||
|
for example in file_ds:
|
||||||
|
batch = _tf_example_to_tianshou_batch(example)
|
||||||
|
buffer.add(batch)
|
||||||
|
cnt += 1
|
||||||
|
if cnt % 1000 == 0:
|
||||||
|
print(f"...{cnt}", end="", flush=True)
|
||||||
|
print("\nReplayBuffer size:", len(buffer))
|
||||||
|
buffer.save_hdf5(ofname, compression="gzip")
|
||||||
|
|
||||||
|
|
||||||
|
def process_dataset(
|
||||||
|
task: str,
|
||||||
|
download_path: str,
|
||||||
|
dst_path: str,
|
||||||
|
run_id: int = 1,
|
||||||
|
shard_id: int = 0,
|
||||||
|
total_num_shards: int = 100,
|
||||||
|
) -> None:
|
||||||
|
fn = f"{task}/{_filename(run_id, shard_id, total_num_shards=total_num_shards)}"
|
||||||
|
url = f"{URL_PREFIX}/{fn}"
|
||||||
|
filepath = f"{download_path}/{fn}"
|
||||||
|
ofname = f"{dst_path}/{fn}.hdf5"
|
||||||
|
process_shard(url, filepath, ofname)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
if args.task not in ALL_GAMES:
|
||||||
|
raise KeyError(f"`{args.task}` is not in the list of games.")
|
||||||
|
fn = _filename(args.run_id, args.shard_id, total_num_shards=args.total_num_shards)
|
||||||
|
buffer_path = os.path.join(args.buffer_dir, args.task, f"{fn}.hdf5")
|
||||||
|
if os.path.exists(buffer_path):
|
||||||
|
raise IOError(f"Found existing buffer at {buffer_path}. Will not overwrite.")
|
||||||
|
args.dataset_dir = os.environ.get("RLU_DATASET_DIR", args.dataset_dir)
|
||||||
|
args.buffer_dir = os.environ.get("RLU_BUFFER_DIR", args.buffer_dir)
|
||||||
|
dataset_path = os.path.join(args.dataset_dir, args.task)
|
||||||
|
os.makedirs(dataset_path, exist_ok=True)
|
||||||
|
dst_path = os.path.join(args.buffer_dir, args.task)
|
||||||
|
os.makedirs(dst_path, exist_ok=True)
|
||||||
|
process_dataset(
|
||||||
|
args.task,
|
||||||
|
args.dataset_dir,
|
||||||
|
args.buffer_dir,
|
||||||
|
run_id=args.run_id,
|
||||||
|
shard_id=args.shard_id,
|
||||||
|
total_num_shards=args.total_num_shards
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser(usage=__doc__)
|
||||||
|
parser.add_argument("--task", required=True, help="Name of the Atari game.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--run-id",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Run id to download and convert. Value in [1..5]."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--shard-id",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Shard id to download and convert. Value in [0..99]."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--total-num-shards", type=int, default=100, help="Total number of shards."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-dir",
|
||||||
|
default=os.path.expanduser("~/.rl_unplugged/datasets"),
|
||||||
|
help="Directory for downloaded original datasets.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--buffer-dir",
|
||||||
|
default=os.path.expanduser("~/.rl_unplugged/buffers"),
|
||||||
|
help="Directory for converted replay buffers.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@ -86,10 +86,10 @@ class ReplayBuffer:
|
|||||||
), "key '{}' is reserved and cannot be assigned".format(key)
|
), "key '{}' is reserved and cannot be assigned".format(key)
|
||||||
super().__setattr__(key, value)
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
def save_hdf5(self, path: str) -> None:
|
def save_hdf5(self, path: str, compression: Optional[str] = None) -> None:
|
||||||
"""Save replay buffer to HDF5 file."""
|
"""Save replay buffer to HDF5 file."""
|
||||||
with h5py.File(path, "w") as f:
|
with h5py.File(path, "w") as f:
|
||||||
to_hdf5(self.__dict__, f)
|
to_hdf5(self.__dict__, f, compression=compression)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer":
|
def load_hdf5(cls, path: str, device: Optional[str] = None) -> "ReplayBuffer":
|
||||||
|
|||||||
@ -78,13 +78,17 @@ Hdf5ConvertibleValues = Union[ # type: ignore
|
|||||||
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
|
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
def to_hdf5(
|
||||||
|
x: Hdf5ConvertibleType, y: h5py.Group, compression: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
"""Copy object into HDF5 group."""
|
"""Copy object into HDF5 group."""
|
||||||
|
|
||||||
def to_hdf5_via_pickle(x: object, y: h5py.Group, key: str) -> None:
|
def to_hdf5_via_pickle(
|
||||||
|
x: object, y: h5py.Group, key: str, compression: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
"""Pickle, convert to numpy array and write to HDF5 dataset."""
|
"""Pickle, convert to numpy array and write to HDF5 dataset."""
|
||||||
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
|
data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
|
||||||
y.create_dataset(key, data=data)
|
y.create_dataset(key, data=data, compression=compression)
|
||||||
|
|
||||||
for k, v in x.items():
|
for k, v in x.items():
|
||||||
if isinstance(v, (Batch, dict)):
|
if isinstance(v, (Batch, dict)):
|
||||||
@ -95,22 +99,22 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
|||||||
subgrp.attrs["__data_type__"] = "Batch"
|
subgrp.attrs["__data_type__"] = "Batch"
|
||||||
else:
|
else:
|
||||||
subgrp_data = v
|
subgrp_data = v
|
||||||
to_hdf5(subgrp_data, subgrp)
|
to_hdf5(subgrp_data, subgrp, compression=compression)
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
# PyTorch tensors are written to datasets
|
# PyTorch tensors are written to datasets
|
||||||
y.create_dataset(k, data=to_numpy(v))
|
y.create_dataset(k, data=to_numpy(v), compression=compression)
|
||||||
y[k].attrs["__data_type__"] = "Tensor"
|
y[k].attrs["__data_type__"] = "Tensor"
|
||||||
elif isinstance(v, np.ndarray):
|
elif isinstance(v, np.ndarray):
|
||||||
try:
|
try:
|
||||||
# NumPy arrays are written to datasets
|
# NumPy arrays are written to datasets
|
||||||
y.create_dataset(k, data=v)
|
y.create_dataset(k, data=v, compression=compression)
|
||||||
y[k].attrs["__data_type__"] = "ndarray"
|
y[k].attrs["__data_type__"] = "ndarray"
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# If data type is not supported by HDF5 fall back to pickle.
|
# If data type is not supported by HDF5 fall back to pickle.
|
||||||
# This happens if dtype=object (e.g. due to entries being None)
|
# This happens if dtype=object (e.g. due to entries being None)
|
||||||
# and possibly in other cases like structured arrays.
|
# and possibly in other cases like structured arrays.
|
||||||
try:
|
try:
|
||||||
to_hdf5_via_pickle(v, y, k)
|
to_hdf5_via_pickle(v, y, k, compression=compression)
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Attempted to pickle {v.__class__.__name__} due to "
|
f"Attempted to pickle {v.__class__.__name__} due to "
|
||||||
@ -122,7 +126,7 @@ def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group) -> None:
|
|||||||
y.attrs[k] = v
|
y.attrs[k] = v
|
||||||
else: # resort to pickle for any other type of object
|
else: # resort to pickle for any other type of object
|
||||||
try:
|
try:
|
||||||
to_hdf5_via_pickle(v, y, k)
|
to_hdf5_via_pickle(v, y, k, compression=compression)
|
||||||
except Exception as exception:
|
except Exception as exception:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"No conversion to HDF5 for object of type '{type(v)}' "
|
f"No conversion to HDF5 for object of type '{type(v)}' "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user