fix #59
This commit is contained in:
parent
d2b2fa87c0
commit
be9ce44290
@ -402,7 +402,7 @@ class Collector(object):
|
|||||||
lens = [len(b) for b in self.buffer]
|
lens = [len(b) for b in self.buffer]
|
||||||
total = sum(lens)
|
total = sum(lens)
|
||||||
batch_index = np.random.choice(
|
batch_index = np.random.choice(
|
||||||
total, batch_size, p=np.array(lens) / total)
|
len(self.buffer), batch_size, p=np.array(lens) / total)
|
||||||
else:
|
else:
|
||||||
batch_index = np.array([])
|
batch_index = np.array([])
|
||||||
batch_data = Batch()
|
batch_data = Batch()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user