Custom keys support in ReplayBuffer (#903)
Issue: Custom keys support in ReplayBuffer #902 Modified `ReplayBuffer` `add` and `__getitem__` methods. Added `test_custom_key()` to test_buffer.py
This commit is contained in:
parent
61182450b6
commit
80a698be52
@ -166,3 +166,13 @@ isort
|
|||||||
yapf
|
yapf
|
||||||
pydocstyle
|
pydocstyle
|
||||||
Args
|
Args
|
||||||
|
tuples
|
||||||
|
tuple
|
||||||
|
Multi
|
||||||
|
multi
|
||||||
|
parameterized
|
||||||
|
Proximal
|
||||||
|
metadata
|
||||||
|
GPU
|
||||||
|
Dopamine
|
||||||
|
builtin
|
||||||
|
@ -1336,6 +1336,65 @@ def test_from_data():
|
|||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_key():
|
||||||
|
batch = Batch(
|
||||||
|
**{
|
||||||
|
'obs_next':
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
1.174, -0.1151, -0.609, -0.5205, -0.9316, 3.236, -2.418, 0.386,
|
||||||
|
0.2227, -0.5117, 2.293
|
||||||
|
]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'rew':
|
||||||
|
np.array([4.28125]),
|
||||||
|
'act':
|
||||||
|
np.array([[-0.3088, -0.4636, 0.4956]]),
|
||||||
|
'truncated':
|
||||||
|
np.array([False]),
|
||||||
|
'obs':
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
1.193, -0.1203, -0.6123, -0.519, -0.9434, 3.32, -2.266, 0.9116,
|
||||||
|
0.623, 0.1259, 0.363
|
||||||
|
]
|
||||||
|
]
|
||||||
|
),
|
||||||
|
'terminated':
|
||||||
|
np.array([False]),
|
||||||
|
'done':
|
||||||
|
np.array([False]),
|
||||||
|
'returns':
|
||||||
|
np.array([74.70343082]),
|
||||||
|
'info':
|
||||||
|
Batch(),
|
||||||
|
'policy':
|
||||||
|
Batch(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
buffer_size = len(batch.rew)
|
||||||
|
buffer = ReplayBuffer(buffer_size)
|
||||||
|
buffer.add(batch)
|
||||||
|
sampled_batch, _ = buffer.sample(1)
|
||||||
|
# Check if they have the same keys
|
||||||
|
assert set(batch.keys()) == set(sampled_batch.keys()), \
|
||||||
|
"Batches have different keys: {} and {}".format(
|
||||||
|
set(batch.keys()), set(sampled_batch.keys()))
|
||||||
|
# Compare the values for each key
|
||||||
|
for key in batch.keys():
|
||||||
|
if isinstance(batch.__dict__[key], np.ndarray
|
||||||
|
) and isinstance(sampled_batch.__dict__[key], np.ndarray):
|
||||||
|
assert np.allclose(batch.__dict__[key], sampled_batch.__dict__[key]), \
|
||||||
|
"Value mismatch for key: {}".format(key)
|
||||||
|
if isinstance(batch.__dict__[key],
|
||||||
|
Batch) and isinstance(sampled_batch.__dict__[key], Batch):
|
||||||
|
assert batch.__dict__[key].is_empty()
|
||||||
|
assert sampled_batch.__dict__[key].is_empty()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_replaybuffer()
|
test_replaybuffer()
|
||||||
test_ignore_obs_next()
|
test_ignore_obs_next()
|
||||||
@ -1351,3 +1410,4 @@ if __name__ == '__main__':
|
|||||||
test_multibuf_hdf5()
|
test_multibuf_hdf5()
|
||||||
test_from_data()
|
test_from_data()
|
||||||
test_herreplaybuffer()
|
test_herreplaybuffer()
|
||||||
|
test_custom_key()
|
||||||
|
@ -220,9 +220,8 @@ class ReplayBuffer:
|
|||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""Add a batch of data into replay buffer.
|
"""Add a batch of data into replay buffer.
|
||||||
|
|
||||||
:param Batch batch: the input data batch. Its keys must belong to the 7
|
:param Batch batch: the input data batch. "obs", "act", "rew",
|
||||||
input keys, and "obs", "act", "rew", "terminated", "truncated" is
|
"terminated", "truncated" are required keys.
|
||||||
required.
|
|
||||||
:param buffer_ids: to make consistent with other buffer's add function; if it
|
:param buffer_ids: to make consistent with other buffer's add function; if it
|
||||||
is not None, we assume the input batch's first dimension is always 1.
|
is not None, we assume the input batch's first dimension is always 1.
|
||||||
|
|
||||||
@ -232,12 +231,12 @@ class ReplayBuffer:
|
|||||||
"""
|
"""
|
||||||
# preprocess batch
|
# preprocess batch
|
||||||
new_batch = Batch()
|
new_batch = Batch()
|
||||||
for key in set(self._input_keys).intersection(batch.keys()):
|
for key in batch.keys():
|
||||||
new_batch.__dict__[key] = batch[key]
|
new_batch.__dict__[key] = batch[key]
|
||||||
batch = new_batch
|
batch = new_batch
|
||||||
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
|
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
|
||||||
assert set(["obs", "act", "rew", "terminated", "truncated",
|
assert set(["obs", "act", "rew", "terminated", "truncated", "done"]
|
||||||
"done"]).issubset(batch.keys())
|
).issubset(batch.keys()) # important to do after preprocess batch
|
||||||
stacked_batch = buffer_ids is not None
|
stacked_batch = buffer_ids is not None
|
||||||
if stacked_batch:
|
if stacked_batch:
|
||||||
assert len(batch) == 1
|
assert len(batch) == 1
|
||||||
@ -376,14 +375,18 @@ class ReplayBuffer:
|
|||||||
obs_next = self.get(indices, "obs_next", Batch())
|
obs_next = self.get(indices, "obs_next", Batch())
|
||||||
else:
|
else:
|
||||||
obs_next = self.get(self.next(indices), "obs", Batch())
|
obs_next = self.get(self.next(indices), "obs", Batch())
|
||||||
return Batch(
|
batch_dict = {
|
||||||
obs=obs,
|
"obs": obs,
|
||||||
act=self.act[indices],
|
"act": self.act[indices],
|
||||||
rew=self.rew[indices],
|
"rew": self.rew[indices],
|
||||||
terminated=self.terminated[indices],
|
"terminated": self.terminated[indices],
|
||||||
truncated=self.truncated[indices],
|
"truncated": self.truncated[indices],
|
||||||
done=self.done[indices],
|
"done": self.done[indices],
|
||||||
obs_next=obs_next,
|
"obs_next": obs_next,
|
||||||
info=self.get(indices, "info", Batch()),
|
"info": self.get(indices, "info", Batch()),
|
||||||
policy=self.get(indices, "policy", Batch()),
|
"policy": self.get(indices, "policy", Batch()),
|
||||||
)
|
}
|
||||||
|
for key in self._meta.__dict__.keys():
|
||||||
|
if key not in self._input_keys:
|
||||||
|
batch_dict[key] = self._meta[key][indices]
|
||||||
|
return Batch(batch_dict)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user