diff --git a/examples/actor_critic.py b/examples/actor_critic.py index 97825b9..2e3a8a5 100755 --- a/examples/actor_critic.py +++ b/examples/actor_critic.py @@ -7,16 +7,6 @@ import gym import tianshou as ts -import sys -sys.path.append('..') -from tianshou.core import losses -import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy.distributional as policy -import tianshou.core.value_function.state_value as value_function - -from tianshou.data.data_buffer.batch_set import BatchSet -from tianshou.data.data_collector import DataCollector - if __name__ == '__main__': env = gym.make('CartPole-v0') diff --git a/examples/ddpg.py b/examples/ddpg.py index 9579129..17a67a7 100644 --- a/examples/ddpg.py +++ b/examples/ddpg.py @@ -9,20 +9,6 @@ import argparse import tianshou as ts -# our lib imports here! It's ok to append path in examples -import sys -sys.path.append('..') -from tianshou.core import losses -import tianshou.data.advantage_estimation as advantage_estimation -import tianshou.core.policy as policy -import tianshou.core.value_function.action_value as value_function -import tianshou.core.opt as opt - -from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer -from tianshou.data.data_collector import DataCollector -from tianshou.data.tester import test_policy_in_env -from tianshou.core.utils import get_soft_update_op - if __name__ == '__main__': parser = argparse.ArgumentParser() diff --git a/tianshou/core/policy/deterministic.py b/tianshou/core/policy/deterministic.py index 5ea3683..a1cd4da 100644 --- a/tianshou/core/policy/deterministic.py +++ b/tianshou/core/policy/deterministic.py @@ -146,7 +146,8 @@ class Deterministic(PolicyBase): """ sess = tf.get_default_session() - feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict) + feed_dict = {self.observation_placeholder: observation} + feed_dict.update(my_feed_dict) action = sess.run(self.action, feed_dict=feed_dict) return action @@ -164,7 +165,8 @@ class Deterministic(PolicyBase): """ sess = tf.get_default_session() - feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict) + feed_dict = {self.observation_placeholder: observation} + feed_dict.update(my_feed_dict) action = sess.run(self.action_old, feed_dict=feed_dict) return action \ No newline at end of file diff --git a/tianshou/core/value_function/action_value.py b/tianshou/core/value_function/action_value.py index 599b1da..56ae1d3 100644 --- a/tianshou/core/value_function/action_value.py +++ b/tianshou/core/value_function/action_value.py @@ -80,8 +80,9 @@ class ActionValue(ValueFunctionBase): :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() - return sess.run(self.value_tensor, feed_dict= - {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)) + feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} + feed_dict.update(my_feed_dict) + return sess.run(self.value_tensor, feed_dict=feed_dict) def eval_value_old(self, observation, action, my_feed_dict={}): """ @@ -95,7 +96,8 @@ class ActionValue(ValueFunctionBase): :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() - feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict) + feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} + feed_dict.update(my_feed_dict) return sess.run(self.value_tensor_old, feed_dict=feed_dict) def sync_weights(self): @@ -198,7 +200,8 @@ class DQN(ValueFunctionBase): :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() - feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict) + feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} + feed_dict.update(my_feed_dict) return sess.run(self.value_tensor, feed_dict=feed_dict) def eval_value_old(self, observation, action, my_feed_dict={}): @@ -213,7 +216,8 @@ class DQN(ValueFunctionBase): :return: A numpy array of shape (batch_size,). The corresponding action value for each observation. """ sess = tf.get_default_session() - feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict) + feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action} + feed_dict.update(my_feed_dict) return sess.run(self.value_tensor_old, feed_dict=feed_dict) @property @@ -232,7 +236,9 @@ class DQN(ValueFunctionBase): :return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation. """ sess = tf.get_default_session() - return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) + feed_dict = {self.observation_placeholder: observation} + feed_dict.update(my_feed_dict) + return sess.run(self._value_tensor_all_actions, feed_dict=feed_dict) def eval_value_all_actions_old(self, observation, my_feed_dict={}): """ @@ -245,7 +251,9 @@ class DQN(ValueFunctionBase): :return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation. """ sess = tf.get_default_session() - return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) + feed_dict = {self.observation_placeholder: observation} + feed_dict.update(my_feed_dict) + return sess.run(self.value_tensor_all_actions_old, feed_dict=feed_dict) def sync_weights(self): """ diff --git a/tianshou/core/value_function/state_value.py b/tianshou/core/value_function/state_value.py index 76da7b2..7d43a28 100644 --- a/tianshou/core/value_function/state_value.py +++ b/tianshou/core/value_function/state_value.py @@ -79,7 +79,9 @@ class StateValue(ValueFunctionBase): :return: A numpy array of shape (batch_size,). The corresponding state value for each observation. """ sess = tf.get_default_session() - return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) + feed_dict = {self.observation_placeholder: observation} + feed_dict.update(my_feed_dict) + return sess.run(self.value_tensor, feed_dict=feed_dict) def eval_value_old(self, observation, my_feed_dict={}): """ @@ -92,7 +94,9 @@ class StateValue(ValueFunctionBase): :return: A numpy array of shape (batch_size,). The corresponding state value for each observation. """ sess = tf.get_default_session() - return sess.run(self.value_tensor_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) + feed_dict = {self.observation_placeholder: observation} + feed_dict.update(my_feed_dict) + return sess.run(self.value_tensor_old, feed_dict=feed_dict) def sync_weights(self): """