refactor code to avoid memory leak
This commit is contained in:
parent
284cc64c18
commit
afc55ed9c2
2
AlphaGo/.gitignore
vendored
2
AlphaGo/.gitignore
vendored
@ -1,5 +1,5 @@
|
|||||||
data
|
data
|
||||||
checkpoints
|
checkpoint*
|
||||||
random
|
random
|
||||||
*.log
|
*.log
|
||||||
*.txt
|
*.txt
|
||||||
|
@ -58,8 +58,9 @@ class Game:
|
|||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
if self.name == "go":
|
if self.name == "go":
|
||||||
|
del self.board[:]
|
||||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||||
self.history = []
|
del self.history[:]
|
||||||
if self.name == "reversi":
|
if self.name == "reversi":
|
||||||
self.board = self.game_engine.get_board()
|
self.board = self.game_engine.get_board()
|
||||||
for _ in range(self.history_length):
|
for _ in range(self.history_length):
|
||||||
|
@ -9,6 +9,7 @@ import tensorflow as tf
|
|||||||
import tensorflow.contrib.layers as layers
|
import tensorflow.contrib.layers as layers
|
||||||
|
|
||||||
import multi_gpu
|
import multi_gpu
|
||||||
|
from utils import Data
|
||||||
|
|
||||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||||
|
|
||||||
@ -71,14 +72,6 @@ def value_head(input, is_training):
|
|||||||
h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4))
|
h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4))
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class Data(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.boards = []
|
|
||||||
self.probs = []
|
|
||||||
self.winner = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ResNet(object):
|
class ResNet(object):
|
||||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
|
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
|
||||||
white_checkpoint_path=None):
|
white_checkpoint_path=None):
|
||||||
@ -238,7 +231,6 @@ class ResNet(object):
|
|||||||
os.mkdir(self.save_path + 'black')
|
os.mkdir(self.save_path + 'black')
|
||||||
os.mkdir(self.save_path + 'white')
|
os.mkdir(self.save_path + 'white')
|
||||||
|
|
||||||
new_file_list = []
|
|
||||||
all_file_list = []
|
all_file_list = []
|
||||||
training_data = {'states': [], 'probs': [], 'winner': []}
|
training_data = {'states': [], 'probs': [], 'winner': []}
|
||||||
|
|
||||||
@ -257,6 +249,7 @@ class ResNet(object):
|
|||||||
self.training_data['probs'].append(probs)
|
self.training_data['probs'].append(probs)
|
||||||
self.training_data['winner'].append(winner)
|
self.training_data['winner'].append(winner)
|
||||||
self.training_data['length'].append(states.shape[0])
|
self.training_data['length'].append(states.shape[0])
|
||||||
|
del new_file_list[:]
|
||||||
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||||
|
|
||||||
if len(self.training_data['states']) != self.window_length:
|
if len(self.training_data['states']) != self.window_length:
|
||||||
@ -300,7 +293,8 @@ class ResNet(object):
|
|||||||
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
|
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
|
||||||
|
|
||||||
for key in training_data.keys():
|
for key in training_data.keys():
|
||||||
training_data[key] = []
|
del training_data[key][:]
|
||||||
|
#training_data[key] = []
|
||||||
iters += 1
|
iters += 1
|
||||||
|
|
||||||
def _file_to_training_data(self, file_name):
|
def _file_to_training_data(self, file_name):
|
||||||
|
@ -5,6 +5,7 @@ import time
|
|||||||
import os
|
import os
|
||||||
from game import Game
|
from game import Game
|
||||||
from engine import GTPEngine
|
from engine import GTPEngine
|
||||||
|
from utils import Data
|
||||||
import utils
|
import utils
|
||||||
from time import gmtime, strftime
|
from time import gmtime, strftime
|
||||||
|
|
||||||
@ -15,15 +16,6 @@ if python_version < (3, 0):
|
|||||||
else:
|
else:
|
||||||
import _pickle as cPickle
|
import _pickle as cPickle
|
||||||
|
|
||||||
class Data(object):
|
|
||||||
def __init__(self):
|
|
||||||
self.boards = []
|
|
||||||
self.probs = []
|
|
||||||
self.winner = 0
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.__init__()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
"""
|
"""
|
||||||
|
@ -5,6 +5,15 @@
|
|||||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||||
#
|
#
|
||||||
|
|
||||||
|
class Data(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.boards = []
|
||||||
|
self.probs = []
|
||||||
|
self.winner = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.__init__()
|
||||||
|
|
||||||
WHITE = -1
|
WHITE = -1
|
||||||
EMPTY = 0
|
EMPTY = 0
|
||||||
BLACK = +1
|
BLACK = +1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user