refactor code to avoid memory leak

This commit is contained in:
Dong Yan 2018-01-11 17:02:36 +08:00
parent 284cc64c18
commit afc55ed9c2
5 changed files with 17 additions and 21 deletions

2
AlphaGo/.gitignore vendored
View File

@ -1,5 +1,5 @@
data
checkpoints
checkpoint*
random
*.log
*.txt

View File

@ -58,8 +58,9 @@ class Game:
def clear(self):
if self.name == "go":
del self.board[:]
self.board = [utils.EMPTY] * (self.size ** 2)
self.history = []
del self.history[:]
if self.name == "reversi":
self.board = self.game_engine.get_board()
for _ in range(self.history_length):

View File

@ -9,6 +9,7 @@ import tensorflow as tf
import tensorflow.contrib.layers as layers
import multi_gpu
from utils import Data
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
@ -71,14 +72,6 @@ def value_head(input, is_training):
h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4))
return h
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
class ResNet(object):
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
white_checkpoint_path=None):
@ -238,7 +231,6 @@ class ResNet(object):
os.mkdir(self.save_path + 'black')
os.mkdir(self.save_path + 'white')
new_file_list = []
all_file_list = []
training_data = {'states': [], 'probs': [], 'winner': []}
@ -257,6 +249,7 @@ class ResNet(object):
self.training_data['probs'].append(probs)
self.training_data['winner'].append(winner)
self.training_data['length'].append(states.shape[0])
del new_file_list[:]
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
if len(self.training_data['states']) != self.window_length:
@ -300,7 +293,8 @@ class ResNet(object):
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
for key in training_data.keys():
training_data[key] = []
del training_data[key][:]
#training_data[key] = []
iters += 1
def _file_to_training_data(self, file_name):

View File

@ -5,6 +5,7 @@ import time
import os
from game import Game
from engine import GTPEngine
from utils import Data
import utils
from time import gmtime, strftime
@ -15,15 +16,6 @@ if python_version < (3, 0):
else:
import _pickle as cPickle
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
def reset(self):
self.__init__()
if __name__ == '__main__':
"""

View File

@ -5,6 +5,15 @@
# $Author: renyong15 © <mails.tsinghua.edu.cn>
#
class Data(object):
def __init__(self):
self.boards = []
self.probs = []
self.winner = 0
def reset(self):
self.__init__()
WHITE = -1
EMPTY = 0
BLACK = +1