2018-04-12 21:10:50 +08:00

56 lines
2.4 KiB
Python

from __future__ import absolute_import
from __future__ import division
import tensorflow as tf
class PolicyBase(object):
"""
Base class for policy. Mandatory methods for a policy class are:
- :func:`act`. It's used interacting with the environment during training, \
so exploration noise should be added in this method.
- :func:`act_test`. Since RL usually adds additional exploration noise during training, a different method\
for testing the policy should be defined with different exploration specification.\
Generally, DQN uses different :math:`\epsilon` in :math:`\epsilon`-greedy and\
DDPG removes exploration noise during test.
- :func:`reset`. It's mainly to reset the states of the exploration random process, or if your policy has\
some internal states that should be reset at the beginning of each new episode. Otherwise, this method\
does nothing.
"""
def act(self, observation, my_feed_dict):
"""
Return action given observation, when interacting with the environment during training.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: A dict. Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array. Action given the single observation. Its "batch_size" is 1,
but should not be explicitly set.
"""
raise NotImplementedError()
def act_test(self, observation, my_feed_dict):
"""
Return action given observation, when interacting with the environment during test.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: A dict. Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array. Action given the single observation. Its "batch_size" is 1,
but should not be explicitly set.
"""
raise NotImplementedError
def reset(self):
"""
Reset the internal states of the policy. Does nothing by default.
"""
pass