90 lines
3.1 KiB
Python

from __future__ import absolute_import
from .base import PolicyBase
import tensorflow as tf
from ..value_function.action_value import DQN
class DQNRefactor(PolicyBase):
"""
use DQN from value_function as a member
"""
def __init__(self, value_tensor, observation_placeholder, action_placeholder):
self._q_net = DQN(value_tensor, observation_placeholder, action_placeholder)
self._argmax_action = tf.argmax(value_tensor, axis=1)
super(DQNRefactor, self).__init__(observation_placeholder=observation_placeholder)
def act(self, observation, exploration=None):
sess = tf.get_default_session()
if not exploration: # no exploration
action = sess.run(self._argmax_action, feed_dict={self._observation_placeholder: observation})
return action
@property
def q_net(self):
return self._q_net
class DQNOld(QValuePolicy):
"""
The policy as in DQN
"""
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
# TODO: this version only support non-continuous action space, extend it to support continuous action space
self._logits = tf.convert_to_tensor(logits)
if dtype is None:
dtype = tf.int32
self._n_categories = self._logits.get_shape()[-1].value
super(DQN, self).__init__(observation_placeholder)
# TODO: put the net definition outside of the class
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True)
self._value = tf.layers.dense(net, self._n_categories)
def _act(self, observation, exploration=None): # first implement no exploration
"""
return the action (int) to be executed.
no exploration when exploration=None.
"""
# TODO: ensure thread safety, tf.multinomial to init
sess = tf.get_default_session()
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
feed_dict={self._observation_placeholder: observation[None]})
return sampled_action
@property
def logits(self):
"""
:return: action values
"""
return self._logits
@property
def n_categories(self):
"""
:return: dimension of action space if not continuous
"""
return self._n_categories
def values(self, observation):
"""
returns the Q(s, a) values (float) for all actions a at observation s
"""
sess = tf.get_default_session()
value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]})
return value
def values_tensor(self):
"""
returns the tensor of the values for all actions a at observation s
"""
return self._value