fix the bug of unnamed_dict.update()
. import cleaning in examples/*.py
This commit is contained in:
parent
d84c9d121c
commit
2527030838
@ -7,16 +7,6 @@ import gym
|
|||||||
|
|
||||||
import tianshou as ts
|
import tianshou as ts
|
||||||
|
|
||||||
import sys
|
|
||||||
sys.path.append('..')
|
|
||||||
from tianshou.core import losses
|
|
||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
|
||||||
import tianshou.core.policy.distributional as policy
|
|
||||||
import tianshou.core.value_function.state_value as value_function
|
|
||||||
|
|
||||||
from tianshou.data.data_buffer.batch_set import BatchSet
|
|
||||||
from tianshou.data.data_collector import DataCollector
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
env = gym.make('CartPole-v0')
|
env = gym.make('CartPole-v0')
|
||||||
|
@ -9,20 +9,6 @@ import argparse
|
|||||||
|
|
||||||
import tianshou as ts
|
import tianshou as ts
|
||||||
|
|
||||||
# our lib imports here! It's ok to append path in examples
|
|
||||||
import sys
|
|
||||||
sys.path.append('..')
|
|
||||||
from tianshou.core import losses
|
|
||||||
import tianshou.data.advantage_estimation as advantage_estimation
|
|
||||||
import tianshou.core.policy as policy
|
|
||||||
import tianshou.core.value_function.action_value as value_function
|
|
||||||
import tianshou.core.opt as opt
|
|
||||||
|
|
||||||
from tianshou.data.data_buffer.vanilla import VanillaReplayBuffer
|
|
||||||
from tianshou.data.data_collector import DataCollector
|
|
||||||
from tianshou.data.tester import test_policy_in_env
|
|
||||||
from tianshou.core.utils import get_soft_update_op
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -146,7 +146,8 @@ class Deterministic(PolicyBase):
|
|||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
|
feed_dict = {self.observation_placeholder: observation}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
action = sess.run(self.action, feed_dict=feed_dict)
|
action = sess.run(self.action, feed_dict=feed_dict)
|
||||||
|
|
||||||
return action
|
return action
|
||||||
@ -164,7 +165,8 @@ class Deterministic(PolicyBase):
|
|||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
feed_dict = {self.observation_placeholder: observation}.update(my_feed_dict)
|
feed_dict = {self.observation_placeholder: observation}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
action = sess.run(self.action_old, feed_dict=feed_dict)
|
action = sess.run(self.action_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
return action
|
return action
|
@ -80,8 +80,9 @@ class ActionValue(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.value_tensor, feed_dict=
|
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||||
{self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict))
|
feed_dict.update(my_feed_dict)
|
||||||
|
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
||||||
|
|
||||||
def eval_value_old(self, observation, action, my_feed_dict={}):
|
def eval_value_old(self, observation, action, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
@ -95,7 +96,8 @@ class ActionValue(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
@ -198,7 +200,8 @@ class DQN(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
||||||
|
|
||||||
def eval_value_old(self, observation, action, my_feed_dict={}):
|
def eval_value_old(self, observation, action, my_feed_dict={}):
|
||||||
@ -213,7 +216,8 @@ class DQN(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -232,7 +236,9 @@ class DQN(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
feed_dict = {self.observation_placeholder: observation}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
|
return sess.run(self._value_tensor_all_actions, feed_dict=feed_dict)
|
||||||
|
|
||||||
def eval_value_all_actions_old(self, observation, my_feed_dict={}):
|
def eval_value_all_actions_old(self, observation, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
@ -245,7 +251,9 @@ class DQN(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
feed_dict = {self.observation_placeholder: observation}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
|
return sess.run(self.value_tensor_all_actions_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
"""
|
"""
|
||||||
|
@ -79,7 +79,9 @@ class StateValue(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
feed_dict = {self.observation_placeholder: observation}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
|
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
||||||
|
|
||||||
def eval_value_old(self, observation, my_feed_dict={}):
|
def eval_value_old(self, observation, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
@ -92,7 +94,9 @@ class StateValue(ValueFunctionBase):
|
|||||||
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.value_tensor_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
feed_dict = {self.observation_placeholder: observation}
|
||||||
|
feed_dict.update(my_feed_dict)
|
||||||
|
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user