90 lines
3.0 KiB
Python
90 lines
3.0 KiB
Python
from __future__ import absolute_import
|
|
|
|
from .base import PolicyBase
|
|
import tensorflow as tf
|
|
from ..value_function.action_value import DQN
|
|
|
|
|
|
class DQNRefactor(PolicyBase):
|
|
"""
|
|
use DQN from value_function as a member
|
|
"""
|
|
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
|
|
self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder)
|
|
self._argmax_action = tf.argmax(value_tensor, axis=1)
|
|
|
|
super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder)
|
|
|
|
def act(self, observation, exploration=None):
|
|
sess = tf.get_default_session()
|
|
if not exploration: # no exploration
|
|
action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation})
|
|
|
|
|
|
return action
|
|
|
|
@property
|
|
def q_net(self):
|
|
return self._q_net
|
|
|
|
|
|
class DQNOld(QValuePolicy):
|
|
"""
|
|
The policy as in DQN
|
|
"""
|
|
|
|
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
|
|
# TODO: this version only support non-continuous action space, extend it to support continuous action space
|
|
self._logits = tf.convert_to_tensor(logits)
|
|
if dtype is None:
|
|
dtype = tf.int32
|
|
self._n_categories = self._logits.get_shape()[-1].value
|
|
|
|
super(DQN, self).__init__(observation_placeholder)
|
|
|
|
# TODO: put the net definition outside of the class
|
|
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
|
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
|
net = tf.layers.flatten(net)
|
|
net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True)
|
|
self._value = tf.layers.dense(net, self._n_categories)
|
|
|
|
def _act(self, observation, exploration=None): # first implement no exploration
|
|
"""
|
|
return the action (int) to be executed.
|
|
no exploration when exploration=None.
|
|
"""
|
|
# TODO: ensure thread safety
|
|
sess = tf.get_default_session()
|
|
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
|
|
feed_dict={self._observation_placeholder: observation[None]})
|
|
return sampled_action
|
|
|
|
@property
|
|
def logits(self):
|
|
"""
|
|
:return: action values
|
|
"""
|
|
return self._logits
|
|
|
|
@property
|
|
def n_categories(self):
|
|
"""
|
|
:return: dimension of action space if not continuous
|
|
"""
|
|
return self._n_categories
|
|
|
|
def values(self, observation):
|
|
"""
|
|
returns the Q(s, a) values (float) for all actions a at observation s
|
|
"""
|
|
sess = tf.get_default_session()
|
|
value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]})
|
|
return value
|
|
|
|
def values_tensor(self):
|
|
"""
|
|
returns the tensor of the values for all actions a at observation s
|
|
"""
|
|
return self._value
|