This commit is contained in:
Trinkle23897 2020-05-29 11:49:47 +08:00
parent d2b2fa87c0
commit be9ce44290

View File

@ -402,7 +402,7 @@ class Collector(object):
lens = [len(b) for b in self.buffer] lens = [len(b) for b in self.buffer]
total = sum(lens) total = sum(lens)
batch_index = np.random.choice( batch_index = np.random.choice(
total, batch_size, p=np.array(lens) / total) len(self.buffer), batch_size, p=np.array(lens) / total)
else: else:
batch_index = np.array([]) batch_index = np.array([])
batch_data = Batch() batch_data = Batch()