Fix a bug in loading offline data (#768)
This PR fixes #766 . Co-authored-by: Yi Su <yi_su@apple.com>
This commit is contained in:
parent
7ff12b909d
commit
06aaad460e
@ -132,7 +132,7 @@ python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1
|
||||
Then you can use it to train an agent by:
|
||||
|
||||
```bash
|
||||
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
|
||||
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
|
||||
```
|
||||
|
||||
Note:
|
||||
|
@ -195,9 +195,8 @@ def download(url: str, fname: str, chunk_size=1024):
|
||||
bar.update(size)
|
||||
|
||||
|
||||
def process_shard(url: str, fname: str, ofname: str) -> None:
|
||||
def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None:
|
||||
download(url, fname)
|
||||
maxsize = 500000
|
||||
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
|
||||
act = np.ndarray((maxsize, ), dtype="int64")
|
||||
rew = np.ndarray((maxsize, ), dtype="float32")
|
||||
@ -206,6 +205,8 @@ def process_shard(url: str, fname: str, ofname: str) -> None:
|
||||
i = 0
|
||||
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
|
||||
for example in file_ds:
|
||||
if i >= maxsize:
|
||||
break
|
||||
batch = _tf_example_to_tianshou_batch(example)
|
||||
obs[i], act[i], rew[i], done[i], obs_next[i] = (
|
||||
batch.obs, batch.act, batch.rew, batch.done, batch.obs_next
|
||||
|
@ -16,7 +16,9 @@ def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
||||
act=dataset["actions"],
|
||||
rew=dataset["rewards"],
|
||||
done=dataset["terminals"],
|
||||
obs_next=dataset["next_observations"]
|
||||
obs_next=dataset["next_observations"],
|
||||
terminated=dataset["terminals"],
|
||||
truncated=np.zeros(len(dataset["terminals"]))
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@ -28,7 +30,9 @@ def load_buffer(buffer_path: str) -> ReplayBuffer:
|
||||
act=dataset["actions"],
|
||||
rew=dataset["rewards"],
|
||||
done=dataset["terminals"],
|
||||
obs_next=dataset["next_observations"]
|
||||
obs_next=dataset["next_observations"],
|
||||
terminated=dataset["terminals"],
|
||||
truncated=np.zeros(len(dataset["terminals"]))
|
||||
)
|
||||
return buffer
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user