Fix a bug in loading offline data (#768)
This PR fixes #766 . Co-authored-by: Yi Su <yi_su@apple.com>
This commit is contained in:
parent
7ff12b909d
commit
06aaad460e
@ -132,7 +132,7 @@ python3 convert_rl_unplugged_atari.py --task Breakout --run-id 1 --shard-id 1
|
|||||||
Then you can use it to train an agent by:
|
Then you can use it to train an agent by:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/buffers/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
|
python3 atari_bcq.py --task BreakoutNoFrameskip-v4 --load-buffer-name ~/.rl_unplugged/datasets/Breakout/run_1-00001-of-00100.hdf5 --buffer-from-rl-unplugged --epoch 12
|
||||||
```
|
```
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
|
@ -195,9 +195,8 @@ def download(url: str, fname: str, chunk_size=1024):
|
|||||||
bar.update(size)
|
bar.update(size)
|
||||||
|
|
||||||
|
|
||||||
def process_shard(url: str, fname: str, ofname: str) -> None:
|
def process_shard(url: str, fname: str, ofname: str, maxsize: int = 500000) -> None:
|
||||||
download(url, fname)
|
download(url, fname)
|
||||||
maxsize = 500000
|
|
||||||
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
|
obs = np.ndarray((maxsize, 4, 84, 84), dtype="uint8")
|
||||||
act = np.ndarray((maxsize, ), dtype="int64")
|
act = np.ndarray((maxsize, ), dtype="int64")
|
||||||
rew = np.ndarray((maxsize, ), dtype="float32")
|
rew = np.ndarray((maxsize, ), dtype="float32")
|
||||||
@ -206,6 +205,8 @@ def process_shard(url: str, fname: str, ofname: str) -> None:
|
|||||||
i = 0
|
i = 0
|
||||||
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
|
file_ds = tf.data.TFRecordDataset(fname, compression_type="GZIP")
|
||||||
for example in file_ds:
|
for example in file_ds:
|
||||||
|
if i >= maxsize:
|
||||||
|
break
|
||||||
batch = _tf_example_to_tianshou_batch(example)
|
batch = _tf_example_to_tianshou_batch(example)
|
||||||
obs[i], act[i], rew[i], done[i], obs_next[i] = (
|
obs[i], act[i], rew[i], done[i], obs_next[i] = (
|
||||||
batch.obs, batch.act, batch.rew, batch.done, batch.obs_next
|
batch.obs, batch.act, batch.rew, batch.done, batch.obs_next
|
||||||
|
@ -16,7 +16,9 @@ def load_buffer_d4rl(expert_data_task: str) -> ReplayBuffer:
|
|||||||
act=dataset["actions"],
|
act=dataset["actions"],
|
||||||
rew=dataset["rewards"],
|
rew=dataset["rewards"],
|
||||||
done=dataset["terminals"],
|
done=dataset["terminals"],
|
||||||
obs_next=dataset["next_observations"]
|
obs_next=dataset["next_observations"],
|
||||||
|
terminated=dataset["terminals"],
|
||||||
|
truncated=np.zeros(len(dataset["terminals"]))
|
||||||
)
|
)
|
||||||
return replay_buffer
|
return replay_buffer
|
||||||
|
|
||||||
@ -28,7 +30,9 @@ def load_buffer(buffer_path: str) -> ReplayBuffer:
|
|||||||
act=dataset["actions"],
|
act=dataset["actions"],
|
||||||
rew=dataset["rewards"],
|
rew=dataset["rewards"],
|
||||||
done=dataset["terminals"],
|
done=dataset["terminals"],
|
||||||
obs_next=dataset["next_observations"]
|
obs_next=dataset["next_observations"],
|
||||||
|
terminated=dataset["terminals"],
|
||||||
|
truncated=np.zeros(len(dataset["terminals"]))
|
||||||
)
|
)
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user