From e76ccaee80249fc54b7e9eb822c459feb9a823f9 Mon Sep 17 00:00:00 2001 From: Wenbo Hu Date: Tue, 16 Jan 2018 14:10:56 +0800 Subject: [PATCH 1/7] add union set for do_move and is_valid --- AlphaGo/game.py | 10 ++- AlphaGo/go.py | 203 +++++++++++++++++++++++++++++++++--------------- AlphaGo/play.py | 6 +- 3 files changed, 152 insertions(+), 67 deletions(-) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index 0d3ca59..a299e97 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -40,6 +40,9 @@ class Game: self.history_hashtable = set() self.game_engine = go.Go(size=self.size, komi=self.komi) self.board = [utils.EMPTY] * (self.size ** 2) + self.group_ancestors = {} # key: idx, value: ancestor idx + self.liberty = {} # key: ancestor idx, value: set of liberty + self.stones = {} # key: ancestor idx, value: set of stones elif self.name == "reversi": self.size = 8 self.history_length = 1 @@ -62,6 +65,9 @@ class Game: self.board = [utils.EMPTY] * (self.size ** 2) del self.history[:] self.history_hashtable.clear() + self.group_ancestors.clear() + self.liberty.clear() + self.stones.clear() if self.name == "reversi": self.board = self.game_engine.get_board() for _ in range(self.history_length): @@ -84,7 +90,7 @@ class Game: evaluator = lambda state:self.model(role, state) mcts = MCTS(self.game_engine, evaluator, [latest_boards, color], self.size ** 2 + 1, role=role, debug=self.debug, inverse=True) - mcts.search(max_step=100) + mcts.search(max_step=5) if self.debug: file = open("mcts_debug.log", 'ab') np.savetxt(file, mcts.root.Q, header="\n" + role + " Q value : ", fmt='%.4f', newline=", ") @@ -109,7 +115,7 @@ class Game: res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex) if self.name == "go": res = self.game_engine.executor_do_move(self.history, self.history_hashtable, self.latest_boards, self.board, - color, vertex) + self.group_ancestors, self.liberty, self.stones, color, vertex) return res def think_play_move(self, color): diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 38efd17..84e6b7d 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -44,6 +44,23 @@ class Go: nei.append((_x, _y)) return nei + def _neighbor_color(self, current_board, vertex, color): + # return neighbors which are listed in different colors + color_neighbor = [] # 1)neighbors in the same color + reverse_color_neighbor = [] # 2)neighbors in the reverse color + empty_neighbor = [] # 2)empty neighbors + reverse_color = utils.BLACK if color == utils.WHITE else utils.WHITE + for n in self._neighbor(vertex): + if current_board[self._flatten(n)] == color: + color_neighbor.append(self._flatten(n)) + elif current_board[self._flatten(n)] == utils.EMPTY: + empty_neighbor.append(self._flatten(n)) + elif current_board[self._flatten(n)] == reverse_color: + reverse_color_neighbor.append(self._flatten(n)) + else: + raise ValueError("board have other positions excluding BLACK, WHITE and EMPTY") + return color_neighbor, reverse_color_neighbor, empty_neighbor + def _corner(self, vertex): x, y = vertex corner = [] @@ -56,13 +73,11 @@ class Go: def _find_group(self, current_board, vertex): color = current_board[self._flatten(vertex)] - # print ("color : ", color) chain = set() frontier = [vertex] has_liberty = False while frontier: current = frontier.pop() - # print ("current : ", current) chain.add(current) for n in self._neighbor(current): if current_board[self._flatten(n)] == color and not n in chain: @@ -71,21 +86,26 @@ class Go: has_liberty = True return has_liberty, chain - def _is_suicide(self, current_board, color, vertex): - current_board[self._flatten(vertex)] = color # assume that we already take this move - suicide = False + def _find_ancestor(self, group_ancestors, idx): + r = idx + while group_ancestors[r] != r: + r = group_ancestors[r] + group_ancestors[idx] = r + return r - has_liberty, group = self._find_group(current_board, vertex) - if not has_liberty: - suicide = True # no liberty, suicide - for n in self._neighbor(vertex): - if current_board[self._flatten(n)] == utils.another_color(color): - opponent_liberty, group = self._find_group(current_board, n) - if not opponent_liberty: - suicide = False # this move is able to take opponent's stone, not suicide - - current_board[self._flatten(vertex)] = utils.EMPTY # undo this move - return suicide + def _is_suicide(self, current_board, group_ancestors, liberty, color, vertex): + color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color) + if empty_neighbor: + return False # neighbors have empty spaces + elif color_neighbor: # neighbors have same color, they have liberties + for idx in color_neighbor: + if len(liberty[self._find_ancestor(group_ancestors, idx)]) > 1: + return False + else: # neighbors have reverse color, they have only one liberty + for idx in reverse_color_neighbor: + if len(liberty[self._find_ancestor(group_ancestors, idx)]) == 1: + return False + return True def _process_board(self, current_board, color, vertex): nei = self._neighbor(vertex) @@ -107,25 +127,15 @@ class Go: return repeat def _is_eye(self, current_board, color, vertex): - nei = self._neighbor(vertex) - cor = self._corner(vertex) - ncolor = {color == current_board[self._flatten(n)] for n in nei} - if False in ncolor: - # print "not all neighbors are in same color with us" + # return is this position is an real eye of color + color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color) + if reverse_color_neighbor or empty_neighbor: # not an eye return False - _, group = self._find_group(current_board, nei[0]) - if set(nei) < group: - # print "all neighbors are in same group and same color with us" - return True - else: - opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color) - opponent_propotion = float(opponent_number) / float(len(cor)) - if opponent_propotion < 0.5: - # print "few opponents, real eye" - return True - else: - # print "many opponents, fake eye" - return False + cor = self._corner(vertex) + opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color) + opponent_propotion = float(opponent_number) / float(len(cor)) + # opponent_propotion<0.5 fake eye + return True if opponent_propotion < 0.5 else False def _knowledge_prunning(self, current_board, color, vertex): # forbid some stupid selfplay using human knowledge @@ -134,23 +144,6 @@ class Go: # forbid position on its own eye. return True - def _is_game_finished(self, current_board, color): - ''' - for each empty position, if it has both BLACK and WHITE neighbors, the game is still not finished - :return: return the game is finished - ''' - board = copy.deepcopy(current_board) - empty_idx = [i for i, x in enumerate(board) if x == utils.EMPTY] # find all empty idx - for idx in empty_idx: - neighbor_idx = self._neighbor(self.deflatten(idx)) - if len(neighbor_idx) > 1: - first_idx = neighbor_idx[0] - for other_idx in neighbor_idx[1:]: - if board[self.flatten(other_idx)] != board[self.flatten(first_idx)]: - return False - - return True - def _action2vertex(self, action): if action == self.size ** 2: vertex = (0, 0) @@ -158,7 +151,7 @@ class Go: vertex = self._deflatten(action) return vertex - def _rule_check(self, history_hashtable, current_board, color, vertex, is_thinking=True): + def _rule_check(self, history_hashtable, current_board, group_ancestors, liberty, color, vertex, is_thinking=True): ### in board if not self._in_board(vertex): if not is_thinking: @@ -174,7 +167,7 @@ class Go: return False ### check if it is suicide - if self._is_suicide(current_board, color, vertex): + if self._is_suicide(current_board, group_ancestors, liberty, color, vertex): if not is_thinking: raise ValueError("Target point causes suicide, Current Board: {}, color: {}, vertex : {}".format(current_board, color, vertex)) else: @@ -189,34 +182,59 @@ class Go: return True - def _is_valid(self, state, action, history_hashtable): + def _is_valid(self, state, action, history_hashtable, group_ancestors, liberty): history_boards, color = state vertex = self._action2vertex(action) current_board = history_boards[-1] - if not self._rule_check(history_hashtable, current_board, color, vertex): + if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex): return False if not self._knowledge_prunning(current_board, color, vertex): return False return True + def _get_groups(self, board): + group_ancestors = {} # key: idx, value: ancestor idx + liberty = {} # key: ancestor idx, value: set of liberty + for idx, color in enumerate(board): + if color and idx not in group_ancestors: + # build group + group_ancestors[idx] = idx + color_neighbor, _, empty_neighbor = \ + self._neighbor_color(board, self._deflatten(idx), color) + liberty[idx] = set(empty_neighbor) + group_list = copy.deepcopy(color_neighbor) + while group_list: + add_idx = group_list.pop() + if add_idx not in group_ancestors: + group_ancestors[add_idx] = idx + color_neighbor_add, _, empty_neighbor_add = \ + self._neighbor_color(board, self._deflatten(add_idx), color) + group_list += color_neighbor_add + liberty[idx] |= set(empty_neighbor_add) + return group_ancestors, liberty + def simulate_get_mask(self, state, action_set): # find all the invalid actions invalid_action_mask = [] history_boards, color = state + group_ancestors, liberty = self._get_groups(history_boards[-1]) history_hashtable = set() for board in history_boards: history_hashtable.add(tuple(board)) for action_candidate in action_set[:-1]: # go through all the actions excluding pass - if not self._is_valid(state, action_candidate, history_hashtable): + if not self._is_valid(state, action_candidate, history_hashtable, group_ancestors, liberty): invalid_action_mask.append(action_candidate) if len(invalid_action_mask) < len(action_set) - 1: invalid_action_mask.append(action_set[-1]) # forbid pass, if we have other choices # TODO: In fact we should not do this. In some extreme cases, we should permit pass. del history_hashtable + del group_ancestors + del liberty + # del stones return invalid_action_mask def _do_move(self, board, color, vertex): @@ -243,15 +261,73 @@ class Go: # since go is MDP, we only need the last board for hashing return tuple(state[0][-1]) - def executor_do_move(self, history, history_hashtable, latest_boards, current_board, color, vertex): - if not self._rule_check(history_hashtable, current_board, color, vertex, is_thinking=False): - # raise ValueError("!!! We have more than four ko at the same time !!!") + def _join_group(self, idx, idx_list, empty_neighbor, group_ancestors, liberty, stones): + # idx joins its neighbors id_list + # empty_neighbor: empty neighbors of idx + color_ancestor = set() + for color_idx in idx_list: + color_ancestor.add(self._find_ancestor(group_ancestors, color_idx)) + joined_ancestor = color_ancestor.pop() + liberty[joined_ancestor] |= set(empty_neighbor) + stones[joined_ancestor].add(idx) + group_ancestors[idx] = joined_ancestor + # add other groups + for color_idx in color_ancestor: + liberty[joined_ancestor] |= liberty[color_idx] + stones[joined_ancestor] |= stones[color_idx] + del liberty[color_idx] + for stone in stones[color_idx]: + group_ancestors[stone] = joined_ancestor + del stones[color_idx] + liberty[joined_ancestor].remove(idx) + + def _add_captured_liberty(self, board, liberty, group_ancestors, stones): + for captured_stone in stones: + color_neighbor, reverse_color_neighbor, empty_neighbor = \ + self._neighbor_color(board, self._deflatten(captured_stone), board[captured_stone]) + assert not empty_neighbor # make sure no empty spaces + for reverse_color_idx in reverse_color_neighbor: + reverse_color_idx_ancestor = self._find_ancestor(group_ancestors, reverse_color_idx) + liberty[reverse_color_idx_ancestor].add(captured_stone) + + def _remove_liberty(self, idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones): + # reverse_color_neighbor: stones near idx in the reverse color + reverse_color_ancestor = set() + for reverse_idx in reverse_color_neighbor: + reverse_color_ancestor.add(self._find_ancestor(group_ancestors, reverse_idx)) + for reverse_color_ancestor_idx in reverse_color_ancestor: + if len(liberty[reverse_color_ancestor_idx]) == 1: + # capture this group if no liberty left + self._add_captured_liberty(current_board, liberty, group_ancestors, stones[reverse_color_ancestor_idx]) + for captured_stone in stones[reverse_color_ancestor_idx]: + current_board[captured_stone] = utils.EMPTY + del group_ancestors[captured_stone] + del liberty[reverse_color_ancestor_idx] + del stones[reverse_color_ancestor_idx] + else: + # remove this liberty + liberty[reverse_color_ancestor_idx].remove(idx) + + def executor_do_move(self, history, history_hashtable, latest_boards, current_board, group_ancestors, liberty, stones, color, vertex): + print("===") + print(color, vertex) + print(group_ancestors, liberty, stones) + if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex): return False - current_board[self._flatten(vertex)] = color - self._process_board(current_board, color, vertex) + idx = self._flatten(vertex) + current_board[idx] = color + color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color) + if color_neighbor: # join nearby groups + self._join_group(idx, color_neighbor, empty_neighbor, group_ancestors, liberty, stones) + else: # build a new group + group_ancestors[idx] = idx + liberty[idx] = set(empty_neighbor) + stones[idx] = {idx} + if reverse_color_neighbor: # remove liberty for nearby reverse color + self._remove_liberty(idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones) history.append(copy.deepcopy(current_board)) latest_boards.append(copy.deepcopy(current_board)) - history_hashtable.add(copy.deepcopy(tuple(current_board))) + history_hashtable.add(tuple(current_board)) return True def _find_empty(self, current_board): @@ -362,6 +438,8 @@ if __name__ == "__main__": 1, 0, 1, 1, 1, 1, 1, -1, 0, 1, 1, 0, 1, -1, -1, -1, -1, -1 ] + + ''' time0 = time.time() score = go.executor_get_score(endgame) time1 = time.time() @@ -370,6 +448,7 @@ if __name__ == "__main__": time2 = time.time() print(score, time2 - time1) ''' + ''' ### do unit test for Go class pure_test = [ 0, 1, 0, 1, 0, 1, 0, 0, 0, diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 1066624..be54ad2 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -28,10 +28,10 @@ def play(engine, data_path): size = {"go": 9, "reversi": 8} show = ['.', 'X', 'O'] - # evaluate_rounds = 100 + evaluate_rounds = 1 game_num = 0 - while True: - # while game_num < evaluate_rounds: + #while True: + while game_num < evaluate_rounds: engine._game.model.check_latest_model() num = 0 pass_flag = [False, False] From 0e4aa44ebb50e88bdb25b51e030b5e7ed230bf8a Mon Sep 17 00:00:00 2001 From: Wenbo Date: Wed, 17 Jan 2018 15:54:46 +0800 Subject: [PATCH 2/7] add deepcopy for hash, add some testing --- AlphaGo/game.py | 2 +- AlphaGo/go.py | 8 ++++---- AlphaGo/play.py | 13 ++++++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/AlphaGo/game.py b/AlphaGo/game.py index a299e97..9b4ba1e 100644 --- a/AlphaGo/game.py +++ b/AlphaGo/game.py @@ -90,7 +90,7 @@ class Game: evaluator = lambda state:self.model(role, state) mcts = MCTS(self.game_engine, evaluator, [latest_boards, color], self.size ** 2 + 1, role=role, debug=self.debug, inverse=True) - mcts.search(max_step=5) + mcts.search(max_step=100) if self.debug: file = open("mcts_debug.log", 'ab') np.savetxt(file, mcts.root.Q, header="\n" + role + " Q value : ", fmt='%.4f', newline=", ") diff --git a/AlphaGo/go.py b/AlphaGo/go.py index 84e6b7d..5d4d21e 100644 --- a/AlphaGo/go.py +++ b/AlphaGo/go.py @@ -309,9 +309,9 @@ class Go: liberty[reverse_color_ancestor_idx].remove(idx) def executor_do_move(self, history, history_hashtable, latest_boards, current_board, group_ancestors, liberty, stones, color, vertex): - print("===") - print(color, vertex) - print(group_ancestors, liberty, stones) + #print("===") + #print(color, vertex) + #print(group_ancestors, liberty, stones) if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex): return False idx = self._flatten(vertex) @@ -327,7 +327,7 @@ class Go: self._remove_liberty(idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones) history.append(copy.deepcopy(current_board)) latest_boards.append(copy.deepcopy(current_board)) - history_hashtable.add(tuple(current_board)) + history_hashtable.add(copy.deepcopy(tuple(current_board))) return True def _find_empty(self, current_board): diff --git a/AlphaGo/play.py b/AlphaGo/play.py index be54ad2..b877b87 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -1,3 +1,4 @@ +from __future__ import division import argparse import sys import re @@ -28,10 +29,13 @@ def play(engine, data_path): size = {"go": 9, "reversi": 8} show = ['.', 'X', 'O'] - evaluate_rounds = 1 + evaluate_rounds = 5 game_num = 0 + total_time = 0 + f=open('time.txt','w') #while True: while game_num < evaluate_rounds: + start = time.time() engine._game.model.check_latest_model() num = 0 pass_flag = [False, False] @@ -77,6 +81,13 @@ def play(engine, data_path): cPickle.dump(data, file) data.reset() game_num += 1 + + this_time = time.time() - start + total += this_time + f.write('time:'+ str(this_time)+'\n') + f.write('Avg time:' + str(total/evaluate_rounds)) + f.close() + if __name__ == '__main__': From 0131bcdc85070a37abd3d4a4471ae69da4a4a78d Mon Sep 17 00:00:00 2001 From: Wenbo Date: Wed, 17 Jan 2018 15:57:41 +0800 Subject: [PATCH 3/7] fix minor --- AlphaGo/play.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/AlphaGo/play.py b/AlphaGo/play.py index b877b87..e1947a4 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -31,7 +31,7 @@ def play(engine, data_path): evaluate_rounds = 5 game_num = 0 - total_time = 0 + total = 0 f=open('time.txt','w') #while True: while game_num < evaluate_rounds: From 764f7fb5f116c0915767214c06fd2f4887b1ce77 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Fri, 23 Feb 2018 23:15:04 +0800 Subject: [PATCH 4/7] minor fix of play.py --- AlphaGo/play.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/AlphaGo/play.py b/AlphaGo/play.py index 22d4422..f6740ca 100644 --- a/AlphaGo/play.py +++ b/AlphaGo/play.py @@ -29,12 +29,12 @@ def play(engine, data_path): size = {"go": 9, "reversi": 8} show = ['.', 'X', 'O'] - evaluate_rounds = 5 + evaluate_rounds = 0 game_num = 0 total = 0 f=open('time.txt','w') - #while True: - while game_num < evaluate_rounds: + while True: + #while game_num < evaluate_rounds: start = time.time() engine._game.model.check_latest_model() num = 0 @@ -85,6 +85,7 @@ def play(engine, data_path): this_time = time.time() - start total += this_time f.write('time:'+ str(this_time)+'\n') + evaluate_rounds += 1 f.write('Avg time:' + str(total/evaluate_rounds)) f.close() From f3aee448e0f37369cc530dcbb5ad8924d7649e95 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Sat, 24 Feb 2018 10:53:39 +0800 Subject: [PATCH 5/7] add option to show the running result of cartpole --- examples/ppo_cartpole.py | 6 +++++- tianshou/data/batch.py | 5 ++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/ppo_cartpole.py b/examples/ppo_cartpole.py index e88a379..074331d 100755 --- a/examples/ppo_cartpole.py +++ b/examples/ppo_cartpole.py @@ -5,6 +5,7 @@ import tensorflow as tf import gym import numpy as np import time +import argparse # our lib imports here! It's ok to append path in examples import sys @@ -16,6 +17,9 @@ import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--render", action="store_true", default=False) + args = parser.parse_args() env = gym.make('CartPole-v0') observation_dim = env.observation_space.shape action_dim = env.action_space.n @@ -55,7 +59,7 @@ if __name__ == '__main__': train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables) ### 3. define data collection - training_data = Batch(env, pi, [advantage_estimation.full_return], [pi]) + training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render) ### 4. start training config = tf.ConfigProto() diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 01ca78d..9c7405d 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -9,7 +9,7 @@ class Batch(object): class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy. """ - def __init__(self, env, pi, reward_processors, networks): # how to name the function? + def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function? """ constructor :param env: @@ -24,6 +24,7 @@ class Batch(object): self.reward_processors = reward_processors self.networks = networks + self.render = render self.required_placeholders = {} for net in self.networks: @@ -108,6 +109,8 @@ class Batch(object): ac = self._pi.act(ob, my_feed_dict) actions.append(ac) + if self.render: + self._env.render() ob, reward, done, _ = self._env.step(ac) rewards.append(reward) From a40e5aec54e0483a3395eb1d221c8044bd1d1026 Mon Sep 17 00:00:00 2001 From: rtz19970824 Date: Sat, 24 Feb 2018 16:26:19 +0800 Subject: [PATCH 6/7] modified README --- README.md | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index bc2414a..7f72b02 100644 --- a/README.md +++ b/README.md @@ -83,12 +83,6 @@ Try to use full names. Don't use abbrevations for class/function/variable names The """xxx""" comment should be written right after class/function. Also comment the part that's not intuitive during the code. We must comment, but for now we don't need to polish them. -# High Priority TODO - -For Haosheng and Tongzheng: separate actor and critic, rewrite the interfaces for policy - -Others can still focus on the task below. - ## TODO Search based method parallel. @@ -106,6 +100,18 @@ Note: install openai/gym first to run the Atari environment; note that interface Without preprocessing and other tricks, this example will not train to any meaningful results. Codes should past two tests: individual module test and run through this example code. +## Some bug to fix + +For DQN and other deterministic policy: $\epsilon$-greedy or other exploration during collection? + +In Batch.py, notice that we cannot stop by setting num_timestep + +Magic numbers + +## One idea + +Like zhusuan, we can register losses background so that we need not claim it in the example. + ## Dependency Tensorflow (Version >= 1.4) Gym From 0bc1b63e389f3a3213bf8469a0a2ff96289634cf Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Sun, 25 Feb 2018 16:31:35 +0800 Subject: [PATCH 7/7] add epsilon-greedy for dqn --- examples/dqn_example.py | 6 ++++-- tianshou/core/policy/dqn.py | 5 ++--- tianshou/data/batch.py | 14 +++++++++----- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/dqn_example.py b/examples/dqn_example.py index 70c9e4b..ee18863 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -66,9 +66,11 @@ if __name__ == '__main__': pi.sync_weights() # TODO: automate this for policies with target network start_time = time.time() - for i in range(100): + #TODO : repeat_num shoulde be defined in some configuration files + repeat_num = 100 + for i in range(repeat_num): # collect data - data_collector.collect(num_episodes=50) + data_collector.collect(num_episodes=50, epsilon_greedy= (repeat_num - i + 0.0) / repeat_num) # print current return print('Epoch {}:'.format(i)) diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index bc5db67..5cef57a 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -18,7 +18,7 @@ class DQN(PolicyBase): else: self.interaction_count = -1 - def act(self, observation, exploration=None): + def act(self, observation, my_feed_dict): sess = tf.get_default_session() if self.weight_update > 1: if self.interaction_count % self.weight_update == 0: @@ -30,8 +30,7 @@ class DQN(PolicyBase): if self.weight_update > 0: self.interaction_count += 1 - if not exploration: - return np.squeeze(action) + return np.squeeze(action) @property def q_net(self): diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 9c7405d..d559ded 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -34,7 +34,7 @@ class Batch(object): self._is_first_collect = True def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={}, - process_reward=True): # specify how many data to collect here, or fix it in __init__() + process_reward=True, epsilon_greedy=0): # specify how many data to collect here, or fix it in __init__() assert sum( [num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!" @@ -106,7 +106,11 @@ class Batch(object): episode_start_flags.append(True) while True: - ac = self._pi.act(ob, my_feed_dict) + # a simple implementation of epsilon greedy + if epsilon_greedy > 0 and np.random.random() < epsilon_greedy: + ac = np.random.randint(low = 0, high = self._env.action_space.n) + else: + ac = self._pi.act(ob, my_feed_dict) actions.append(ac) if self.render: @@ -114,9 +118,9 @@ class Batch(object): ob, reward, done, _ = self._env.step(ac) rewards.append(reward) - t_count += 1 - if t_count >= 100: # force episode stop, just to test if memory still grows - break + #t_count += 1 + #if t_count >= 100: # force episode stop, just to test if memory still grows + # break if done: # end of episode, discard s_T # TODO: for num_timesteps collection, has to store terminal flag instead of start flag!