add __all__
This commit is contained in:
parent
eb8c82636e
commit
6f206759ab
@ -1,5 +1,11 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ppo_clip',
|
||||||
|
'REINFORCE',
|
||||||
|
'value_mse'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def ppo_clip(policy, clip_param):
|
def ppo_clip(policy, clip_param):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'DPG',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def DPG(policy, action_value):
|
def DPG(policy, action_value):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
||||||
import tensorflow as tf
|
__all__ = []
|
||||||
|
|
||||||
|
|
||||||
class PolicyBase(object):
|
class PolicyBase(object):
|
||||||
|
|||||||
@ -1,10 +1,13 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import logging
|
|
||||||
|
|
||||||
from .base import PolicyBase
|
from .base import PolicyBase
|
||||||
from ..random import OrnsteinUhlenbeckProcess
|
from ..random import OrnsteinUhlenbeckProcess
|
||||||
from ..utils import identify_dependent_variables
|
from ..utils import identify_dependent_variables
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Deterministic',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Deterministic(PolicyBase):
|
class Deterministic(PolicyBase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import logging
|
|
||||||
from .base import PolicyBase
|
from .base import PolicyBase
|
||||||
from ..utils import identify_dependent_variables
|
from ..utils import identify_dependent_variables
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Distributional',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Distributional(PolicyBase):
|
class Distributional(PolicyBase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,6 +4,10 @@ from .base import PolicyBase
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'DQN',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DQN(PolicyBase):
|
class DQN(PolicyBase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -5,6 +5,11 @@ adapted from keras-rl
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'GaussianWhiteNoiseProcess',
|
||||||
|
'OrnsteinUhlenbeckProcess',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class RandomProcess(object):
|
class RandomProcess(object):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'get_soft_update_op',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def identify_dependent_variables(tensor, candidate_variables):
|
def identify_dependent_variables(tensor, candidate_variables):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
import logging
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .base import ValueFunctionBase
|
from .base import ValueFunctionBase
|
||||||
from ..utils import identify_dependent_variables
|
from ..utils import identify_dependent_variables
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ActionValue',
|
||||||
|
'DQN',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ActionValue(ValueFunctionBase):
|
class ActionValue(ValueFunctionBase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -2,6 +2,9 @@ from __future__ import absolute_import
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
|
||||||
class ValueFunctionBase(object):
|
class ValueFunctionBase(object):
|
||||||
"""
|
"""
|
||||||
Base class for value functions, including S-values and Q-values. The only
|
Base class for value functions, including S-values and Q-values. The only
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import logging
|
|
||||||
|
|
||||||
from .base import ValueFunctionBase
|
from .base import ValueFunctionBase
|
||||||
from ..utils import identify_dependent_variables
|
from ..utils import identify_dependent_variables
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'StateValue',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class StateValue(ValueFunctionBase):
|
class StateValue(ValueFunctionBase):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
|
||||||
|
__all__ = [
|
||||||
|
'full_return',
|
||||||
|
'nstep_return',
|
||||||
|
'nstep_q_return',
|
||||||
|
'ddpg_return',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
STATE = 0
|
STATE = 0
|
||||||
ACTION = 1
|
ACTION = 1
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
__all__ = []
|
||||||
|
|
||||||
|
|
||||||
class DataBufferBase(object):
|
class DataBufferBase(object):
|
||||||
|
|||||||
@ -4,6 +4,11 @@ import logging
|
|||||||
|
|
||||||
from .base import DataBufferBase
|
from .base import DataBufferBase
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'BatchSet'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
STATE = 0
|
STATE = 0
|
||||||
ACTION = 1
|
ACTION = 1
|
||||||
REWARD = 2
|
REWARD = 2
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
from .base import DataBufferBase
|
from .base import DataBufferBase
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
|
||||||
class ReplayBufferBase(DataBufferBase):
|
class ReplayBufferBase(DataBufferBase):
|
||||||
"""
|
"""
|
||||||
Base class for replay buffer.
|
Base class for replay buffer.
|
||||||
|
|||||||
@ -3,6 +3,11 @@ import numpy as np
|
|||||||
|
|
||||||
from .replay_buffer_base import ReplayBufferBase
|
from .replay_buffer_base import ReplayBufferBase
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'VanillaReplayBuffer',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
STATE = 0
|
STATE = 0
|
||||||
ACTION = 1
|
ACTION = 1
|
||||||
REWARD = 2
|
REWARD = 2
|
||||||
|
|||||||
@ -7,6 +7,10 @@ from .data_buffer.batch_set import BatchSet
|
|||||||
from .utils import internal_key_match
|
from .utils import internal_key_match
|
||||||
from ..core.policy.deterministic import Deterministic
|
from ..core.policy.deterministic import Deterministic
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'DataCollector',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class DataCollector(object):
|
class DataCollector(object):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -4,6 +4,10 @@ import gym
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'test_policy_in_env',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0,
|
def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0,
|
||||||
discount_factor=0.99, seed=0, episode_cutoff=None):
|
discount_factor=0.99, seed=0, episode_cutoff=None):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user