From 6f206759ab689313808071ed90fb57dc56827b5d Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Sun, 20 May 2018 22:36:04 +0800 Subject: [PATCH] add __all__ --- tianshou/core/losses.py | 6 ++++++ tianshou/core/opt.py | 4 ++++ tianshou/core/policy/base.py | 2 +- tianshou/core/policy/deterministic.py | 5 ++++- tianshou/core/policy/distributional.py | 5 ++++- tianshou/core/policy/dqn.py | 4 ++++ tianshou/core/random.py | 5 +++++ tianshou/core/utils.py | 4 ++++ tianshou/core/value_function/action_value.py | 6 +++++- tianshou/core/value_function/base.py | 3 +++ tianshou/core/value_function/state_value.py | 5 ++++- tianshou/data/advantage_estimation.py | 9 ++++++++- tianshou/data/data_buffer/base.py | 1 + tianshou/data/data_buffer/batch_set.py | 5 +++++ tianshou/data/data_buffer/replay_buffer_base.py | 3 +++ tianshou/data/data_buffer/vanilla.py | 5 +++++ tianshou/data/data_collector.py | 4 ++++ tianshou/data/tester.py | 4 ++++ 18 files changed, 74 insertions(+), 6 deletions(-) diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index 1c329de..ecef8f0 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -1,5 +1,11 @@ import tensorflow as tf +__all__ = [ + 'ppo_clip', + 'REINFORCE', + 'value_mse' +] + def ppo_clip(policy, clip_param): """ diff --git a/tianshou/core/opt.py b/tianshou/core/opt.py index 96f263a..1df939c 100644 --- a/tianshou/core/opt.py +++ b/tianshou/core/opt.py @@ -1,5 +1,9 @@ import tensorflow as tf +__all__ = [ + 'DPG', +] + def DPG(policy, action_value): """ diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index 42e8f30..d4abe70 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from __future__ import division -import tensorflow as tf +__all__ = [] class PolicyBase(object): diff --git a/tianshou/core/policy/deterministic.py b/tianshou/core/policy/deterministic.py index a1cd4da..b9a76fc 100644 --- a/tianshou/core/policy/deterministic.py +++ b/tianshou/core/policy/deterministic.py @@ -1,10 +1,13 @@ import tensorflow as tf -import logging from .base import PolicyBase from ..random import OrnsteinUhlenbeckProcess from ..utils import identify_dependent_variables +__all__ = [ + 'Deterministic', +] + class Deterministic(PolicyBase): """ diff --git a/tianshou/core/policy/distributional.py b/tianshou/core/policy/distributional.py index dc2dcc5..e04a8ed 100644 --- a/tianshou/core/policy/distributional.py +++ b/tianshou/core/policy/distributional.py @@ -1,8 +1,11 @@ import tensorflow as tf -import logging from .base import PolicyBase from ..utils import identify_dependent_variables +__all__ = [ + 'Distributional', +] + class Distributional(PolicyBase): """ diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index 70f4a56..80bf59d 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -4,6 +4,10 @@ from .base import PolicyBase import tensorflow as tf import numpy as np +__all__ = [ + 'DQN', +] + class DQN(PolicyBase): """ diff --git a/tianshou/core/random.py b/tianshou/core/random.py index 5807670..9d4e40b 100644 --- a/tianshou/core/random.py +++ b/tianshou/core/random.py @@ -5,6 +5,11 @@ adapted from keras-rl from __future__ import division import numpy as np +__all__ = [ + 'GaussianWhiteNoiseProcess', + 'OrnsteinUhlenbeckProcess', +] + class RandomProcess(object): """ diff --git a/tianshou/core/utils.py b/tianshou/core/utils.py index 749e72f..2562408 100644 --- a/tianshou/core/utils.py +++ b/tianshou/core/utils.py @@ -1,5 +1,9 @@ import tensorflow as tf +__all__ = [ + 'get_soft_update_op', +] + def identify_dependent_variables(tensor, candidate_variables): """ diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index 56ae1d3..2377ab7 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -1,10 +1,14 @@ from __future__ import absolute_import -import logging import tensorflow as tf from .base import ValueFunctionBase from ..utils import identify_dependent_variables +__all__ = [ + 'ActionValue', + 'DQN', +] + class ActionValue(ValueFunctionBase): """ diff --git a/tianshou/core/value_function/base.py b/tianshou/core/value_function/base.py index 7c4ce88..0d4dadf 100644 --- a/tianshou/core/value_function/base.py +++ b/tianshou/core/value_function/base.py @@ -2,6 +2,9 @@ from __future__ import absolute_import import tensorflow as tf +__all__ = [] + + class ValueFunctionBase(object): """ Base class for value functions, including S-values and Q-values. The only diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py index 7d43a28..3258b21 100644 --- a/tianshou/core/value_function/state_value.py +++ b/tianshou/core/value_function/state_value.py @@ -1,11 +1,14 @@ from __future__ import absolute_import import tensorflow as tf -import logging from .base import ValueFunctionBase from ..utils import identify_dependent_variables +__all__ = [ + 'StateValue', +] + class StateValue(ValueFunctionBase): """ diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 5818a90..3b6d930 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -1,5 +1,12 @@ import logging -import numpy as np + +__all__ = [ + 'full_return', + 'nstep_return', + 'nstep_q_return', + 'ddpg_return', +] + STATE = 0 ACTION = 1 diff --git a/tianshou/data/data_buffer/base.py b/tianshou/data/data_buffer/base.py index 49c3ddc..c60dff2 100644 --- a/tianshou/data/data_buffer/base.py +++ b/tianshou/data/data_buffer/base.py @@ -1,3 +1,4 @@ +__all__ = [] class DataBufferBase(object): diff --git a/tianshou/data/data_buffer/batch_set.py b/tianshou/data/data_buffer/batch_set.py index b4d8044..572f4d1 100644 --- a/tianshou/data/data_buffer/batch_set.py +++ b/tianshou/data/data_buffer/batch_set.py @@ -4,6 +4,11 @@ import logging from .base import DataBufferBase +__all__ = [ + 'BatchSet' +] + + STATE = 0 ACTION = 1 REWARD = 2 diff --git a/tianshou/data/data_buffer/replay_buffer_base.py b/tianshou/data/data_buffer/replay_buffer_base.py index c0539be..8bbae4c 100644 --- a/tianshou/data/data_buffer/replay_buffer_base.py +++ b/tianshou/data/data_buffer/replay_buffer_base.py @@ -1,5 +1,8 @@ from .base import DataBufferBase +__all__ = [] + + class ReplayBufferBase(DataBufferBase): """ Base class for replay buffer. diff --git a/tianshou/data/data_buffer/vanilla.py b/tianshou/data/data_buffer/vanilla.py index 005327a..cb83524 100644 --- a/tianshou/data/data_buffer/vanilla.py +++ b/tianshou/data/data_buffer/vanilla.py @@ -3,6 +3,11 @@ import numpy as np from .replay_buffer_base import ReplayBufferBase +__all__ = [ + 'VanillaReplayBuffer', +] + + STATE = 0 ACTION = 1 REWARD = 2 diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index c70b3e5..67d346c 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -7,6 +7,10 @@ from .data_buffer.batch_set import BatchSet from .utils import internal_key_match from ..core.policy.deterministic import Deterministic +__all__ = [ + 'DataCollector', +] + class DataCollector(object): """ diff --git a/tianshou/data/tester.py b/tianshou/data/tester.py index 46fea1b..0eb1e0b 100644 --- a/tianshou/data/tester.py +++ b/tianshou/data/tester.py @@ -4,6 +4,10 @@ import gym import logging import numpy as np +__all__ = [ + 'test_policy_in_env', +] + def test_policy_in_env(policy, env, num_timesteps=0, num_episodes=0, discount_factor=0.99, seed=0, episode_cutoff=None):