add __all__

This commit is contained in:
haoshengzou 2018-05-20 22:36:04 +08:00
parent eb8c82636e
commit 6f206759ab
18 changed files with 74 additions and 6 deletions

View File

@ -1,5 +1,11 @@
import tensorflow as tf import tensorflow as tf
__all__ = [
'ppo_clip',
'REINFORCE',
'value_mse'
]
def ppo_clip(policy, clip_param): def ppo_clip(policy, clip_param):
""" """

View File

@ -1,5 +1,9 @@
import tensorflow as tf import tensorflow as tf
__all__ = [
'DPG',
]
def DPG(policy, action_value): def DPG(policy, action_value):
""" """

View File

@ -1,7 +1,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
import tensorflow as tf __all__ = []
class PolicyBase(object): class PolicyBase(object):

View File

@ -1,10 +1,13 @@
import tensorflow as tf import tensorflow as tf
import logging
from .base import PolicyBase from .base import PolicyBase
from ..random import OrnsteinUhlenbeckProcess from ..random import OrnsteinUhlenbeckProcess
from ..utils import identify_dependent_variables from ..utils import identify_dependent_variables
__all__ = [
'Deterministic',
]
class Deterministic(PolicyBase): class Deterministic(PolicyBase):
""" """

View File

@ -1,8 +1,11 @@
import tensorflow as tf import tensorflow as tf
import logging
from .base import PolicyBase from .base import PolicyBase
from ..utils import identify_dependent_variables from ..utils import identify_dependent_variables
__all__ = [
'Distributional',
]
class Distributional(PolicyBase): class Distributional(PolicyBase):
""" """

View File

@ -4,6 +4,10 @@ from .base import PolicyBase
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
__all__ = [
'DQN',
]
class DQN(PolicyBase): class DQN(PolicyBase):
""" """

View File

@ -5,6 +5,11 @@ adapted from keras-rl
from __future__ import division from __future__ import division
import numpy as np import numpy as np
__all__ = [
'GaussianWhiteNoiseProcess',
'OrnsteinUhlenbeckProcess',
]
class RandomProcess(object): class RandomProcess(object):
""" """

View File

@ -1,5 +1,9 @@
import tensorflow as tf import tensorflow as tf
__all__ = [
'get_soft_update_op',
]
def identify_dependent_variables(tensor, candidate_variables): def identify_dependent_variables(tensor, candidate_variables):
""" """

View File

@ -1,10 +1,14 @@
from __future__ import absolute_import from __future__ import absolute_import
import logging
import tensorflow as tf import tensorflow as tf
from .base import ValueFunctionBase from .base import ValueFunctionBase
from ..utils import identify_dependent_variables from ..utils import identify_dependent_variables
__all__ = [
'ActionValue',
'DQN',
]
class ActionValue(ValueFunctionBase): class ActionValue(ValueFunctionBase):
""" """

View File

@ -2,6 +2,9 @@ from __future__ import absolute_import
import tensorflow as tf import tensorflow as tf
__all__ = []
class ValueFunctionBase(object): class ValueFunctionBase(object):
""" """
Base class for value functions, including S-values and Q-values. The only Base class for value functions, including S-values and Q-values. The only

View File

@ -1,11 +1,14 @@
from __future__ import absolute_import from __future__ import absolute_import
import tensorflow as tf import tensorflow as tf
import logging
from .base import ValueFunctionBase from .base import ValueFunctionBase
from ..utils import identify_dependent_variables from ..utils import identify_dependent_variables
__all__ = [
'StateValue',
]
class StateValue(ValueFunctionBase): class StateValue(ValueFunctionBase):
""" """

View File

@ -1,5 +1,12 @@
import logging import logging
import numpy as np
__all__ = [
'full_return',
'nstep_return',
'nstep_q_return',
'ddpg_return',
]
STATE = 0 STATE = 0
ACTION = 1 ACTION = 1

View File

@ -1,3 +1,4 @@
__all__ = []
class DataBufferBase(object): class DataBufferBase(object):

View File

@ -4,6 +4,11 @@ import logging
from .base import DataBufferBase from .base import DataBufferBase
__all__ = [
'BatchSet'
]
STATE = 0 STATE = 0
ACTION = 1 ACTION = 1
REWARD = 2 REWARD = 2

View File

@ -1,5 +1,8 @@
from .base import DataBufferBase from .base import DataBufferBase
__all__ = []
class ReplayBufferBase(DataBufferBase): class ReplayBufferBase(DataBufferBase):
""" """
Base class for replay buffer. Base class for replay buffer.

View File

@ -3,6 +3,11 @@ import numpy as np
from .replay_buffer_base import ReplayBufferBase from .replay_buffer_base import ReplayBufferBase
__all__ = [
'VanillaReplayBuffer',
]
STATE = 0 STATE = 0
ACTION = 1 ACTION = 1
REWARD = 2 REWARD = 2

View File

@ -7,6 +7,10 @@ from .data_buffer.batch_set import BatchSet
from .utils import internal_key_match from .utils import internal_key_match
from ..core.policy.deterministic import Deterministic from ..core.policy.deterministic import Deterministic
__all__ = [
'DataCollector',
]
class DataCollector(object): class DataCollector(object):
""" """

View File

@ -4,6 +4,10 @@ import gym
import logging import logging
import numpy as np import numpy as np
__all__ = [
'test_policy_in_env',
]
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0,
discount_factor=0.99, seed=0, episode_cutoff=None): discount_factor=0.99, seed=0, episode_cutoff=None):