2018-01-18 17:38:52 +08:00
|
|
|
import tensorflow as tf
|
2018-04-11 14:23:40 +08:00
|
|
|
import logging
|
|
|
|
|
2018-01-18 17:38:52 +08:00
|
|
|
from .base import PolicyBase
|
2018-03-11 17:47:42 +08:00
|
|
|
from ..random import OrnsteinUhlenbeckProcess
|
2018-04-11 14:23:40 +08:00
|
|
|
from ..utils import identify_dependent_variables
|
|
|
|
|
2018-01-18 17:38:52 +08:00
|
|
|
|
|
|
|
class Deterministic(PolicyBase):
|
|
|
|
"""
|
2018-04-11 14:23:40 +08:00
|
|
|
deterministic policy as used in deterministic policy gradient (DDPG) methods
|
2018-01-18 17:38:52 +08:00
|
|
|
"""
|
2018-04-11 14:23:40 +08:00
|
|
|
def __init__(self, network_callable, observation_placeholder, has_old_net=False, random_process=None):
|
|
|
|
self.observation_placeholder = observation_placeholder
|
2018-01-18 17:38:52 +08:00
|
|
|
self.managed_placeholders = {'observation': observation_placeholder}
|
2018-04-11 14:23:40 +08:00
|
|
|
|
|
|
|
self.has_old_net = has_old_net
|
|
|
|
|
|
|
|
network_scope = 'network'
|
|
|
|
net_old_scope = 'net_old'
|
2018-01-18 17:38:52 +08:00
|
|
|
|
|
|
|
# build network, action and value
|
2018-04-11 14:23:40 +08:00
|
|
|
with tf.variable_scope(network_scope, reuse=tf.AUTO_REUSE):
|
|
|
|
action = network_callable()[0]
|
|
|
|
assert action is not None
|
2018-01-18 17:38:52 +08:00
|
|
|
self.action = action
|
|
|
|
|
2018-04-11 14:23:40 +08:00
|
|
|
weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
|
|
|
self.network_weights = identify_dependent_variables(self.action, weights)
|
|
|
|
self._trainable_variables = [var for var in self.network_weights
|
|
|
|
if var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]
|
2018-01-18 17:38:52 +08:00
|
|
|
|
|
|
|
# deal with target network
|
2018-04-11 14:23:40 +08:00
|
|
|
if not has_old_net:
|
2018-01-18 17:38:52 +08:00
|
|
|
self.sync_weights_ops = None
|
|
|
|
else: # then we need to build another tf graph as target network
|
|
|
|
with tf.variable_scope('net_old', reuse=tf.AUTO_REUSE):
|
2018-04-11 14:23:40 +08:00
|
|
|
self.action_old = network_callable()[0]
|
2018-01-18 17:38:52 +08:00
|
|
|
|
2018-04-11 14:23:40 +08:00
|
|
|
old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=net_old_scope)
|
|
|
|
|
|
|
|
# re-filter to rule out some edge cases
|
|
|
|
old_weights = [var for var in old_weights if var.name[:len(net_old_scope)] == net_old_scope]
|
|
|
|
|
|
|
|
self.network_old_weights = identify_dependent_variables(self.action_old, old_weights)
|
|
|
|
assert len(self.network_weights) == len(self.network_old_weights)
|
2018-01-18 17:38:52 +08:00
|
|
|
|
|
|
|
self.sync_weights_ops = [tf.assign(variable_old, variable)
|
2018-04-11 14:23:40 +08:00
|
|
|
for (variable_old, variable) in zip(self.network_old_weights, self.network_weights)]
|
2018-01-18 17:38:52 +08:00
|
|
|
|
2018-04-11 14:23:40 +08:00
|
|
|
# random process for exploration for deterministic policies
|
2018-03-11 17:47:42 +08:00
|
|
|
self.random_process = random_process or OrnsteinUhlenbeckProcess(
|
2018-03-28 18:47:41 +08:00
|
|
|
theta=0.15, sigma=0.3, size=self.action.shape.as_list()[-1])
|
2018-03-11 17:47:42 +08:00
|
|
|
|
2018-01-18 17:38:52 +08:00
|
|
|
@property
|
2018-04-11 14:23:40 +08:00
|
|
|
def trainable_variables(self):
|
|
|
|
return set(self._trainable_variables)
|
2018-01-18 17:38:52 +08:00
|
|
|
|
|
|
|
def act(self, observation, my_feed_dict={}):
|
|
|
|
sess = tf.get_default_session()
|
|
|
|
|
2018-04-11 14:23:40 +08:00
|
|
|
# observation[None] adds one dimension at the beginning
|
|
|
|
feed_dict = {self.observation_placeholder: observation[None]}
|
2018-01-18 17:38:52 +08:00
|
|
|
feed_dict.update(my_feed_dict)
|
|
|
|
sampled_action = sess.run(self.action, feed_dict=feed_dict)
|
|
|
|
|
2018-03-11 17:47:42 +08:00
|
|
|
sampled_action = sampled_action[0] + self.random_process.sample()
|
|
|
|
|
|
|
|
return sampled_action
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
self.random_process.reset_states()
|
|
|
|
|
|
|
|
def act_test(self, observation, my_feed_dict={}):
|
|
|
|
sess = tf.get_default_session()
|
|
|
|
|
2018-04-11 14:23:40 +08:00
|
|
|
# observation[None] adds one dimension at the beginning
|
|
|
|
feed_dict = {self.observation_placeholder: observation[None]}
|
2018-03-11 17:47:42 +08:00
|
|
|
feed_dict.update(my_feed_dict)
|
|
|
|
sampled_action = sess.run(self.action, feed_dict=feed_dict)
|
|
|
|
|
2018-01-18 17:38:52 +08:00
|
|
|
sampled_action = sampled_action[0]
|
|
|
|
|
|
|
|
return sampled_action
|
|
|
|
|
|
|
|
def sync_weights(self):
|
|
|
|
"""
|
|
|
|
sync the weights of network_old. Direct copy the weights of network.
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
if self.sync_weights_ops is not None:
|
|
|
|
sess = tf.get_default_session()
|
|
|
|
sess.run(self.sync_weights_ops)
|
|
|
|
|
|
|
|
def eval_action(self, observation):
|
|
|
|
"""
|
|
|
|
evaluate action in minibatch
|
|
|
|
:param observation:
|
|
|
|
:return: 2-D numpy array
|
|
|
|
"""
|
|
|
|
sess = tf.get_default_session()
|
|
|
|
|
2018-04-11 14:23:40 +08:00
|
|
|
feed_dict = {self.observation_placeholder: observation}
|
2018-01-18 17:38:52 +08:00
|
|
|
action = sess.run(self.action, feed_dict=feed_dict)
|
|
|
|
|
|
|
|
return action
|
|
|
|
|
|
|
|
def eval_action_old(self, observation):
|
|
|
|
"""
|
|
|
|
evaluate action in minibatch
|
|
|
|
:param observation:
|
|
|
|
:return: 2-D numpy array
|
|
|
|
"""
|
|
|
|
sess = tf.get_default_session()
|
|
|
|
|
2018-04-11 14:23:40 +08:00
|
|
|
feed_dict = {self.observation_placeholder: observation}
|
2018-01-18 17:38:52 +08:00
|
|
|
action = sess.run(self.action_old, feed_dict=feed_dict)
|
|
|
|
|
|
|
|
return action
|