111 lines
4.4 KiB
Python
111 lines
4.4 KiB
Python
|
import tensorflow as tf
|
||
|
from .base import PolicyBase
|
||
|
|
||
|
class Deterministic(PolicyBase):
|
||
|
"""
|
||
|
deterministic policy as used in deterministic policy gradient methods
|
||
|
"""
|
||
|
def __init__(self, policy_callable, observation_placeholder, weight_update=1):
|
||
|
self._observation_placeholder = observation_placeholder
|
||
|
self.managed_placeholders = {'observation': observation_placeholder}
|
||
|
self.weight_update = weight_update
|
||
|
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
|
||
|
|
||
|
# build network, action and value
|
||
|
with tf.variable_scope('network', reuse=tf.AUTO_REUSE):
|
||
|
action, _ = policy_callable()
|
||
|
self.action = action
|
||
|
# TODO: self._action should be exactly the action tensor to run that directly gives action_dim
|
||
|
|
||
|
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
|
||
|
|
||
|
# deal with target network
|
||
|
if self.weight_update == 1:
|
||
|
self.weight_update_ops = None
|
||
|
self.sync_weights_ops = None
|
||
|
else: # then we need to build another tf graph as target network
|
||
|
with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
|
||
|
action, _ = policy_callable()
|
||
|
self.action_old = action
|
||
|
|
||
|
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
|
||
|
network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
|
||
|
# TODO: use a scope that the user will almost surely not use. so get_collection will return
|
||
|
# the correct weights and old_weights, since it filters by regular expression
|
||
|
# or we write a util to parse the variable names and use only the topmost scope
|
||
|
|
||
|
assert len(network_weights) == len(network_old_weights)
|
||
|
self.sync_weights_ops = [tf.assign(variable_old, variable)
|
||
|
for (variable_old, variable) in zip(network_old_weights, network_weights)]
|
||
|
|
||
|
if weight_update == 0:
|
||
|
self.weight_update_ops = self.sync_weights_ops
|
||
|
elif 0 < weight_update < 1: # as in DDPG
|
||
|
self.weight_update_ops = [tf.assign(variable_old,
|
||
|
weight_update * variable + (1 - weight_update) * variable_old)
|
||
|
for (variable_old, variable) in zip(network_old_weights, network_weights)]
|
||
|
else:
|
||
|
self.interaction_count = 0 # as in DQN
|
||
|
import math
|
||
|
self.weight_update = math.ceil(weight_update)
|
||
|
|
||
|
@property
|
||
|
def action_shape(self):
|
||
|
return self.action.shape.as_list()[1:]
|
||
|
|
||
|
def act(self, observation, my_feed_dict={}):
|
||
|
# TODO: this may be ugly. also maybe huge problem when parallel
|
||
|
sess = tf.get_default_session()
|
||
|
# observation[None] adds one dimension at the beginning
|
||
|
|
||
|
feed_dict = {self._observation_placeholder: observation[None]}
|
||
|
feed_dict.update(my_feed_dict)
|
||
|
sampled_action = sess.run(self.action, feed_dict=feed_dict)
|
||
|
|
||
|
sampled_action = sampled_action[0]
|
||
|
|
||
|
return sampled_action
|
||
|
|
||
|
def update_weights(self):
|
||
|
"""
|
||
|
updates the weights of policy_old.
|
||
|
:return:
|
||
|
"""
|
||
|
if self.weight_update_ops is not None:
|
||
|
sess = tf.get_default_session()
|
||
|
sess.run(self.weight_update_ops)
|
||
|
|
||
|
def sync_weights(self):
|
||
|
"""
|
||
|
sync the weights of network_old. Direct copy the weights of network.
|
||
|
:return:
|
||
|
"""
|
||
|
if self.sync_weights_ops is not None:
|
||
|
sess = tf.get_default_session()
|
||
|
sess.run(self.sync_weights_ops)
|
||
|
|
||
|
def eval_action(self, observation):
|
||
|
"""
|
||
|
evaluate action in minibatch
|
||
|
:param observation:
|
||
|
:return: 2-D numpy array
|
||
|
"""
|
||
|
sess = tf.get_default_session()
|
||
|
|
||
|
feed_dict = {self._observation_placeholder: observation}
|
||
|
action = sess.run(self.action, feed_dict=feed_dict)
|
||
|
|
||
|
return action
|
||
|
|
||
|
def eval_action_old(self, observation):
|
||
|
"""
|
||
|
evaluate action in minibatch
|
||
|
:param observation:
|
||
|
:return: 2-D numpy array
|
||
|
"""
|
||
|
sess = tf.get_default_session()
|
||
|
|
||
|
feed_dict = {self._observation_placeholder: observation}
|
||
|
action = sess.run(self.action_old, feed_dict=feed_dict)
|
||
|
|
||
|
return action
|