From 06aaad460e402bfa4ae3b631a5e7f64d0b1d5405 Mon Sep 17 00:00:00 2001 From: Yi Su Date: Thu, 3 Nov 2022 16:12:33 -0700 Subject: [PATCH] Fix a bug in loading offline data (#768) This PR fixes #766 . Co-authored-by: Yi Su --- examples/offline/README.md | 2 +- examples/offline/convert_rl_unplugged_atari.py | 5 +++-- examples/offline/utils.py | 8 ++++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/examples/offline/README.md b/examples/offline/README.md index f52fb91..0212979 100644 --- a/examples/offline/README.md +++ b/examples/offline/README.md @@ -132,7 +132,7 @@ python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1 Then you can use it to train an agent by: ```bash -python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12 +python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12 ``` Note: diff --git a/examples/offline/convert_rl_unplugged_atari.py b/examples/offline/convert_rl_unplugged_atari.py index 9fc8c25..46696f5 100755 --- a/examples/offline/convert_rl_unplugged_atari.py +++ b/examples/offline/convert_rl_unplugged_atari.py @@ -195,9 +195,8 @@ def download(url: str, fname: str, chunk_size=1024): bar.update(size) -def process_shard(url: str, fname: str, ofname: str) -> None: +def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None: download(url, fname) - maxsize = 500000 obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8") act = np.ndarray((maxsize, ), dtype="int64") rew = np.ndarray((maxsize, ), dtype="float32") @@ -206,6 +205,8 @@ def process_shard(url: str, fname: str, ofname: str) -> None: i = 0 file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP") for example in file_ds: + if i >= maxsize: + break batch = _tf_example_to_tianshou_batch(example) obs[i], act[i], rew[i], done[i], obs_next[i] = ( batch.obs, batch.act, batch.rew, batch.done, batch.obs_next diff --git a/examples/offline/utils.py b/examples/offline/utils.py index c605279..07c693c 100644 --- a/examples/offline/utils.py +++ b/examples/offline/utils.py @@ -16,7 +16,9 @@ def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer: act=dataset["actions"], rew=dataset["rewards"], done=dataset["terminals"], - obs_next=dataset["next_observations"] + obs_next=dataset["next_observations"], + terminated=dataset["terminals"], + truncated=np.zeros(len(dataset["terminals"])) ) return replay_buffer @@ -28,7 +30,9 @@ def load_buffer(buffer_path: str) -> ReplayBuffer: act=dataset["actions"], rew=dataset["rewards"], done=dataset["terminals"], - obs_next=dataset["next_observations"] + obs_next=dataset["next_observations"], + terminated=dataset["terminals"], + truncated=np.zeros(len(dataset["terminals"])) ) return buffer