2017-12-23 15:36:10 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
|
|
|
import tensorflow as tf
|
2017-12-22 00:22:23 +08:00
|
|
|
|
|
|
|
class ValueFunctionBase(object):
|
|
|
|
"""
|
2018-04-12 21:10:50 +08:00
|
|
|
Base class for value functions, including S-values and Q-values. The only
|
|
|
|
mandatory method for a value function class is:
|
|
|
|
|
|
|
|
:func:`eval_value`, which runs the graph and evaluates the corresponding value.
|
|
|
|
|
|
|
|
:param value_tensor: a Tensor. The tensor of V(s) or Q(s, a).
|
|
|
|
:param observation_placeholder: a :class:`tf.placeholder`. The observation placeholder of the network graph.
|
2017-12-22 00:22:23 +08:00
|
|
|
"""
|
|
|
|
def __init__(self, value_tensor, observation_placeholder):
|
2018-04-11 14:23:40 +08:00
|
|
|
self.observation_placeholder = observation_placeholder
|
2018-04-12 21:10:50 +08:00
|
|
|
self._value_tensor = tf.squeeze(value_tensor) # canonical value has shape (batchsize, )
|
2017-12-22 00:22:23 +08:00
|
|
|
|
2017-12-23 17:25:16 +08:00
|
|
|
def eval_value(self, **kwargs):
|
2017-12-22 00:22:23 +08:00
|
|
|
"""
|
2018-04-12 21:10:50 +08:00
|
|
|
Runs the graph and evaluates the corresponding value.
|
2017-12-22 00:22:23 +08:00
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2017-12-23 17:25:16 +08:00
|
|
|
@property
|
|
|
|
def value_tensor(self):
|
2018-04-12 21:10:50 +08:00
|
|
|
"""Tensor of the corresponding value"""
|
2017-12-22 00:22:23 +08:00
|
|
|
return self._value_tensor
|