refactor code to avoid memory leak
This commit is contained in:
parent
284cc64c18
commit
afc55ed9c2
2
AlphaGo/.gitignore
vendored
2
AlphaGo/.gitignore
vendored
@ -1,5 +1,5 @@
|
||||
data
|
||||
checkpoints
|
||||
checkpoint*
|
||||
random
|
||||
*.log
|
||||
*.txt
|
||||
|
@ -58,8 +58,9 @@ class Game:
|
||||
|
||||
def clear(self):
|
||||
if self.name == "go":
|
||||
del self.board[:]
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
self.history = []
|
||||
del self.history[:]
|
||||
if self.name == "reversi":
|
||||
self.board = self.game_engine.get_board()
|
||||
for _ in range(self.history_length):
|
||||
|
@ -9,6 +9,7 @@ import tensorflow as tf
|
||||
import tensorflow.contrib.layers as layers
|
||||
|
||||
import multi_gpu
|
||||
from utils import Data
|
||||
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
@ -71,14 +72,6 @@ def value_head(input, is_training):
|
||||
h = layers.fully_connected(h, 1, activation_fn=tf.nn.tanh, weights_regularizer=layers.l2_regularizer(1e-4))
|
||||
return h
|
||||
|
||||
|
||||
class Data(object):
|
||||
def __init__(self):
|
||||
self.boards = []
|
||||
self.probs = []
|
||||
self.winner = 0
|
||||
|
||||
|
||||
class ResNet(object):
|
||||
def __init__(self, board_size, action_num, history_length=1, residual_block_num=10, black_checkpoint_path=None,
|
||||
white_checkpoint_path=None):
|
||||
@ -238,7 +231,6 @@ class ResNet(object):
|
||||
os.mkdir(self.save_path + 'black')
|
||||
os.mkdir(self.save_path + 'white')
|
||||
|
||||
new_file_list = []
|
||||
all_file_list = []
|
||||
training_data = {'states': [], 'probs': [], 'winner': []}
|
||||
|
||||
@ -257,6 +249,7 @@ class ResNet(object):
|
||||
self.training_data['probs'].append(probs)
|
||||
self.training_data['winner'].append(winner)
|
||||
self.training_data['length'].append(states.shape[0])
|
||||
del new_file_list[:]
|
||||
new_file_list = list(set(os.listdir(data_path)).difference(all_file_list))
|
||||
|
||||
if len(self.training_data['states']) != self.window_length:
|
||||
@ -300,7 +293,8 @@ class ResNet(object):
|
||||
self.white_saver.save(self.sess, self.save_path + 'white/' + ckpt_file)
|
||||
|
||||
for key in training_data.keys():
|
||||
training_data[key] = []
|
||||
del training_data[key][:]
|
||||
#training_data[key] = []
|
||||
iters += 1
|
||||
|
||||
def _file_to_training_data(self, file_name):
|
||||
|
@ -5,6 +5,7 @@ import time
|
||||
import os
|
||||
from game import Game
|
||||
from engine import GTPEngine
|
||||
from utils import Data
|
||||
import utils
|
||||
from time import gmtime, strftime
|
||||
|
||||
@ -15,15 +16,6 @@ if python_version < (3, 0):
|
||||
else:
|
||||
import _pickle as cPickle
|
||||
|
||||
class Data(object):
|
||||
def __init__(self):
|
||||
self.boards = []
|
||||
self.probs = []
|
||||
self.winner = 0
|
||||
|
||||
def reset(self):
|
||||
self.__init__()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
|
@ -5,6 +5,15 @@
|
||||
# $Author: renyong15 © <mails.tsinghua.edu.cn>
|
||||
#
|
||||
|
||||
class Data(object):
|
||||
def __init__(self):
|
||||
self.boards = []
|
||||
self.probs = []
|
||||
self.winner = 0
|
||||
|
||||
def reset(self):
|
||||
self.__init__()
|
||||
|
||||
WHITE = -1
|
||||
EMPTY = 0
|
||||
BLACK = +1
|
||||
|
Loading…
x
Reference in New Issue
Block a user