2020-03-14 21:48:31 +08:00
|
|
|
import torch
|
2020-03-13 17:49:22 +08:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
2020-03-11 09:09:56 +08:00
|
|
|
class Batch(object):
|
|
|
|
"""Suggested keys: [obs, act, rew, done, obs_next, info]"""
|
2020-03-13 17:49:22 +08:00
|
|
|
|
2020-03-11 09:09:56 +08:00
|
|
|
def __init__(self, **kwargs):
|
|
|
|
super().__init__()
|
2020-03-12 22:20:33 +08:00
|
|
|
self.__dict__.update(kwargs)
|
|
|
|
|
2020-03-14 21:48:31 +08:00
|
|
|
def __getitem__(self, index):
|
|
|
|
b = Batch()
|
|
|
|
for k in self.__dict__.keys():
|
|
|
|
if self.__dict__[k] is not None:
|
2020-03-16 11:11:29 +08:00
|
|
|
b.update(**{k: self.__dict__[k][index]})
|
2020-03-14 21:48:31 +08:00
|
|
|
return b
|
|
|
|
|
2020-03-12 22:20:33 +08:00
|
|
|
def update(self, **kwargs):
|
2020-03-11 09:09:56 +08:00
|
|
|
self.__dict__.update(kwargs)
|
2020-03-13 17:49:22 +08:00
|
|
|
|
|
|
|
def append(self, batch):
|
|
|
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
|
|
|
for k in batch.__dict__.keys():
|
|
|
|
if batch.__dict__[k] is None:
|
|
|
|
continue
|
|
|
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
|
|
|
self.__dict__[k] = batch.__dict__[k]
|
|
|
|
elif isinstance(batch.__dict__[k], np.ndarray):
|
|
|
|
self.__dict__[k] = np.concatenate([
|
|
|
|
self.__dict__[k], batch.__dict__[k]])
|
2020-03-14 21:48:31 +08:00
|
|
|
elif isinstance(batch.__dict__[k], torch.Tensor):
|
|
|
|
self.__dict__[k] = torch.cat([
|
|
|
|
self.__dict__[k], batch.__dict__[k]])
|
2020-03-13 17:49:22 +08:00
|
|
|
elif isinstance(batch.__dict__[k], list):
|
|
|
|
self.__dict__[k] += batch.__dict__[k]
|
|
|
|
else:
|
2020-03-28 09:43:35 +08:00
|
|
|
s = 'No support for append with type'\
|
|
|
|
+ str(type(batch.__dict__[k]))\
|
|
|
|
+ 'in class Batch.'
|
|
|
|
raise TypeError(s)
|
2020-03-17 11:37:31 +08:00
|
|
|
|
2020-03-20 19:52:29 +08:00
|
|
|
def split(self, size=None, permute=True):
|
2020-03-17 11:37:31 +08:00
|
|
|
length = min([
|
|
|
|
len(self.__dict__[k]) for k in self.__dict__.keys()
|
|
|
|
if self.__dict__[k] is not None])
|
|
|
|
if size is None:
|
|
|
|
size = length
|
|
|
|
temp = 0
|
2020-03-20 19:52:29 +08:00
|
|
|
if permute:
|
|
|
|
index = np.random.permutation(length)
|
|
|
|
else:
|
|
|
|
index = np.arange(length)
|
2020-03-17 11:37:31 +08:00
|
|
|
while temp < length:
|
2020-03-20 19:52:29 +08:00
|
|
|
yield self[index[temp:temp + size]]
|
2020-03-17 11:37:31 +08:00
|
|
|
temp += size
|