56 lines
1.9 KiB
Python
Raw Normal View History

2020-03-14 21:48:31 +08:00
import torch
2020-03-13 17:49:22 +08:00
import numpy as np
2020-03-11 09:09:56 +08:00
class Batch(object):
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
2020-03-13 17:49:22 +08:00
2020-03-11 09:09:56 +08:00
def __init__(self, **kwargs):
super().__init__()
2020-03-12 22:20:33 +08:00
self.__dict__.update(kwargs)
2020-03-14 21:48:31 +08:00
def __getitem__(self, index):
b = Batch()
for k in self.__dict__.keys():
if self.__dict__[k] is not None:
2020-03-16 11:11:29 +08:00
b.update(**{k: self.__dict__[k][index]})
2020-03-14 21:48:31 +08:00
return b
2020-03-12 22:20:33 +08:00
def update(self, **kwargs):
2020-03-11 09:09:56 +08:00
self.__dict__.update(kwargs)
2020-03-13 17:49:22 +08:00
def append(self, batch):
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys():
if batch.__dict__[k] is None:
continue
if not hasattr(self, k) or self.__dict__[k] is None:
self.__dict__[k] = batch.__dict__[k]
elif isinstance(batch.__dict__[k], np.ndarray):
self.__dict__[k] = np.concatenate([
self.__dict__[k], batch.__dict__[k]])
2020-03-14 21:48:31 +08:00
elif isinstance(batch.__dict__[k], torch.Tensor):
self.__dict__[k] = torch.cat([
self.__dict__[k], batch.__dict__[k]])
2020-03-13 17:49:22 +08:00
elif isinstance(batch.__dict__[k], list):
self.__dict__[k] += batch.__dict__[k]
else:
raise TypeError(
2020-03-14 21:48:31 +08:00
'No support for append with type {} in class Batch.'
.format(type(batch.__dict__[k])))
2020-03-17 11:37:31 +08:00
2020-03-20 19:52:29 +08:00
def split(self, size=None, permute=True):
2020-03-17 11:37:31 +08:00
length = min([
len(self.__dict__[k]) for k in self.__dict__.keys()
if self.__dict__[k] is not None])
if size is None:
size = length
temp = 0
2020-03-20 19:52:29 +08:00
if permute:
index = np.random.permutation(length)
else:
index = np.arange(length)
2020-03-17 11:37:31 +08:00
while temp < length:
2020-03-20 19:52:29 +08:00
yield self[index[temp:temp + size]]
2020-03-17 11:37:31 +08:00
temp += size