Custom keys support in ReplayBuffer (#903)

Issue: Custom keys support in ReplayBuffer #902
Modified `ReplayBuffer` `add` and `__getitem__` methods.
Added `test_custom_key()` to test_buffer.py
This commit is contained in:
Anas BELFADIL 2023-08-11 01:06:10 +02:00 committed by GitHub
parent 61182450b6
commit 80a698be52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 90 additions and 17 deletions

View File

@ -166,3 +166,13 @@ isort
yapf yapf
pydocstyle pydocstyle
Args Args
tuples
tuple
Multi
multi
parameterized
Proximal
metadata
GPU
Dopamine
builtin

View File

@ -1336,6 +1336,65 @@ def test_from_data():
os.remove(path) os.remove(path)
def test_custom_key():
batch = Batch(
**{
'obs_next':
np.array(
[
[
1.174, -0.1151, -0.609, -0.5205, -0.9316, 3.236, -2.418, 0.386,
0.2227, -0.5117, 2.293
]
]
),
'rew':
np.array([4.28125]),
'act':
np.array([[-0.3088, -0.4636, 0.4956]]),
'truncated':
np.array([False]),
'obs':
np.array(
[
[
1.193, -0.1203, -0.6123, -0.519, -0.9434, 3.32, -2.266, 0.9116,
0.623, 0.1259, 0.363
]
]
),
'terminated':
np.array([False]),
'done':
np.array([False]),
'returns':
np.array([74.70343082]),
'info':
Batch(),
'policy':
Batch(),
}
)
buffer_size = len(batch.rew)
buffer = ReplayBuffer(buffer_size)
buffer.add(batch)
sampled_batch, _ = buffer.sample(1)
# Check if they have the same keys
assert set(batch.keys()) == set(sampled_batch.keys()), \
"Batches have different keys: {} and {}".format(
set(batch.keys()), set(sampled_batch.keys()))
# Compare the values for each key
for key in batch.keys():
if isinstance(batch.__dict__[key], np.ndarray
) and isinstance(sampled_batch.__dict__[key], np.ndarray):
assert np.allclose(batch.__dict__[key], sampled_batch.__dict__[key]), \
"Value mismatch for key: {}".format(key)
if isinstance(batch.__dict__[key],
Batch) and isinstance(sampled_batch.__dict__[key], Batch):
assert batch.__dict__[key].is_empty()
assert sampled_batch.__dict__[key].is_empty()
if __name__ == '__main__': if __name__ == '__main__':
test_replaybuffer() test_replaybuffer()
test_ignore_obs_next() test_ignore_obs_next()
@ -1351,3 +1410,4 @@ if __name__ == '__main__':
test_multibuf_hdf5() test_multibuf_hdf5()
test_from_data() test_from_data()
test_herreplaybuffer() test_herreplaybuffer()
test_custom_key()

View File

@ -220,9 +220,8 @@ class ReplayBuffer:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Add a batch of data into replay buffer. """Add a batch of data into replay buffer.
:param Batch batch: the input data batch. Its keys must belong to the 7 :param Batch batch: the input data batch. "obs", "act", "rew",
input keys, and "obs", "act", "rew", "terminated", "truncated" is "terminated", "truncated" are required keys.
required.
:param buffer_ids: to make consistent with other buffer's add function; if it :param buffer_ids: to make consistent with other buffer's add function; if it
is not None, we assume the input batch's first dimension is always 1. is not None, we assume the input batch's first dimension is always 1.
@ -232,12 +231,12 @@ class ReplayBuffer:
""" """
# preprocess batch # preprocess batch
new_batch = Batch() new_batch = Batch()
for key in set(self._input_keys).intersection(batch.keys()): for key in batch.keys():
new_batch.__dict__[key] = batch[key] new_batch.__dict__[key] = batch[key]
batch = new_batch batch = new_batch
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated)
assert set(["obs", "act", "rew", "terminated", "truncated", assert set(["obs", "act", "rew", "terminated", "truncated", "done"]
"done"]).issubset(batch.keys()) ).issubset(batch.keys()) # important to do after preprocess batch
stacked_batch = buffer_ids is not None stacked_batch = buffer_ids is not None
if stacked_batch: if stacked_batch:
assert len(batch) == 1 assert len(batch) == 1
@ -376,14 +375,18 @@ class ReplayBuffer:
obs_next = self.get(indices, "obs_next", Batch()) obs_next = self.get(indices, "obs_next", Batch())
else: else:
obs_next = self.get(self.next(indices), "obs", Batch()) obs_next = self.get(self.next(indices), "obs", Batch())
return Batch( batch_dict = {
obs=obs, "obs": obs,
act=self.act[indices], "act": self.act[indices],
rew=self.rew[indices], "rew": self.rew[indices],
terminated=self.terminated[indices], "terminated": self.terminated[indices],
truncated=self.truncated[indices], "truncated": self.truncated[indices],
done=self.done[indices], "done": self.done[indices],
obs_next=obs_next, "obs_next": obs_next,
info=self.get(indices, "info", Batch()), "info": self.get(indices, "info", Batch()),
policy=self.get(indices, "policy", Batch()), "policy": self.get(indices, "policy", Batch()),
) }
for key in self._meta.__dict__.keys():
if key not in self._input_keys:
batch_dict[key] = self._meta[key][indices]
return Batch(batch_dict)