diff --git a/AlphaGo/model.py b/AlphaGo/model.py index e8b5eb9..3cfb900 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -215,9 +215,10 @@ class ResNet(object): self.training_data['states'].append(states) self.training_data['probs'].append(probs) self.training_data['winner'].append(winner) - training_data['states'] = np.concatenate(self.training_data['states'], axis=0) - training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0) - training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0) + if len(self.training_data['states']) == self.window_length: + training_data['states'] = np.concatenate(self.training_data['states'], axis=0) + training_data['probs'] = np.concatenate(self.training_data['probs'], axis=0) + training_data['winner'] = np.concatenate(self.training_data['winner'], axis=0) if len(self.training_data['states']) != self.window_length: continue