diff --git a/AlphaGo/utils.py b/AlphaGo/utils.py index d005c42..8075381 100644 --- a/AlphaGo/utils.py +++ b/AlphaGo/utils.py @@ -26,26 +26,32 @@ import go KGS_COLUMNS = 'ABCDEFGHJKLMNOPQRST' SGF_COLUMNS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + def parse_sgf_to_flat(sgf): return flatten_coords(parse_sgf_coords(sgf)) + def flatten_coords(c): return go.N * c[0] + c[1] + def unflatten_coords(f): return divmod(f, go.N) + def parse_sgf_coords(s): 'Interprets coords. aa is top left corner; sa is top right corner' if s is None or s == '': return None return SGF_COLUMNS.index(s[1]), SGF_COLUMNS.index(s[0]) + def unparse_sgf_coords(c): if c is None: return '' return SGF_COLUMNS[c[1]] + SGF_COLUMNS[c[0]] + def parse_kgs_coords(s): 'Interprets coords. A1 is bottom left; A9 is top left.' if s == 'pass': @@ -55,17 +61,20 @@ def parse_kgs_coords(s): row_from_bottom = int(s[1:]) - 1 return go.N - row_from_bottom - 1, col + def parse_pygtp_coords(vertex): 'Interprets coords. (1, 1) is bottom left; (1, 9) is top left.' if vertex in (gtp.PASS, gtp.RESIGN): return None return go.N - vertex[1], vertex[0] - 1 + def unparse_pygtp_coords(c): if c is None: return gtp.PASS return c[1] + 1, go.N - c[0] + def parse_game_result(result): if re.match(r'[bB]\+', result): return go.BLACK @@ -74,12 +83,15 @@ def parse_game_result(result): else: return None + def product(numbers): return functools.reduce(operator.mul, numbers) + def take_n(n, iterable): return list(itertools.islice(iterable, n)) + def iter_chunks(chunk_size, iterator): while True: next_chunk = take_n(chunk_size, iterator) @@ -89,7 +101,8 @@ def iter_chunks(chunk_size, iterator): else: break -def shuffler(iterator, pool_size=10**5, refill_threshold=0.9): + +def shuffler(iterator, pool_size=10 ** 5, refill_threshold=0.9): yields_between_refills = round(pool_size * (1 - refill_threshold)) # initialize pool; this step may or may not exhaust the iterator. pool = take_n(pool_size, iterator) @@ -102,18 +115,23 @@ def shuffler(iterator, pool_size=10**5, refill_threshold=0.9): break pool.extend(next_batch) # finish consuming whatever's left - no need for further randomization. - yield from pool + yield pool + class timer(object): all_times = defaultdict(float) + def __init__(self, label): self.label = label + def __enter__(self): self.tick = time.time() + def __exit__(self, type, value, traceback): self.tock = time.time() self.all_times[self.label] += self.tock - self.tick + @classmethod def print_times(cls): for k, v in cls.all_times.items(): - print("%s: %.3f" % (k, v)) \ No newline at end of file + print("%s: %.3f" % (k, v))