fix the bug of unnamed_dict.update()
. import cleaning in examples/*.py
This commit is contained in:
parent
d84c9d121c
commit
2527030838
@ -7,16 +7,6 @@ import gym
|
||||
|
||||
import tianshou as ts
|
||||
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
from tianshou.core import losses
|
||||
import tianshou.data.advantage_estimation as advantage_estimation
|
||||
import tianshou.core.policy.distributional as policy
|
||||
import tianshou.core.value_function.state_value as value_function
|
||||
|
||||
from tianshou.data.data_buffer.batch_set import BatchSet
|
||||
from tianshou.data.data_collector import DataCollector
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
env = gym.make('CartPole-v0')
|
||||
|
@ -9,20 +9,6 @@ import argparse
|
||||
|
||||
import tianshou as ts
|
||||
|
||||
# our lib imports here! It's ok to append path in examples
|
||||
import sys
|
||||
sys.path.append('..')
|
||||
from tianshou.core import losses
|
||||
import tianshou.data.advantage_estimation as advantage_estimation
|
||||
import tianshou.core.policy as policy
|
||||
import tianshou.core.value_function.action_value as value_function
|
||||
import tianshou.core.opt as opt
|
||||
|
||||
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
|
||||
from tianshou.data.data_collector import DataCollector
|
||||
from tianshou.data.tester import test_policy_in_env
|
||||
from tianshou.core.utils import get_soft_update_op
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -146,7 +146,8 @@ class Deterministic(PolicyBase):
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
|
||||
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict.update(my_feed_dict)
|
||||
action = sess.run(self.action, feed_dict=feed_dict)
|
||||
|
||||
return action
|
||||
@ -164,7 +165,8 @@ class Deterministic(PolicyBase):
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
|
||||
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict.update(my_feed_dict)
|
||||
action = sess.run(self.action_old, feed_dict=feed_dict)
|
||||
|
||||
return action
|
@ -80,8 +80,9 @@ class ActionValue(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.value_tensor, feed_dict=
|
||||
{self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict))
|
||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
||||
|
||||
def eval_value_old(self, observation, action, my_feed_dict={}):
|
||||
"""
|
||||
@ -95,7 +96,8 @@ class ActionValue(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||
|
||||
def sync_weights(self):
|
||||
@ -198,7 +200,8 @@ class DQN(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
||||
|
||||
def eval_value_old(self, observation, action, my_feed_dict={}):
|
||||
@ -213,7 +216,8 @@ class DQN(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||
|
||||
@property
|
||||
@ -232,7 +236,9 @@ class DQN(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self._value_tensor_all_actions, feed_dict=feed_dict)
|
||||
|
||||
def eval_value_all_actions_old(self, observation, my_feed_dict={}):
|
||||
"""
|
||||
@ -245,7 +251,9 @@ class DQN(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self.value_tensor_all_actions_old, feed_dict=feed_dict)
|
||||
|
||||
def sync_weights(self):
|
||||
"""
|
||||
|
@ -79,7 +79,9 @@ class StateValue(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
||||
|
||||
def eval_value_old(self, observation, my_feed_dict={}):
|
||||
"""
|
||||
@ -92,7 +94,9 @@ class StateValue(ValueFunctionBase):
|
||||
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
return sess.run(self.value_tensor_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||
feed_dict = {self.observation_placeholder: observation}
|
||||
feed_dict.update(my_feed_dict)
|
||||
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||
|
||||
def sync_weights(self):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user