fix the bug of unnamed_dict.update(). import cleaning in examples/*.py

This commit is contained in:
haoshengzou 2018-04-16 20:17:41 +08:00
parent d84c9d121c
commit 2527030838
5 changed files with 25 additions and 35 deletions

View File

@ -7,16 +7,6 @@ import gym
import tianshou as ts import tianshou as ts
import sys
sys.path.append('..')
from tianshou.core import losses
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy.distributional as policy
import tianshou.core.value_function.state_value as value_function
from tianshou.data.data_buffer.batch_set import BatchSet
from tianshou.data.data_collector import DataCollector
if __name__ == '__main__': if __name__ == '__main__':
env = gym.make('CartPole-v0') env = gym.make('CartPole-v0')

View File

@ -9,20 +9,6 @@ import argparse
import tianshou as ts import tianshou as ts
# our lib imports here! It's ok to append path in examples
import sys
sys.path.append('..')
from tianshou.core import losses
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy as policy
import tianshou.core.value_function.action_value as value_function
import tianshou.core.opt as opt
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
from tianshou.data.data_collector import DataCollector
from tianshou.data.tester import test_policy_in_env
from tianshou.core.utils import get_soft_update_op
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -146,7 +146,8 @@ class Deterministic(PolicyBase):
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict) feed_dict = {self.observation_placeholder: observation}
feed_dict.update(my_feed_dict)
action = sess.run(self.action, feed_dict=feed_dict) action = sess.run(self.action, feed_dict=feed_dict)
return action return action
@ -164,7 +165,8 @@ class Deterministic(PolicyBase):
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict) feed_dict = {self.observation_placeholder: observation}
feed_dict.update(my_feed_dict)
action = sess.run(self.action_old, feed_dict=feed_dict) action = sess.run(self.action_old, feed_dict=feed_dict)
return action return action

View File

@ -80,8 +80,9 @@ class ActionValue(ValueFunctionBase):
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self.value_tensor, feed_dict= feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
{self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)) feed_dict.update(my_feed_dict)
return sess.run(self.value_tensor, feed_dict=feed_dict)
def eval_value_old(self, observation, action, my_feed_dict={}): def eval_value_old(self, observation, action, my_feed_dict={}):
""" """
@ -95,7 +96,8 @@ class ActionValue(ValueFunctionBase):
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict) feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
feed_dict.update(my_feed_dict)
return sess.run(self.value_tensor_old, feed_dict=feed_dict) return sess.run(self.value_tensor_old, feed_dict=feed_dict)
def sync_weights(self): def sync_weights(self):
@ -198,7 +200,8 @@ class DQN(ValueFunctionBase):
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict) feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
feed_dict.update(my_feed_dict)
return sess.run(self.value_tensor, feed_dict=feed_dict) return sess.run(self.value_tensor, feed_dict=feed_dict)
def eval_value_old(self, observation, action, my_feed_dict={}): def eval_value_old(self, observation, action, my_feed_dict={}):
@ -213,7 +216,8 @@ class DQN(ValueFunctionBase):
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation. :return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict) feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
feed_dict.update(my_feed_dict)
return sess.run(self.value_tensor_old, feed_dict=feed_dict) return sess.run(self.value_tensor_old, feed_dict=feed_dict)
@property @property
@ -232,7 +236,9 @@ class DQN(ValueFunctionBase):
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation. :return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) feed_dict = {self.observation_placeholder: observation}
feed_dict.update(my_feed_dict)
return sess.run(self._value_tensor_all_actions, feed_dict=feed_dict)
def eval_value_all_actions_old(self, observation, my_feed_dict={}): def eval_value_all_actions_old(self, observation, my_feed_dict={}):
""" """
@ -245,7 +251,9 @@ class DQN(ValueFunctionBase):
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation. :return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) feed_dict = {self.observation_placeholder: observation}
feed_dict.update(my_feed_dict)
return sess.run(self.value_tensor_all_actions_old, feed_dict=feed_dict)
def sync_weights(self): def sync_weights(self):
""" """

View File

@ -79,7 +79,9 @@ class StateValue(ValueFunctionBase):
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation. :return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) feed_dict = {self.observation_placeholder: observation}
feed_dict.update(my_feed_dict)
return sess.run(self.value_tensor, feed_dict=feed_dict)
def eval_value_old(self, observation, my_feed_dict={}): def eval_value_old(self, observation, my_feed_dict={}):
""" """
@ -92,7 +94,9 @@ class StateValue(ValueFunctionBase):
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation. :return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
""" """
sess = tf.get_default_session() sess = tf.get_default_session()
return sess.run(self.value_tensor_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict)) feed_dict = {self.observation_placeholder: observation}
feed_dict.update(my_feed_dict)
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
def sync_weights(self): def sync_weights(self):
""" """