finish dqn
This commit is contained in:
parent
c804662457
commit
5983c6b33d
1
.gitignore
vendored
1
.gitignore
vendored
@ -135,3 +135,4 @@ dmypy.json
|
|||||||
|
|
||||||
# customize
|
# customize
|
||||||
flake8.sh
|
flake8.sh
|
||||||
|
log/
|
||||||
|
|||||||
136
test/test_dqn.py
Normal file
136
test/test_dqn.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import gym
|
||||||
|
import tqdm
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import DQNPolicy
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.utils import tqdm_config, MovAvg
|
||||||
|
from tianshou.data import Collector, ReplayBuffer
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Module):
|
||||||
|
def __init__(self, layer_num, state_shape, action_shape, device):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.model = [
|
||||||
|
nn.Linear(np.prod(state_shape), 128),
|
||||||
|
nn.ReLU(inplace=True)]
|
||||||
|
for i in range(layer_num):
|
||||||
|
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
|
||||||
|
self.model += [nn.Linear(128, np.prod(action_shape))]
|
||||||
|
self.model = nn.Sequential(*self.model)
|
||||||
|
|
||||||
|
def forward(self, s, **kwargs):
|
||||||
|
if not isinstance(s, torch.Tensor):
|
||||||
|
s = torch.Tensor(s)
|
||||||
|
s = s.to(self.device)
|
||||||
|
batch = s.shape[0]
|
||||||
|
q = self.model(s.view(batch, -1))
|
||||||
|
return q, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.05)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=0.1)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-3)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.9)
|
||||||
|
parser.add_argument('--n-step', type=int, default=1)
|
||||||
|
parser.add_argument('--epoch', type=int, default=100)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=320)
|
||||||
|
parser.add_argument('--collect-per-step', type=int, default=10)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--layer-num', type=int, default=3)
|
||||||
|
parser.add_argument('--training-num', type=int, default=8)
|
||||||
|
parser.add_argument('--test-num', type=int, default=20)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
args = parser.parse_known_args()[0]
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def test_dqn(args=get_args()):
|
||||||
|
env = gym.make(args.task)
|
||||||
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
train_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||||
|
reset_after_done=True)
|
||||||
|
test_envs = SubprocVectorEnv(
|
||||||
|
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||||
|
reset_after_done=False)
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# model
|
||||||
|
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
|
||||||
|
net = net.to(args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
loss = nn.MSELoss()
|
||||||
|
policy = DQNPolicy(net, optim, loss, args.gamma, args.n_step)
|
||||||
|
# collector
|
||||||
|
training_collector = Collector(
|
||||||
|
policy, train_envs, ReplayBuffer(args.buffer_size))
|
||||||
|
test_collector = Collector(
|
||||||
|
policy, test_envs, ReplayBuffer(args.buffer_size), args.test_num)
|
||||||
|
training_collector.collect(n_step=args.batch_size)
|
||||||
|
# log
|
||||||
|
stat_loss = MovAvg()
|
||||||
|
global_step = 0
|
||||||
|
writer = SummaryWriter(args.logdir)
|
||||||
|
best_epoch = -1
|
||||||
|
best_reward = -1e10
|
||||||
|
for epoch in range(args.epoch):
|
||||||
|
desc = f"Epoch #{epoch + 1}"
|
||||||
|
# train
|
||||||
|
policy.train()
|
||||||
|
policy.sync_weight()
|
||||||
|
policy.set_eps(args.eps_train)
|
||||||
|
with tqdm.trange(
|
||||||
|
0, args.step_per_epoch, desc=desc, **tqdm_config) as t:
|
||||||
|
for _ in t:
|
||||||
|
training_collector.collect(n_step=args.collect_per_step)
|
||||||
|
global_step += 1
|
||||||
|
result = training_collector.stat()
|
||||||
|
loss = policy.learn(training_collector.sample(args.batch_size))
|
||||||
|
stat_loss.add(loss)
|
||||||
|
writer.add_scalar(
|
||||||
|
'reward', result['reward'], global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'length', result['length'], global_step=global_step)
|
||||||
|
writer.add_scalar(
|
||||||
|
'loss', stat_loss.get(), global_step=global_step)
|
||||||
|
t.set_postfix(loss=f'{stat_loss.get():.6f}',
|
||||||
|
reward=f'{result["reward"]:.6f}',
|
||||||
|
length=f'{result["length"]:.6f}')
|
||||||
|
# eval
|
||||||
|
test_collector.reset_env()
|
||||||
|
test_collector.reset_buffer()
|
||||||
|
policy.eval()
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
test_collector.collect(n_episode=args.test_num)
|
||||||
|
result = test_collector.stat()
|
||||||
|
if best_reward < result['reward']:
|
||||||
|
best_reward = result['reward']
|
||||||
|
best_epoch = epoch
|
||||||
|
print(f'Epoch #{epoch + 1} reward: {result["reward"]:.6f}, '
|
||||||
|
f'best_reward: {best_reward:.6f} in #{best_epoch}')
|
||||||
|
if args.task == 'CartPole-v0' and best_reward >= 200:
|
||||||
|
break
|
||||||
|
assert best_reward >= 200
|
||||||
|
return best_reward
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dqn(get_args())
|
||||||
@ -65,6 +65,16 @@ class ReplayBuffer(object):
|
|||||||
info=self.info[indice]
|
info=self.info[indice]
|
||||||
), indice
|
), indice
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return Batch(
|
||||||
|
obs=self.obs[index],
|
||||||
|
act=self.act[index],
|
||||||
|
rew=self.rew[index],
|
||||||
|
done=self.done[index],
|
||||||
|
obs_next=self.obs_next[index],
|
||||||
|
info=self.info[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||||
"""docstring for PrioritizedReplayBuffer"""
|
"""docstring for PrioritizedReplayBuffer"""
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from tianshou.utils import MovAvg
|
|||||||
class Collector(object):
|
class Collector(object):
|
||||||
"""docstring for Collector"""
|
"""docstring for Collector"""
|
||||||
|
|
||||||
def __init__(self, policy, env, buffer, contiguous=True):
|
def __init__(self, policy, env, buffer, stat_size=100):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = 1
|
self.env_num = 1
|
||||||
@ -19,27 +19,28 @@ class Collector(object):
|
|||||||
self.process_fn = policy.process_fn
|
self.process_fn = policy.process_fn
|
||||||
self._multi_env = isinstance(env, BaseVectorEnv)
|
self._multi_env = isinstance(env, BaseVectorEnv)
|
||||||
self._multi_buf = False # buf is a list
|
self._multi_buf = False # buf is a list
|
||||||
# need multiple cache buffers only if contiguous in one buffer
|
# need multiple cache buffers only if storing in one buffer
|
||||||
self._cached_buf = []
|
self._cached_buf = []
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
self.env_num = len(env)
|
self.env_num = len(env)
|
||||||
if isinstance(self.buffer, list):
|
if isinstance(self.buffer, list):
|
||||||
assert len(self.buffer) == self.env_num,\
|
assert len(self.buffer) == self.env_num,\
|
||||||
'# of data buffer does not match the # of input env.'
|
'The number of data buffer does not match the number of '\
|
||||||
|
'input env.'
|
||||||
self._multi_buf = True
|
self._multi_buf = True
|
||||||
elif isinstance(self.buffer, ReplayBuffer) and contiguous:
|
elif isinstance(self.buffer, ReplayBuffer):
|
||||||
self._cached_buf = [
|
self._cached_buf = [
|
||||||
deepcopy(buffer) for _ in range(self.env_num)]
|
deepcopy(buffer) for _ in range(self.env_num)]
|
||||||
else:
|
else:
|
||||||
raise TypeError('The buffer in data collector is invalid!')
|
raise TypeError('The buffer in data collector is invalid!')
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
self.clear_buffer()
|
self.reset_buffer()
|
||||||
# state over batch is either a list, an np.ndarray, or torch.Tensor
|
# state over batch is either a list, an np.ndarray, or a torch.Tensor
|
||||||
self.state = None
|
self.state = None
|
||||||
self.stat_reward = MovAvg()
|
self.stat_reward = MovAvg(stat_size)
|
||||||
self.stat_length = MovAvg()
|
self.stat_length = MovAvg(stat_size)
|
||||||
|
|
||||||
def clear_buffer(self):
|
def reset_buffer(self):
|
||||||
if self._multi_buf:
|
if self._multi_buf:
|
||||||
for b in self.buffer:
|
for b in self.buffer:
|
||||||
b.reset()
|
b.reset()
|
||||||
@ -57,6 +58,18 @@ class Collector(object):
|
|||||||
for b in self._cached_buf:
|
for b in self._cached_buf:
|
||||||
b.reset()
|
b.reset()
|
||||||
|
|
||||||
|
def seed(self, seed=None):
|
||||||
|
if hasattr(self.env, 'seed'):
|
||||||
|
self.env.seed(seed)
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
if hasattr(self.env, 'render'):
|
||||||
|
self.env.render()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
if hasattr(self.env, 'close'):
|
||||||
|
self.env.close()
|
||||||
|
|
||||||
def _make_batch(data):
|
def _make_batch(data):
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
return data[None]
|
return data[None]
|
||||||
@ -66,9 +79,10 @@ class Collector(object):
|
|||||||
def collect(self, n_step=0, n_episode=0):
|
def collect(self, n_step=0, n_episode=0):
|
||||||
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
assert sum([(n_step > 0), (n_episode > 0)]) == 1,\
|
||||||
"One and only one collection number specification permitted!"
|
"One and only one collection number specification permitted!"
|
||||||
cur_step, cur_episode = 0, 0
|
cur_step = 0
|
||||||
|
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
|
||||||
while True:
|
while True:
|
||||||
if self.multi_env:
|
if self._multi_env:
|
||||||
batch_data = Batch(
|
batch_data = Batch(
|
||||||
obs=self._obs, act=self._act, rew=self._rew,
|
obs=self._obs, act=self._act, rew=self._rew,
|
||||||
done=self._done, obs_next=None, info=self._info)
|
done=self._done, obs_next=None, info=self._info)
|
||||||
@ -78,8 +92,9 @@ class Collector(object):
|
|||||||
act=self._make_batch(self._act),
|
act=self._make_batch(self._act),
|
||||||
rew=self._make_batch(self._rew),
|
rew=self._make_batch(self._rew),
|
||||||
done=self._make_batch(self._done),
|
done=self._make_batch(self._done),
|
||||||
obs_next=None, info=self._make_batch(self._info))
|
obs_next=None,
|
||||||
result = self.policy.act(batch_data, self.state)
|
info=self._make_batch(self._info))
|
||||||
|
result = self.policy(batch_data, self.state)
|
||||||
self.state = result.state if hasattr(result, 'state') else None
|
self.state = result.state if hasattr(result, 'state') else None
|
||||||
self._act = result.act
|
self._act = result.act
|
||||||
obs_next, self._rew, self._done, self._info = self.env.step(
|
obs_next, self._rew, self._done, self._info = self.env.step(
|
||||||
@ -88,6 +103,9 @@ class Collector(object):
|
|||||||
self.reward += self._rew
|
self.reward += self._rew
|
||||||
if self._multi_env:
|
if self._multi_env:
|
||||||
for i in range(self.env_num):
|
for i in range(self.env_num):
|
||||||
|
if not self.env.is_reset_after_done()\
|
||||||
|
and cur_episode[i] > 0:
|
||||||
|
continue
|
||||||
data = {
|
data = {
|
||||||
'obs': self._obs[i], 'act': self._act[i],
|
'obs': self._obs[i], 'act': self._act[i],
|
||||||
'rew': self._rew[i], 'done': self._done[i],
|
'rew': self._rew[i], 'done': self._done[i],
|
||||||
@ -101,7 +119,7 @@ class Collector(object):
|
|||||||
self.buffer.add(**data)
|
self.buffer.add(**data)
|
||||||
cur_step += 1
|
cur_step += 1
|
||||||
if self._done[i]:
|
if self._done[i]:
|
||||||
cur_episode += 1
|
cur_episode[i] += 1
|
||||||
self.stat_reward.add(self.reward[i])
|
self.stat_reward.add(self.reward[i])
|
||||||
self.stat_length.add(self.length[i])
|
self.stat_length.add(self.length[i])
|
||||||
self.reward[i], self.length[i] = 0, 0
|
self.reward[i], self.length[i] = 0, 0
|
||||||
@ -111,12 +129,12 @@ class Collector(object):
|
|||||||
self._cached_buf[i].reset()
|
self._cached_buf[i].reset()
|
||||||
if isinstance(self.state, list):
|
if isinstance(self.state, list):
|
||||||
self.state[i] = None
|
self.state[i] = None
|
||||||
else:
|
elif self.state is not None:
|
||||||
self.state[i] = self.state[i] * 0
|
self.state[i] = self.state[i] * 0
|
||||||
if isinstance(self.state, torch.Tensor):
|
if isinstance(self.state, torch.Tensor):
|
||||||
# remove ref in torch (?)
|
# remove ref count in pytorch (?)
|
||||||
self.state = self.state.detach()
|
self.state = self.state.detach()
|
||||||
if n_episode > 0 and cur_episode >= n_episode:
|
if n_episode > 0 and cur_episode.sum() >= n_episode:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self.buffer.add(
|
self.buffer.add(
|
||||||
@ -141,13 +159,13 @@ class Collector(object):
|
|||||||
if batch_size > 0:
|
if batch_size > 0:
|
||||||
lens = [len(b) for b in self.buffer]
|
lens = [len(b) for b in self.buffer]
|
||||||
total = sum(lens)
|
total = sum(lens)
|
||||||
ib = np.random.choice(
|
batch_index = np.random.choice(
|
||||||
total, batch_size, p=np.array(lens) / total)
|
total, batch_size, p=np.array(lens) / total)
|
||||||
else:
|
else:
|
||||||
ib = np.array([])
|
batch_index = np.array([])
|
||||||
batch_data = Batch()
|
batch_data = Batch()
|
||||||
for i, b in enumerate(self.buffer):
|
for i, b in enumerate(self.buffer):
|
||||||
cur_batch = (ib == i).sum()
|
cur_batch = (batch_index == i).sum()
|
||||||
if batch_size and cur_batch or batch_size <= 0:
|
if batch_size and cur_batch or batch_size <= 0:
|
||||||
batch, indice = b.sample(cur_batch)
|
batch, indice = b.sample(cur_batch)
|
||||||
batch = self.process_fn(batch, b, indice)
|
batch = self.process_fn(batch, b, indice)
|
||||||
|
|||||||
37
tianshou/env/wrapper.py
vendored
37
tianshou/env/wrapper.py
vendored
@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from abc import ABC
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from multiprocessing import Process, Pipe
|
from multiprocessing import Process, Pipe
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray
|
||||||
@ -63,6 +63,32 @@ class BaseVectorEnv(ABC):
|
|||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
self._reset_after_done = reset_after_done
|
self._reset_after_done = reset_after_done
|
||||||
|
|
||||||
|
def is_reset_after_done(self):
|
||||||
|
return self._reset_after_done
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.env_num
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def step(self, action):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def seed(self, seed=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def render(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class VectorEnv(BaseVectorEnv):
|
class VectorEnv(BaseVectorEnv):
|
||||||
"""docstring for VectorEnv"""
|
"""docstring for VectorEnv"""
|
||||||
@ -71,9 +97,6 @@ class VectorEnv(BaseVectorEnv):
|
|||||||
super().__init__(env_fns, reset_after_done)
|
super().__init__(env_fns, reset_after_done)
|
||||||
self.envs = [_() for _ in env_fns]
|
self.envs = [_() for _ in env_fns]
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.envs)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return np.stack([e.reset() for e in self.envs])
|
return np.stack([e.reset() for e in self.envs])
|
||||||
|
|
||||||
@ -148,9 +171,6 @@ class SubprocVectorEnv(BaseVectorEnv):
|
|||||||
for c in self.child_remote:
|
for c in self.child_remote:
|
||||||
c.close()
|
c.close()
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.env_num
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert len(action) == self.env_num
|
assert len(action) == self.env_num
|
||||||
for p, a in zip(self.parent_remote, action):
|
for p, a in zip(self.parent_remote, action):
|
||||||
@ -203,9 +223,6 @@ class RayVectorEnv(BaseVectorEnv):
|
|||||||
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
|
ray.remote(EnvWrapper).options(num_cpus=0).remote(e())
|
||||||
for e in env_fns]
|
for e in env_fns]
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.env_num
|
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
assert len(action) == self.env_num
|
assert len(action) == self.env_num
|
||||||
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
result_obj = [e.step.remote(a) for e, a in zip(self.envs, action)]
|
||||||
|
|||||||
@ -6,25 +6,21 @@ class BasePolicy(ABC):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.model = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def act(self, batch, hidden_state=None):
|
def __call__(self, batch, hidden_state=None):
|
||||||
# return Batch(policy, action, hidden)
|
# return Batch(policy, action, hidden)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def train(self):
|
@abstractmethod
|
||||||
pass
|
def learn(self, batch):
|
||||||
|
|
||||||
def eval(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weight(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def exploration(self):
|
def exploration(self):
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
@ -9,25 +10,86 @@ from tianshou.policy import BasePolicy
|
|||||||
class DQNPolicy(BasePolicy, nn.Module):
|
class DQNPolicy(BasePolicy, nn.Module):
|
||||||
"""docstring for DQNPolicy"""
|
"""docstring for DQNPolicy"""
|
||||||
|
|
||||||
def __init__(self, model, discount_factor=0.99, estimation_step=1,
|
def __init__(self, model, optim, loss,
|
||||||
|
discount_factor=0.99,
|
||||||
|
estimation_step=1,
|
||||||
use_target_network=True):
|
use_target_network=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.optim = optim
|
||||||
|
self.loss = loss
|
||||||
|
self.eps = 0
|
||||||
|
assert 0 <= discount_factor <= 1, 'discount_factor should in [0, 1]'
|
||||||
self._gamma = discount_factor
|
self._gamma = discount_factor
|
||||||
|
assert estimation_step > 0, 'estimation_step should greater than 0'
|
||||||
self._n_step = estimation_step
|
self._n_step = estimation_step
|
||||||
self._target = use_target_network
|
self._target = use_target_network
|
||||||
if use_target_network:
|
if use_target_network:
|
||||||
self.model_old = deepcopy(self.model)
|
self.model_old = deepcopy(self.model)
|
||||||
|
self.model_old.eval()
|
||||||
|
|
||||||
def act(self, batch, hidden_state=None):
|
def __call__(self, batch, hidden_state=None,
|
||||||
batch_result = Batch()
|
model='model', input='obs', eps=None):
|
||||||
return batch_result
|
model = getattr(self, model)
|
||||||
|
obs = getattr(batch, input)
|
||||||
|
q, h = model(obs, hidden_state=hidden_state, info=batch.info)
|
||||||
|
act = q.max(dim=1)[1].detach().cpu().numpy()
|
||||||
|
# add eps to act
|
||||||
|
for i in range(len(q)):
|
||||||
|
if np.random.rand() < self.eps:
|
||||||
|
act[i] = np.random.randint(q.shape[1])
|
||||||
|
return Batch(Q=q, act=act, state=h)
|
||||||
|
|
||||||
def sync_weights(self):
|
def set_eps(self, eps):
|
||||||
if self._use_target_network:
|
self.eps = eps
|
||||||
for old, new in zip(
|
|
||||||
self.model_old.parameters(), self.model.parameters()):
|
def train(self):
|
||||||
old.data.copy_(new.data)
|
self.training = True
|
||||||
|
self.model.train()
|
||||||
|
|
||||||
|
def eval(self):
|
||||||
|
self.training = False
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def sync_weight(self):
|
||||||
|
if self._target:
|
||||||
|
self.model_old.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
def process_fn(self, batch, buffer, indice):
|
def process_fn(self, batch, buffer, indice):
|
||||||
|
returns = np.zeros_like(indice)
|
||||||
|
gammas = np.zeros_like(indice) + self._n_step
|
||||||
|
for n in range(self._n_step - 1, -1, -1):
|
||||||
|
now = (indice + n) % len(buffer)
|
||||||
|
gammas[buffer.done[now] > 0] = n
|
||||||
|
returns[buffer.done[now] > 0] = 0
|
||||||
|
returns = buffer.rew[now] + self._gamma * returns
|
||||||
|
terminal = (indice + self._n_step - 1) % len(buffer)
|
||||||
|
if self._target:
|
||||||
|
# target_Q = Q_old(s_, argmax(Q_new(s_, *)))
|
||||||
|
a = self(buffer[terminal], input='obs_next').act
|
||||||
|
target_q = self(
|
||||||
|
buffer[terminal], model='model_old', input='obs_next').Q
|
||||||
|
if isinstance(target_q, torch.Tensor):
|
||||||
|
target_q = target_q.detach().cpu().numpy()
|
||||||
|
target_q = target_q[np.arange(len(a)), a]
|
||||||
|
else:
|
||||||
|
target_q = self(buffer[terminal], input='obs_next').Q
|
||||||
|
if isinstance(target_q, torch.Tensor):
|
||||||
|
target_q = target_q.detach().cpu().numpy()
|
||||||
|
target_q = target_q.max(axis=1)
|
||||||
|
target_q[gammas != self._n_step] = 0
|
||||||
|
returns += (self._gamma ** gammas) * target_q
|
||||||
|
batch.update(returns=returns)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
def learn(self, batch):
|
||||||
|
self.optim.zero_grad()
|
||||||
|
q = self(batch).Q
|
||||||
|
q = q[np.arange(len(q)), batch.act]
|
||||||
|
r = batch.returns
|
||||||
|
if isinstance(r, np.ndarray):
|
||||||
|
r = torch.tensor(r, device=q.device, dtype=q.dtype)
|
||||||
|
loss = self.loss(q, r)
|
||||||
|
loss.backward()
|
||||||
|
self.optim.step()
|
||||||
|
return loss.detach().cpu().numpy()
|
||||||
|
|||||||
@ -21,3 +21,11 @@ class MovAvg(object):
|
|||||||
if len(self.cache) == 0:
|
if len(self.cache) == 0:
|
||||||
return 0
|
return 0
|
||||||
return np.mean(self.cache)
|
return np.mean(self.cache)
|
||||||
|
|
||||||
|
def mean(self):
|
||||||
|
return self.get()
|
||||||
|
|
||||||
|
def std(self):
|
||||||
|
if len(self.cache) == 0:
|
||||||
|
return 0
|
||||||
|
return np.std(self.cache)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user