From 97161f37ef1831e396d4468873c13ea7077840b1 Mon Sep 17 00:00:00 2001 From: rtz19970824 <1289226405@qq.com> Date: Fri, 22 Dec 2017 13:42:53 +0800 Subject: [PATCH] faster the loading --- AlphaGo/model.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 5629128..c4338c8 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -215,9 +215,10 @@ class ResNet(object): self.training_data['states'].append(states) self.training_data['probs'].append(probs) self.training_data['winner'].append(winner) - training_data['states'] = np.concatenate(self.training_data['states'], axis=0) - training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0) - training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0) + if len(self.training_data['states']) == self.window_length: + training_data['states'] = np.concatenate(self.training_data['states'], axis=0) + training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0) + training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0) if len(self.training_data['states']) != self.window_length: continue