add __all__
This commit is contained in:
parent
eb8c82636e
commit
6f206759ab
@ -1,5 +1,11 @@
|
||||
import tensorflow as tf
|
||||
|
||||
__all__ = [
|
||||
'ppo_clip',
|
||||
'REINFORCE',
|
||||
'value_mse'
|
||||
]
|
||||
|
||||
|
||||
def ppo_clip(policy, clip_param):
|
||||
"""
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
import tensorflow as tf
|
||||
|
||||
__all__ = [
|
||||
'DPG',
|
||||
]
|
||||
|
||||
|
||||
def DPG(policy, action_value):
|
||||
"""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
import tensorflow as tf
|
||||
__all__ = []
|
||||
|
||||
|
||||
class PolicyBase(object):
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
import tensorflow as tf
|
||||
import logging
|
||||
|
||||
from .base import PolicyBase
|
||||
from ..random import OrnsteinUhlenbeckProcess
|
||||
from ..utils import identify_dependent_variables
|
||||
|
||||
__all__ = [
|
||||
'Deterministic',
|
||||
]
|
||||
|
||||
|
||||
class Deterministic(PolicyBase):
|
||||
"""
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import tensorflow as tf
|
||||
import logging
|
||||
from .base import PolicyBase
|
||||
from ..utils import identify_dependent_variables
|
||||
|
||||
__all__ = [
|
||||
'Distributional',
|
||||
]
|
||||
|
||||
|
||||
class Distributional(PolicyBase):
|
||||
"""
|
||||
|
||||
@ -4,6 +4,10 @@ from .base import PolicyBase
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'DQN',
|
||||
]
|
||||
|
||||
|
||||
class DQN(PolicyBase):
|
||||
"""
|
||||
|
||||
@ -5,6 +5,11 @@ adapted from keras-rl
|
||||
from __future__ import division
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'GaussianWhiteNoiseProcess',
|
||||
'OrnsteinUhlenbeckProcess',
|
||||
]
|
||||
|
||||
|
||||
class RandomProcess(object):
|
||||
"""
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
import tensorflow as tf
|
||||
|
||||
__all__ = [
|
||||
'get_soft_update_op',
|
||||
]
|
||||
|
||||
|
||||
def identify_dependent_variables(tensor, candidate_variables):
|
||||
"""
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
import tensorflow as tf
|
||||
|
||||
from .base import ValueFunctionBase
|
||||
from ..utils import identify_dependent_variables
|
||||
|
||||
__all__ = [
|
||||
'ActionValue',
|
||||
'DQN',
|
||||
]
|
||||
|
||||
|
||||
class ActionValue(ValueFunctionBase):
|
||||
"""
|
||||
|
||||
@ -2,6 +2,9 @@ from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
class ValueFunctionBase(object):
|
||||
"""
|
||||
Base class for value functions, including S-values and Q-values. The only
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import tensorflow as tf
|
||||
import logging
|
||||
|
||||
from .base import ValueFunctionBase
|
||||
from ..utils import identify_dependent_variables
|
||||
|
||||
__all__ = [
|
||||
'StateValue',
|
||||
]
|
||||
|
||||
|
||||
class StateValue(ValueFunctionBase):
|
||||
"""
|
||||
|
||||
@ -1,5 +1,12 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'full_return',
|
||||
'nstep_return',
|
||||
'nstep_q_return',
|
||||
'ddpg_return',
|
||||
]
|
||||
|
||||
|
||||
STATE = 0
|
||||
ACTION = 1
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
__all__ = []
|
||||
|
||||
|
||||
class DataBufferBase(object):
|
||||
|
||||
@ -4,6 +4,11 @@ import logging
|
||||
|
||||
from .base import DataBufferBase
|
||||
|
||||
__all__ = [
|
||||
'BatchSet'
|
||||
]
|
||||
|
||||
|
||||
STATE = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
from .base import DataBufferBase
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
class ReplayBufferBase(DataBufferBase):
|
||||
"""
|
||||
Base class for replay buffer.
|
||||
|
||||
@ -3,6 +3,11 @@ import numpy as np
|
||||
|
||||
from .replay_buffer_base import ReplayBufferBase
|
||||
|
||||
__all__ = [
|
||||
'VanillaReplayBuffer',
|
||||
]
|
||||
|
||||
|
||||
STATE = 0
|
||||
ACTION = 1
|
||||
REWARD = 2
|
||||
|
||||
@ -7,6 +7,10 @@ from .data_buffer.batch_set import BatchSet
|
||||
from .utils import internal_key_match
|
||||
from ..core.policy.deterministic import Deterministic
|
||||
|
||||
__all__ = [
|
||||
'DataCollector',
|
||||
]
|
||||
|
||||
|
||||
class DataCollector(object):
|
||||
"""
|
||||
|
||||
@ -4,6 +4,10 @@ import gym
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
'test_policy_in_env',
|
||||
]
|
||||
|
||||
|
||||
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0,
|
||||
discount_factor=0.99, seed=0, episode_cutoff=None):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user