Merge remote-tracking branch 'origin/master'
# Conflicts: # README.md
This commit is contained in:
commit
40190a282e
@ -40,6 +40,9 @@ class Game:
|
||||
self.history_hashtable = set()
|
||||
self.game_engine = go.Go(size=self.size, komi=self.komi)
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
self.group_ancestors = {} # key: idx, value: ancestor idx
|
||||
self.liberty = {} # key: ancestor idx, value: set of liberty
|
||||
self.stones = {} # key: ancestor idx, value: set of stones
|
||||
elif self.name == "reversi":
|
||||
self.size = 8
|
||||
self.history_length = 1
|
||||
@ -62,6 +65,9 @@ class Game:
|
||||
self.board = [utils.EMPTY] * (self.size ** 2)
|
||||
del self.history[:]
|
||||
self.history_hashtable.clear()
|
||||
self.group_ancestors.clear()
|
||||
self.liberty.clear()
|
||||
self.stones.clear()
|
||||
if self.name == "reversi":
|
||||
self.board = self.game_engine.get_board()
|
||||
for _ in range(self.history_length):
|
||||
@ -109,7 +115,7 @@ class Game:
|
||||
res = self.game_engine.executor_do_move(self.history, self.latest_boards, self.board, color, vertex)
|
||||
if self.name == "go":
|
||||
res = self.game_engine.executor_do_move(self.history, self.history_hashtable, self.latest_boards, self.board,
|
||||
color, vertex)
|
||||
self.group_ancestors, self.liberty, self.stones, color, vertex)
|
||||
return res
|
||||
|
||||
def think_play_move(self, color):
|
||||
|
||||
201
AlphaGo/go.py
201
AlphaGo/go.py
@ -44,6 +44,23 @@ class Go:
|
||||
nei.append((_x, _y))
|
||||
return nei
|
||||
|
||||
def _neighbor_color(self, current_board, vertex, color):
|
||||
# return neighbors which are listed in different colors
|
||||
color_neighbor = [] # 1)neighbors in the same color
|
||||
reverse_color_neighbor = [] # 2)neighbors in the reverse color
|
||||
empty_neighbor = [] # 2)empty neighbors
|
||||
reverse_color = utils.BLACK if color == utils.WHITE else utils.WHITE
|
||||
for n in self._neighbor(vertex):
|
||||
if current_board[self._flatten(n)] == color:
|
||||
color_neighbor.append(self._flatten(n))
|
||||
elif current_board[self._flatten(n)] == utils.EMPTY:
|
||||
empty_neighbor.append(self._flatten(n))
|
||||
elif current_board[self._flatten(n)] == reverse_color:
|
||||
reverse_color_neighbor.append(self._flatten(n))
|
||||
else:
|
||||
raise ValueError("board have other positions excluding BLACK, WHITE and EMPTY")
|
||||
return color_neighbor, reverse_color_neighbor, empty_neighbor
|
||||
|
||||
def _corner(self, vertex):
|
||||
x, y = vertex
|
||||
corner = []
|
||||
@ -56,13 +73,11 @@ class Go:
|
||||
|
||||
def _find_group(self, current_board, vertex):
|
||||
color = current_board[self._flatten(vertex)]
|
||||
# print ("color : ", color)
|
||||
chain = set()
|
||||
frontier = [vertex]
|
||||
has_liberty = False
|
||||
while frontier:
|
||||
current = frontier.pop()
|
||||
# print ("current : ", current)
|
||||
chain.add(current)
|
||||
for n in self._neighbor(current):
|
||||
if current_board[self._flatten(n)] == color and not n in chain:
|
||||
@ -71,21 +86,26 @@ class Go:
|
||||
has_liberty = True
|
||||
return has_liberty, chain
|
||||
|
||||
def _is_suicide(self, current_board, color, vertex):
|
||||
current_board[self._flatten(vertex)] = color # assume that we already take this move
|
||||
suicide = False
|
||||
def _find_ancestor(self, group_ancestors, idx):
|
||||
r = idx
|
||||
while group_ancestors[r] != r:
|
||||
r = group_ancestors[r]
|
||||
group_ancestors[idx] = r
|
||||
return r
|
||||
|
||||
has_liberty, group = self._find_group(current_board, vertex)
|
||||
if not has_liberty:
|
||||
suicide = True # no liberty, suicide
|
||||
for n in self._neighbor(vertex):
|
||||
if current_board[self._flatten(n)] == utils.another_color(color):
|
||||
opponent_liberty, group = self._find_group(current_board, n)
|
||||
if not opponent_liberty:
|
||||
suicide = False # this move is able to take opponent's stone, not suicide
|
||||
|
||||
current_board[self._flatten(vertex)] = utils.EMPTY # undo this move
|
||||
return suicide
|
||||
def _is_suicide(self, current_board, group_ancestors, liberty, color, vertex):
|
||||
color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
|
||||
if empty_neighbor:
|
||||
return False # neighbors have empty spaces
|
||||
elif color_neighbor: # neighbors have same color, they have liberties
|
||||
for idx in color_neighbor:
|
||||
if len(liberty[self._find_ancestor(group_ancestors, idx)]) > 1:
|
||||
return False
|
||||
else: # neighbors have reverse color, they have only one liberty
|
||||
for idx in reverse_color_neighbor:
|
||||
if len(liberty[self._find_ancestor(group_ancestors, idx)]) == 1:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _process_board(self, current_board, color, vertex):
|
||||
nei = self._neighbor(vertex)
|
||||
@ -107,25 +127,15 @@ class Go:
|
||||
return repeat
|
||||
|
||||
def _is_eye(self, current_board, color, vertex):
|
||||
nei = self._neighbor(vertex)
|
||||
cor = self._corner(vertex)
|
||||
ncolor = {color == current_board[self._flatten(n)] for n in nei}
|
||||
if False in ncolor:
|
||||
# print "not all neighbors are in same color with us"
|
||||
# return is this position is an real eye of color
|
||||
color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
|
||||
if reverse_color_neighbor or empty_neighbor: # not an eye
|
||||
return False
|
||||
_, group = self._find_group(current_board, nei[0])
|
||||
if set(nei) < group:
|
||||
# print "all neighbors are in same group and same color with us"
|
||||
return True
|
||||
else:
|
||||
opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color)
|
||||
opponent_propotion = float(opponent_number) / float(len(cor))
|
||||
if opponent_propotion < 0.5:
|
||||
# print "few opponents, real eye"
|
||||
return True
|
||||
else:
|
||||
# print "many opponents, fake eye"
|
||||
return False
|
||||
cor = self._corner(vertex)
|
||||
opponent_number = [current_board[self._flatten(c)] for c in cor].count(-color)
|
||||
opponent_propotion = float(opponent_number) / float(len(cor))
|
||||
# opponent_propotion<0.5 fake eye
|
||||
return True if opponent_propotion < 0.5 else False
|
||||
|
||||
def _knowledge_prunning(self, current_board, color, vertex):
|
||||
# forbid some stupid selfplay using human knowledge
|
||||
@ -134,23 +144,6 @@ class Go:
|
||||
# forbid position on its own eye.
|
||||
return True
|
||||
|
||||
def _is_game_finished(self, current_board, color):
|
||||
'''
|
||||
for each empty position, if it has both BLACK and WHITE neighbors, the game is still not finished
|
||||
:return: return the game is finished
|
||||
'''
|
||||
board = copy.deepcopy(current_board)
|
||||
empty_idx = [i for i, x in enumerate(board) if x == utils.EMPTY] # find all empty idx
|
||||
for idx in empty_idx:
|
||||
neighbor_idx = self._neighbor(self.deflatten(idx))
|
||||
if len(neighbor_idx) > 1:
|
||||
first_idx = neighbor_idx[0]
|
||||
for other_idx in neighbor_idx[1:]:
|
||||
if board[self.flatten(other_idx)] != board[self.flatten(first_idx)]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _action2vertex(self, action):
|
||||
if action == self.size ** 2:
|
||||
vertex = (0, 0)
|
||||
@ -158,7 +151,7 @@ class Go:
|
||||
vertex = self._deflatten(action)
|
||||
return vertex
|
||||
|
||||
def _rule_check(self, history_hashtable, current_board, color, vertex, is_thinking=True):
|
||||
def _rule_check(self, history_hashtable, current_board, group_ancestors, liberty, color, vertex, is_thinking=True):
|
||||
### in board
|
||||
if not self._in_board(vertex):
|
||||
if not is_thinking:
|
||||
@ -174,7 +167,7 @@ class Go:
|
||||
return False
|
||||
|
||||
### check if it is suicide
|
||||
if self._is_suicide(current_board, color, vertex):
|
||||
if self._is_suicide(current_board, group_ancestors, liberty, color, vertex):
|
||||
if not is_thinking:
|
||||
raise ValueError("Target point causes suicide, Current Board: {}, color: {}, vertex : {}".format(current_board, color, vertex))
|
||||
else:
|
||||
@ -189,34 +182,59 @@ class Go:
|
||||
|
||||
return True
|
||||
|
||||
def _is_valid(self, state, action, history_hashtable):
|
||||
def _is_valid(self, state, action, history_hashtable, group_ancestors, liberty):
|
||||
history_boards, color = state
|
||||
vertex = self._action2vertex(action)
|
||||
current_board = history_boards[-1]
|
||||
|
||||
if not self._rule_check(history_hashtable, current_board, color, vertex):
|
||||
if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex):
|
||||
return False
|
||||
|
||||
if not self._knowledge_prunning(current_board, color, vertex):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _get_groups(self, board):
|
||||
group_ancestors = {} # key: idx, value: ancestor idx
|
||||
liberty = {} # key: ancestor idx, value: set of liberty
|
||||
for idx, color in enumerate(board):
|
||||
if color and idx not in group_ancestors:
|
||||
# build group
|
||||
group_ancestors[idx] = idx
|
||||
color_neighbor, _, empty_neighbor = \
|
||||
self._neighbor_color(board, self._deflatten(idx), color)
|
||||
liberty[idx] = set(empty_neighbor)
|
||||
group_list = copy.deepcopy(color_neighbor)
|
||||
while group_list:
|
||||
add_idx = group_list.pop()
|
||||
if add_idx not in group_ancestors:
|
||||
group_ancestors[add_idx] = idx
|
||||
color_neighbor_add, _, empty_neighbor_add = \
|
||||
self._neighbor_color(board, self._deflatten(add_idx), color)
|
||||
group_list += color_neighbor_add
|
||||
liberty[idx] |= set(empty_neighbor_add)
|
||||
return group_ancestors, liberty
|
||||
|
||||
def simulate_get_mask(self, state, action_set):
|
||||
# find all the invalid actions
|
||||
invalid_action_mask = []
|
||||
history_boards, color = state
|
||||
group_ancestors, liberty = self._get_groups(history_boards[-1])
|
||||
history_hashtable = set()
|
||||
for board in history_boards:
|
||||
history_hashtable.add(tuple(board))
|
||||
for action_candidate in action_set[:-1]:
|
||||
# go through all the actions excluding pass
|
||||
if not self._is_valid(state, action_candidate, history_hashtable):
|
||||
if not self._is_valid(state, action_candidate, history_hashtable, group_ancestors, liberty):
|
||||
invalid_action_mask.append(action_candidate)
|
||||
if len(invalid_action_mask) < len(action_set) - 1:
|
||||
invalid_action_mask.append(action_set[-1])
|
||||
# forbid pass, if we have other choices
|
||||
# TODO: In fact we should not do this. In some extreme cases, we should permit pass.
|
||||
del history_hashtable
|
||||
del group_ancestors
|
||||
del liberty
|
||||
# del stones
|
||||
return invalid_action_mask
|
||||
|
||||
def _do_move(self, board, color, vertex):
|
||||
@ -243,12 +261,70 @@ class Go:
|
||||
# since go is MDP, we only need the last board for hashing
|
||||
return tuple(state[0][-1])
|
||||
|
||||
def executor_do_move(self, history, history_hashtable, latest_boards, current_board, color, vertex):
|
||||
if not self._rule_check(history_hashtable, current_board, color, vertex, is_thinking=False):
|
||||
# raise ValueError("!!! We have more than four ko at the same time !!!")
|
||||
def _join_group(self, idx, idx_list, empty_neighbor, group_ancestors, liberty, stones):
|
||||
# idx joins its neighbors id_list
|
||||
# empty_neighbor: empty neighbors of idx
|
||||
color_ancestor = set()
|
||||
for color_idx in idx_list:
|
||||
color_ancestor.add(self._find_ancestor(group_ancestors, color_idx))
|
||||
joined_ancestor = color_ancestor.pop()
|
||||
liberty[joined_ancestor] |= set(empty_neighbor)
|
||||
stones[joined_ancestor].add(idx)
|
||||
group_ancestors[idx] = joined_ancestor
|
||||
# add other groups
|
||||
for color_idx in color_ancestor:
|
||||
liberty[joined_ancestor] |= liberty[color_idx]
|
||||
stones[joined_ancestor] |= stones[color_idx]
|
||||
del liberty[color_idx]
|
||||
for stone in stones[color_idx]:
|
||||
group_ancestors[stone] = joined_ancestor
|
||||
del stones[color_idx]
|
||||
liberty[joined_ancestor].remove(idx)
|
||||
|
||||
def _add_captured_liberty(self, board, liberty, group_ancestors, stones):
|
||||
for captured_stone in stones:
|
||||
color_neighbor, reverse_color_neighbor, empty_neighbor = \
|
||||
self._neighbor_color(board, self._deflatten(captured_stone), board[captured_stone])
|
||||
assert not empty_neighbor # make sure no empty spaces
|
||||
for reverse_color_idx in reverse_color_neighbor:
|
||||
reverse_color_idx_ancestor = self._find_ancestor(group_ancestors, reverse_color_idx)
|
||||
liberty[reverse_color_idx_ancestor].add(captured_stone)
|
||||
|
||||
def _remove_liberty(self, idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones):
|
||||
# reverse_color_neighbor: stones near idx in the reverse color
|
||||
reverse_color_ancestor = set()
|
||||
for reverse_idx in reverse_color_neighbor:
|
||||
reverse_color_ancestor.add(self._find_ancestor(group_ancestors, reverse_idx))
|
||||
for reverse_color_ancestor_idx in reverse_color_ancestor:
|
||||
if len(liberty[reverse_color_ancestor_idx]) == 1:
|
||||
# capture this group if no liberty left
|
||||
self._add_captured_liberty(current_board, liberty, group_ancestors, stones[reverse_color_ancestor_idx])
|
||||
for captured_stone in stones[reverse_color_ancestor_idx]:
|
||||
current_board[captured_stone] = utils.EMPTY
|
||||
del group_ancestors[captured_stone]
|
||||
del liberty[reverse_color_ancestor_idx]
|
||||
del stones[reverse_color_ancestor_idx]
|
||||
else:
|
||||
# remove this liberty
|
||||
liberty[reverse_color_ancestor_idx].remove(idx)
|
||||
|
||||
def executor_do_move(self, history, history_hashtable, latest_boards, current_board, group_ancestors, liberty, stones, color, vertex):
|
||||
#print("===")
|
||||
#print(color, vertex)
|
||||
#print(group_ancestors, liberty, stones)
|
||||
if not self._rule_check(history_hashtable, current_board, group_ancestors, liberty, color, vertex):
|
||||
return False
|
||||
current_board[self._flatten(vertex)] = color
|
||||
self._process_board(current_board, color, vertex)
|
||||
idx = self._flatten(vertex)
|
||||
current_board[idx] = color
|
||||
color_neighbor, reverse_color_neighbor, empty_neighbor = self._neighbor_color(current_board, vertex, color)
|
||||
if color_neighbor: # join nearby groups
|
||||
self._join_group(idx, color_neighbor, empty_neighbor, group_ancestors, liberty, stones)
|
||||
else: # build a new group
|
||||
group_ancestors[idx] = idx
|
||||
liberty[idx] = set(empty_neighbor)
|
||||
stones[idx] = {idx}
|
||||
if reverse_color_neighbor: # remove liberty for nearby reverse color
|
||||
self._remove_liberty(idx, reverse_color_neighbor, current_board, group_ancestors, liberty, stones)
|
||||
history.append(copy.deepcopy(current_board))
|
||||
latest_boards.append(copy.deepcopy(current_board))
|
||||
history_hashtable.add(copy.deepcopy(tuple(current_board)))
|
||||
@ -362,6 +438,8 @@ if __name__ == "__main__":
|
||||
1, 0, 1, 1, 1, 1, 1, -1, 0,
|
||||
1, 1, 0, 1, -1, -1, -1, -1, -1
|
||||
]
|
||||
|
||||
'''
|
||||
time0 = time.time()
|
||||
score = go.executor_get_score(endgame)
|
||||
time1 = time.time()
|
||||
@ -370,6 +448,7 @@ if __name__ == "__main__":
|
||||
time2 = time.time()
|
||||
print(score, time2 - time1)
|
||||
'''
|
||||
'''
|
||||
### do unit test for Go class
|
||||
pure_test = [
|
||||
0, 1, 0, 1, 0, 1, 0, 0, 0,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from __future__ import division
|
||||
import argparse
|
||||
import sys
|
||||
import re
|
||||
@ -28,10 +29,13 @@ def play(engine, data_path):
|
||||
size = {"go": 9, "reversi": 8}
|
||||
show = ['.', 'X', 'O']
|
||||
|
||||
# evaluate_rounds = 100
|
||||
evaluate_rounds = 0
|
||||
game_num = 0
|
||||
total = 0
|
||||
f=open('time.txt','w')
|
||||
while True:
|
||||
# while game_num < evaluate_rounds:
|
||||
#while game_num < evaluate_rounds:
|
||||
start = time.time()
|
||||
engine._game.model.check_latest_model()
|
||||
num = 0
|
||||
pass_flag = [False, False]
|
||||
@ -77,6 +81,14 @@ def play(engine, data_path):
|
||||
cPickle.dump(data, file)
|
||||
data.reset()
|
||||
game_num += 1
|
||||
|
||||
this_time = time.time() - start
|
||||
total += this_time
|
||||
f.write('time:'+ str(this_time)+'\n')
|
||||
evaluate_rounds += 1
|
||||
f.write('Avg time:' + str(total/evaluate_rounds))
|
||||
f.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
23
README.md
23
README.md
@ -83,12 +83,6 @@ Try to use full names. Don't use abbrevations for class/function/variable names
|
||||
|
||||
The """xxx""" comment should be written right after class/function. Also comment the part that's not intuitive during the code. We must comment, but for now we don't need to polish them.
|
||||
|
||||
# High Priority TODO
|
||||
|
||||
For Haosheng and Tongzheng: separate actor and critic, rewrite the interfaces for policy
|
||||
|
||||
Others can still focus on the task below.
|
||||
|
||||
## TODO
|
||||
Search based method parallel.
|
||||
|
||||
@ -106,7 +100,18 @@ Note: install openai/gym first to run the Atari environment; note that interface
|
||||
|
||||
Without preprocessing and other tricks, this example will not train to any meaningful results. Codes should past two tests: individual module test and run through this example code.
|
||||
|
||||
## Dependencies
|
||||
TensorFlow (Version >= 1.4)
|
||||
## Some bug to fix
|
||||
|
||||
gym
|
||||
For DQN and other deterministic policy: $\epsilon$-greedy or other exploration during collection?
|
||||
|
||||
In Batch.py, notice that we cannot stop by setting num_timestep
|
||||
|
||||
Magic numbers
|
||||
|
||||
## One idea
|
||||
|
||||
Like zhusuan, we can register losses background so that we need not claim it in the example.
|
||||
|
||||
## Dependency
|
||||
Tensorflow (Version >= 1.4)
|
||||
Gym
|
||||
|
||||
@ -66,9 +66,11 @@ if __name__ == '__main__':
|
||||
pi.sync_weights() # TODO: automate this for policies with target network
|
||||
|
||||
start_time = time.time()
|
||||
for i in range(100):
|
||||
#TODO : repeat_num shoulde be defined in some configuration files
|
||||
repeat_num = 100
|
||||
for i in range(repeat_num):
|
||||
# collect data
|
||||
data_collector.collect(num_episodes=50)
|
||||
data_collector.collect(num_episodes=50, epsilon_greedy= (repeat_num - i + 0.0) / repeat_num)
|
||||
|
||||
# print current return
|
||||
print('Epoch {}:'.format(i))
|
||||
|
||||
@ -5,6 +5,7 @@ import tensorflow as tf
|
||||
import gym
|
||||
import numpy as np
|
||||
import time
|
||||
import argparse
|
||||
|
||||
# our lib imports here! It's ok to append path in examples
|
||||
import sys
|
||||
@ -16,6 +17,9 @@ import tianshou.core.policy.stochastic as policy # TODO: fix imports as zhusuan
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--render", action="store_true", default=False)
|
||||
args = parser.parse_args()
|
||||
env = gym.make('CartPole-v0')
|
||||
observation_dim = env.observation_space.shape
|
||||
action_dim = env.action_space.n
|
||||
@ -55,7 +59,7 @@ if __name__ == '__main__':
|
||||
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
|
||||
|
||||
### 3. define data collection
|
||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi])
|
||||
training_data = Batch(env, pi, [advantage_estimation.full_return], [pi], render = args.render)
|
||||
|
||||
### 4. start training
|
||||
config = tf.ConfigProto()
|
||||
|
||||
@ -18,7 +18,7 @@ class DQN(PolicyBase):
|
||||
else:
|
||||
self.interaction_count = -1
|
||||
|
||||
def act(self, observation, exploration=None):
|
||||
def act(self, observation, my_feed_dict):
|
||||
sess = tf.get_default_session()
|
||||
if self.weight_update > 1:
|
||||
if self.interaction_count % self.weight_update == 0:
|
||||
@ -30,8 +30,7 @@ class DQN(PolicyBase):
|
||||
if self.weight_update > 0:
|
||||
self.interaction_count += 1
|
||||
|
||||
if not exploration:
|
||||
return np.squeeze(action)
|
||||
return np.squeeze(action)
|
||||
|
||||
@property
|
||||
def q_net(self):
|
||||
|
||||
@ -9,7 +9,7 @@ class Batch(object):
|
||||
class for batch datasets. Collect multiple observations (actions, rewards, etc.) on-policy.
|
||||
"""
|
||||
|
||||
def __init__(self, env, pi, reward_processors, networks): # how to name the function?
|
||||
def __init__(self, env, pi, reward_processors, networks, render=False): # how to name the function?
|
||||
"""
|
||||
constructor
|
||||
:param env:
|
||||
@ -24,6 +24,7 @@ class Batch(object):
|
||||
|
||||
self.reward_processors = reward_processors
|
||||
self.networks = networks
|
||||
self.render = render
|
||||
|
||||
self.required_placeholders = {}
|
||||
for net in self.networks:
|
||||
@ -33,7 +34,7 @@ class Batch(object):
|
||||
self._is_first_collect = True
|
||||
|
||||
def collect(self, num_timesteps=0, num_episodes=0, my_feed_dict={},
|
||||
process_reward=True): # specify how many data to collect here, or fix it in __init__()
|
||||
process_reward=True, epsilon_greedy=0): # specify how many data to collect here, or fix it in __init__()
|
||||
assert sum(
|
||||
[num_timesteps > 0, num_episodes > 0]) == 1, "One and only one collection number specification permitted!"
|
||||
|
||||
@ -105,15 +106,21 @@ class Batch(object):
|
||||
episode_start_flags.append(True)
|
||||
|
||||
while True:
|
||||
ac = self._pi.act(ob, my_feed_dict)
|
||||
# a simple implementation of epsilon greedy
|
||||
if epsilon_greedy > 0 and np.random.random() < epsilon_greedy:
|
||||
ac = np.random.randint(low = 0, high = self._env.action_space.n)
|
||||
else:
|
||||
ac = self._pi.act(ob, my_feed_dict)
|
||||
actions.append(ac)
|
||||
|
||||
if self.render:
|
||||
self._env.render()
|
||||
ob, reward, done, _ = self._env.step(ac)
|
||||
rewards.append(reward)
|
||||
|
||||
t_count += 1
|
||||
if t_count >= 100: # force episode stop, just to test if memory still grows
|
||||
break
|
||||
#t_count += 1
|
||||
#if t_count >= 100: # force episode stop, just to test if memory still grows
|
||||
# break
|
||||
|
||||
if done: # end of episode, discard s_T
|
||||
# TODO: for num_timesteps collection, has to store terminal flag instead of start flag!
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user