Fix a bug in loading offline data (#768)

This PR fixes #766 .

Co-authored-by: Yi Su <yi_su@apple.com>
This commit is contained in:
Yi Su 2022-11-03 16:12:33 -07:00 committed by GitHub
parent 7ff12b909d
commit 06aaad460e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 5 deletions

View File

@ -132,7 +132,7 @@ python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1
Then you can use it to train an agent by:
```bash
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
```
Note:

View File

@ -195,9 +195,8 @@ def download(url: str, fname: str, chunk_size=1024):
bar.update(size)
def process_shard(url: str, fname: str, ofname: str) -> None:
def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None:
download(url, fname)
maxsize = 500000
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
act = np.ndarray((maxsize, ), dtype="int64")
rew = np.ndarray((maxsize, ), dtype="float32")
@ -206,6 +205,8 @@ def process_shard(url: str, fname: str, ofname: str) -> None:
i = 0
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
for example in file_ds:
if i >= maxsize:
break
batch = _tf_example_to_tianshou_batch(example)
obs[i], act[i], rew[i], done[i], obs_next[i] = (
batch.obs, batch.act, batch.rew, batch.done, batch.obs_next

View File

@ -16,7 +16,9 @@ def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"]
obs_next=dataset["next_observations"],
terminated=dataset["terminals"],
truncated=np.zeros(len(dataset["terminals"]))
)
return replay_buffer
@ -28,7 +30,9 @@ def load_buffer(buffer_path: str) -> ReplayBuffer:
act=dataset["actions"],
rew=dataset["rewards"],
done=dataset["terminals"],
obs_next=dataset["next_observations"]
obs_next=dataset["next_observations"],
terminated=dataset["terminals"],
truncated=np.zeros(len(dataset["terminals"]))
)
return buffer