2017-12-23 15:36:10 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
|
|
|
import tensorflow as tf
|
2017-12-22 00:22:23 +08:00
|
|
|
|
|
|
|
# TODO: linear feature baseline also in tf?
|
|
|
|
class ValueFunctionBase(object):
|
|
|
|
"""
|
|
|
|
base class of value functions. Children include state values V(s) and action values Q(s, a)
|
|
|
|
"""
|
|
|
|
def __init__(self, value_tensor, observation_placeholder):
|
|
|
|
self._observation_placeholder = observation_placeholder
|
2017-12-23 15:36:10 +08:00
|
|
|
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
|
2017-12-22 00:22:23 +08:00
|
|
|
|
2017-12-23 17:25:16 +08:00
|
|
|
def eval_value(self, **kwargs):
|
2017-12-22 00:22:23 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
:return: batch of corresponding values in numpy array
|
|
|
|
"""
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2017-12-23 17:25:16 +08:00
|
|
|
@property
|
|
|
|
def value_tensor(self):
|
2017-12-22 00:22:23 +08:00
|
|
|
"""
|
|
|
|
|
|
|
|
:return: tensor of the corresponding values
|
|
|
|
"""
|
|
|
|
return self._value_tensor
|