refactor code to avoid memory leak

This commit is contained in:
Dong Yan 2018-01-11 17:02:36 +08:00
parent 284cc64c18
commit afc55ed9c2
5 changed files with 17 additions and 21 deletions

2
AlphaGo/.gitignore vendored
View File

@ -1,5 +1,5 @@
data data
checkpoints checkpoint*
random random
*.log *.log
*.txt *.txt

View File

@ -58,8 +58,9 @@ class Game:
def clear(self): def clear(self):
if self.name == "go": if self.name == "go":
del self.board[:]
self.board = [utils.EMPTY] * (self.size ** 2) self.board = [utils.EMPTY] * (self.size ** 2)
self.history = [] del self.history[:]
if self.name == "reversi": if self.name == "reversi":
self.board = self.game_engine.get_board() self.board = self.game_engine.get_board()
for _ in range(self.history_length): for _ in range(self.history_length):

View File

@ -9,6 +9,7 @@ import tensorflow as tf
import tensorflow.contrib.layers as layers import tensorflow.contrib.layers as layers
import multi_gpu import multi_gpu
from utils import Data
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -71,14 +72,6 @@ def value_head(input, is_training):
h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4)) h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4))
return h return h
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
class ResNet(object): class ResNet(object):
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None, def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
white_checkpoint_path=None): white_checkpoint_path=None):
@ -238,7 +231,6 @@ class ResNet(object):
os.mkdir(self.save_path + 'black') os.mkdir(self.save_path + 'black')
os.mkdir(self.save_path + 'white') os.mkdir(self.save_path + 'white')
new_file_list = []
all_file_list = [] all_file_list = []
training_data = {'states': [], 'probs': [], 'winner': []} training_data = {'states': [], 'probs': [], 'winner': []}
@ -257,6 +249,7 @@ class ResNet(object):
self.training_data['probs'].append(probs) self.training_data['probs'].append(probs)
self.training_data['winner'].append(winner) self.training_data['winner'].append(winner)
self.training_data['length'].append(states.shape[0]) self.training_data['length'].append(states.shape[0])
del new_file_list[:]
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list)) new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
if len(self.training_data['states']) != self.window_length: if len(self.training_data['states']) != self.window_length:
@ -300,7 +293,8 @@ class ResNet(object):
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file) self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
for key in training_data.keys(): for key in training_data.keys():
training_data[key] = [] del training_data[key][:]
#training_data[key] = []
iters += 1 iters += 1
def _file_to_training_data(self, file_name): def _file_to_training_data(self, file_name):

View File

@ -5,6 +5,7 @@ import time
import os import os
from game import Game from game import Game
from engine import GTPEngine from engine import GTPEngine
from utils import Data
import utils import utils
from time import gmtime, strftime from time import gmtime, strftime
@ -15,15 +16,6 @@ if python_version < (3, 0):
else: else:
import _pickle as cPickle import _pickle as cPickle
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
def reset(self):
self.__init__()
if __name__ == '__main__': if __name__ == '__main__':
""" """

View File

@ -5,6 +5,15 @@
# $Author: renyong15 © <mails.tsinghua.edu.cn> # $Author: renyong15 © <mails.tsinghua.edu.cn>
# #
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
def reset(self):
self.__init__()
WHITE = -1 WHITE = -1
EMPTY = 0 EMPTY = 0
BLACK = +1 BLACK = +1