diff --git a/AlphaGo/model.py b/AlphaGo/model.py index 6fde6e5..8d4c508 100644 --- a/AlphaGo/model.py +++ b/AlphaGo/model.py @@ -284,7 +284,7 @@ class ResNet(object): history.append(board) states.append(self._history2state(history, color)) probs.append(np.array(prob).reshape(1, self.board_size ** 2 + 1)) - winner.append(np.array(data.winner).reshape(1, 1)) + winner.append(np.array(data.winner * color).reshape(1, 1)) color *= -1 states = np.concatenate(states, axis=0) probs = np.concatenate(probs, axis=0)