2017-12-23 15:36:10 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
|
|
|
from .base import PolicyBase
|
2017-12-17 12:52:00 +08:00
|
|
|
import tensorflow as tf
|
2018-01-18 12:19:48 +08:00
|
|
|
import numpy as np
|
2017-12-22 00:22:23 +08:00
|
|
|
|
|
|
|
|
2018-01-18 12:19:48 +08:00
|
|
|
class DQN(PolicyBase):
|
2017-12-22 00:22:23 +08:00
|
|
|
"""
|
|
|
|
use DQN from value_function as a member
|
|
|
|
"""
|
2018-03-04 21:29:58 +08:00
|
|
|
def __init__(self, dqn, epsilon_train=0.1, epsilon_test=0.05):
|
2018-01-18 12:19:48 +08:00
|
|
|
self.action_value = dqn
|
|
|
|
self._argmax_action = tf.argmax(dqn.value_tensor_all_actions, axis=1)
|
|
|
|
self.weight_update = dqn.weight_update
|
|
|
|
if self.weight_update > 1:
|
|
|
|
self.interaction_count = 0
|
|
|
|
else:
|
|
|
|
self.interaction_count = -1
|
2017-12-23 17:25:16 +08:00
|
|
|
|
2018-03-04 21:29:58 +08:00
|
|
|
self.epsilon_train = epsilon_train
|
|
|
|
self.epsilon_test = epsilon_test
|
|
|
|
|
|
|
|
def act(self, observation, my_feed_dict={}):
|
2017-12-23 15:36:10 +08:00
|
|
|
sess = tf.get_default_session()
|
2018-01-18 12:19:48 +08:00
|
|
|
if self.weight_update > 1:
|
|
|
|
if self.interaction_count % self.weight_update == 0:
|
|
|
|
self.update_weights()
|
2017-12-13 22:43:45 +08:00
|
|
|
|
2018-01-18 12:19:48 +08:00
|
|
|
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
2018-03-04 21:29:58 +08:00
|
|
|
feed_dict.update(my_feed_dict)
|
2018-01-18 12:19:48 +08:00
|
|
|
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
2018-03-08 16:51:12 +08:00
|
|
|
|
|
|
|
# epsilon_greedy
|
2018-03-04 21:29:58 +08:00
|
|
|
if np.random.rand() < self.epsilon_train:
|
2018-03-08 16:51:12 +08:00
|
|
|
action = np.random.randint(self.action_value.num_actions)
|
2017-12-17 12:52:00 +08:00
|
|
|
|
2018-01-18 12:19:48 +08:00
|
|
|
if self.weight_update > 0:
|
|
|
|
self.interaction_count += 1
|
2017-12-17 12:52:00 +08:00
|
|
|
|
2018-02-25 16:31:35 +08:00
|
|
|
return np.squeeze(action)
|
2017-12-17 12:52:00 +08:00
|
|
|
|
2018-03-04 21:29:58 +08:00
|
|
|
def act_test(self, observation, my_feed_dict={}):
|
2018-03-08 16:51:12 +08:00
|
|
|
sess = tf.get_default_session()
|
|
|
|
if self.weight_update > 1:
|
|
|
|
if self.interaction_count % self.weight_update == 0:
|
|
|
|
self.update_weights()
|
|
|
|
|
|
|
|
feed_dict = {self.action_value._observation_placeholder: observation[None]}
|
|
|
|
feed_dict.update(my_feed_dict)
|
|
|
|
action = sess.run(self._argmax_action, feed_dict=feed_dict)
|
|
|
|
|
|
|
|
# epsilon_greedy
|
|
|
|
if np.random.rand() < self.epsilon_test:
|
|
|
|
action = np.random.randint(self.action_value.num_actions)
|
|
|
|
|
|
|
|
if self.weight_update > 0:
|
|
|
|
self.interaction_count += 1
|
|
|
|
|
|
|
|
return np.squeeze(action)
|
2018-03-04 21:29:58 +08:00
|
|
|
|
2017-12-17 12:52:00 +08:00
|
|
|
@property
|
2018-01-18 12:19:48 +08:00
|
|
|
def q_net(self):
|
|
|
|
return self.action_value
|
2017-12-17 12:52:00 +08:00
|
|
|
|
2018-01-18 12:19:48 +08:00
|
|
|
def sync_weights(self):
|
2017-12-17 12:52:00 +08:00
|
|
|
"""
|
2018-01-18 12:19:48 +08:00
|
|
|
sync the weights of network_old. Direct copy the weights of network.
|
|
|
|
:return:
|
2017-12-17 12:52:00 +08:00
|
|
|
"""
|
2018-01-18 12:19:48 +08:00
|
|
|
if self.action_value.sync_weights_ops is not None:
|
|
|
|
self.action_value.sync_weights()
|
2017-12-17 12:52:00 +08:00
|
|
|
|
2018-01-18 12:19:48 +08:00
|
|
|
def update_weights(self):
|
2017-12-17 12:52:00 +08:00
|
|
|
"""
|
2018-01-18 12:19:48 +08:00
|
|
|
updates the weights of policy_old.
|
|
|
|
:return:
|
2017-12-17 12:52:00 +08:00
|
|
|
"""
|
2018-01-18 12:19:48 +08:00
|
|
|
if self.action_value.weight_update_ops is not None:
|
2018-03-04 21:29:58 +08:00
|
|
|
self.action_value.update_weights()
|
|
|
|
|
|
|
|
def set_epsilon_train(self, epsilon):
|
|
|
|
self.epsilon_train = epsilon
|
|
|
|
|
|
|
|
def set_epsilon_test(self, epsilon):
|
|
|
|
self.epsilon_test = epsilon
|