add vizdoom example, bump version to 0.4.2 (#384)
This commit is contained in:
parent
c0bc8e00ca
commit
ebaca6f8da
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
- name: Test with pytest
|
||||
# ignore test/throughput which only profiles the code
|
||||
run: |
|
||||
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v
|
||||
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v1
|
||||
with:
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
[](https://pypi.org/project/tianshou/)
|
||||
[](https://github.com/conda-forge/tianshou-feedstock)
|
||||
[](https://tianshou.readthedocs.io/en/latest)
|
||||
[](https://tianshou.readthedocs.io/en/master)
|
||||
[](https://tianshou.readthedocs.io/zh/latest/)
|
||||
[](https://github.com/thu-ml/tianshou/actions)
|
||||
[](https://codecov.io/gh/thu-ml/tianshou)
|
||||
@ -14,7 +14,6 @@
|
||||
[](https://github.com/thu-ml/tianshou/stargazers)
|
||||
[](https://github.com/thu-ml/tianshou/network)
|
||||
[](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
|
||||
[](https://gitter.im/thu-ml/tianshou?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
||||
|
||||
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
|
||||
|
||||
|
1
examples/vizdoom/.gitignore
vendored
Normal file
1
examples/vizdoom/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
_vizdoom.ini
|
66
examples/vizdoom/README.md
Normal file
66
examples/vizdoom/README.md
Normal file
@ -0,0 +1,66 @@
|
||||
# ViZDoom
|
||||
|
||||
[ViZDoom](https://github.com/mwydmuch/ViZDoom) is a popular RL env for a famous first-person shooting game Doom. Here we provide some results and intuitions for this scenario.
|
||||
|
||||
## Train
|
||||
|
||||
To train an agent:
|
||||
|
||||
```bash
|
||||
python3 vizdoom_c51.py --task {D1_basic|D3_battle|D4_battle2}
|
||||
```
|
||||
|
||||
D1 (health gathering) should finish training (no death) in less than 500k env step (5 epochs);
|
||||
|
||||
D3 can reach 1600+ reward (75+ killcount in 5 minutes);
|
||||
|
||||
D4 can reach 700+ reward. Here is the result:
|
||||
|
||||
(episode length, the maximum length is 2625 because we use frameskip=4, that is 10500/4=2625)
|
||||
|
||||

|
||||
|
||||
(episode reward)
|
||||
|
||||

|
||||
|
||||
To evaluate an agent's performance:
|
||||
|
||||
```bash
|
||||
python3 vizdoom_c51.py --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2}
|
||||
```
|
||||
|
||||
To save `.lmp` files for recording:
|
||||
|
||||
```bash
|
||||
python3 vizdoom_c51.py --save-lmp --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2}
|
||||
```
|
||||
|
||||
it will store `lmp` file in `lmps/` directory. To watch these `lmp` files (for example, d3 lmp):
|
||||
|
||||
```bash
|
||||
python3 replay.py maps/D3_battle.cfg episode_8_25.lmp
|
||||
```
|
||||
|
||||
We provide two lmp files (d3 best and d4 best) under `results/c51`, you can use the following command to enjoy:
|
||||
|
||||
```bash
|
||||
python3 replay.py maps/D3_battle.cfg results/c51/d3.lmp
|
||||
python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp
|
||||
```
|
||||
|
||||
## Maps
|
||||
|
||||
See [maps/README.md](maps/README.md)
|
||||
|
||||
## Algorithms
|
||||
|
||||
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
|
||||
|
||||
## Reward
|
||||
|
||||
1. living reward is bad
|
||||
2. combo-action is really important
|
||||
3. negative reward for health and ammo2 is really helpful for d3/d4
|
||||
4. only with positive reward for health is really helpful for d1
|
||||
5. remove MOVE_BACKWARD may converge faster but the final performance may be lower
|
129
examples/vizdoom/env.py
Normal file
129
examples/vizdoom/env.py
Normal file
@ -0,0 +1,129 @@
|
||||
import os
|
||||
import cv2
|
||||
import gym
|
||||
import numpy as np
|
||||
import vizdoom as vzd
|
||||
|
||||
|
||||
def normal_button_comb():
|
||||
actions = []
|
||||
m_forward = [[0.0], [1.0]]
|
||||
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||
for i in m_forward:
|
||||
for j in t_left_right:
|
||||
actions.append(i + j)
|
||||
return actions
|
||||
|
||||
|
||||
def battle_button_comb():
|
||||
actions = []
|
||||
m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||
m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||
attack = [[0.0], [1.0]]
|
||||
speed = [[0.0], [1.0]]
|
||||
|
||||
for m in attack:
|
||||
for n in speed:
|
||||
for j in m_left_right:
|
||||
for i in m_forward_backward:
|
||||
for k in t_left_right:
|
||||
actions.append(i + j + k + m + n)
|
||||
return actions
|
||||
|
||||
|
||||
class Env(gym.Env):
|
||||
def __init__(
|
||||
self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False
|
||||
):
|
||||
super().__init__()
|
||||
self.save_lmp = save_lmp
|
||||
self.health_setting = "battle" in cfg_path
|
||||
if save_lmp:
|
||||
os.makedirs("lmps", exist_ok=True)
|
||||
self.res = res
|
||||
self.skip = frameskip
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0, high=255, shape=res, dtype=np.float32
|
||||
)
|
||||
self.game = vzd.DoomGame()
|
||||
self.game.load_config(cfg_path)
|
||||
self.game.init()
|
||||
if "battle" in cfg_path:
|
||||
self.available_actions = battle_button_comb()
|
||||
else:
|
||||
self.available_actions = normal_button_comb()
|
||||
self.action_num = len(self.available_actions)
|
||||
self.action_space = gym.spaces.Discrete(self.action_num)
|
||||
self.spec = gym.envs.registration.EnvSpec("vizdoom-v0")
|
||||
self.count = 0
|
||||
|
||||
def get_obs(self):
|
||||
state = self.game.get_state()
|
||||
if state is None:
|
||||
return
|
||||
obs = state.screen_buffer
|
||||
self.obs_buffer[:-1] = self.obs_buffer[1:]
|
||||
self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2]))
|
||||
|
||||
def reset(self):
|
||||
if self.save_lmp:
|
||||
self.game.new_episode(f"lmps/episode_{self.count}.lmp")
|
||||
else:
|
||||
self.game.new_episode()
|
||||
self.count += 1
|
||||
self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
|
||||
self.get_obs()
|
||||
self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
|
||||
self.killcount = self.game.get_game_variable(
|
||||
vzd.GameVariable.KILLCOUNT)
|
||||
self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
|
||||
return self.obs_buffer
|
||||
|
||||
def step(self, action):
|
||||
self.game.make_action(self.available_actions[action], self.skip)
|
||||
reward = 0.0
|
||||
self.get_obs()
|
||||
health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
|
||||
if self.health_setting:
|
||||
reward += health - self.health
|
||||
elif health > self.health: # positive health reward only for d1/d2
|
||||
reward += health - self.health
|
||||
self.health = health
|
||||
killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
|
||||
reward += 20 * (killcount - self.killcount)
|
||||
self.killcount = killcount
|
||||
ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
|
||||
# if ammo2 > self.ammo2:
|
||||
reward += ammo2 - self.ammo2
|
||||
self.ammo2 = ammo2
|
||||
done = False
|
||||
info = {}
|
||||
if self.game.is_player_dead() or self.game.get_state() is None:
|
||||
done = True
|
||||
elif self.game.is_episode_finished():
|
||||
done = True
|
||||
info["TimeLimit.truncated"] = True
|
||||
return self.obs_buffer, reward, done, info
|
||||
|
||||
def render(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.game.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
|
||||
env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))
|
||||
print(env.available_actions)
|
||||
action_num = env.action_space.n
|
||||
obs = env.reset()
|
||||
print(env.spec.reward_threshold)
|
||||
print(obs.shape, action_num)
|
||||
for i in range(4000):
|
||||
obs, rew, done, info = env.step(0)
|
||||
if done:
|
||||
env.reset()
|
||||
print(obs.shape, rew, done)
|
||||
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])
|
39
examples/vizdoom/maps/D1_basic.cfg
Normal file
39
examples/vizdoom/maps/D1_basic.cfg
Normal file
@ -0,0 +1,39 @@
|
||||
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||
# It doesn't matter if you use capital letters or not.
|
||||
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||
|
||||
doom_scenario_path = D1_basic.wad
|
||||
doom_map = map01
|
||||
|
||||
# Rewards
|
||||
|
||||
# Each step is good for you!
|
||||
living_reward = 0
|
||||
# And death is not!
|
||||
death_penalty = 0
|
||||
|
||||
# Rendering options
|
||||
screen_resolution = RES_160X120
|
||||
screen_format = GRAY8
|
||||
render_hud = false
|
||||
render_crosshair = false
|
||||
render_weapon = false
|
||||
render_decals = false
|
||||
render_particles = false
|
||||
window_visible = false
|
||||
|
||||
# make episodes finish after 10500 actions (tics)
|
||||
episode_timeout = 10500
|
||||
|
||||
# Available buttons
|
||||
available_buttons =
|
||||
{
|
||||
MOVE_FORWARD
|
||||
TURN_LEFT
|
||||
TURN_RIGHT
|
||||
}
|
||||
|
||||
# Game variables that will be in the state
|
||||
available_game_variables = { HEALTH }
|
||||
|
||||
mode = PLAYER
|
BIN
examples/vizdoom/maps/D1_basic.wad
Normal file
BIN
examples/vizdoom/maps/D1_basic.wad
Normal file
Binary file not shown.
39
examples/vizdoom/maps/D2_navigation.cfg
Normal file
39
examples/vizdoom/maps/D2_navigation.cfg
Normal file
@ -0,0 +1,39 @@
|
||||
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||
# It doesn't matter if you use capital letters or not.
|
||||
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||
|
||||
doom_scenario_path = D2_navigation.wad
|
||||
doom_map = map01
|
||||
|
||||
# Rewards
|
||||
|
||||
# Each step is good for you!
|
||||
living_reward = 0
|
||||
# And death is not!
|
||||
death_penalty = 0
|
||||
|
||||
# Rendering options
|
||||
screen_resolution = RES_160X120
|
||||
screen_format = GRAY8
|
||||
render_hud = false
|
||||
render_crosshair = false
|
||||
render_weapon = false
|
||||
render_decals = false
|
||||
render_particles = false
|
||||
window_visible = false
|
||||
|
||||
# make episodes finish after 10500 actions (tics)
|
||||
episode_timeout = 10500
|
||||
|
||||
# Available buttons
|
||||
available_buttons =
|
||||
{
|
||||
MOVE_FORWARD
|
||||
TURN_LEFT
|
||||
TURN_RIGHT
|
||||
}
|
||||
|
||||
# Game variables that will be in the state
|
||||
available_game_variables = { HEALTH }
|
||||
|
||||
mode = PLAYER
|
BIN
examples/vizdoom/maps/D2_navigation.wad
Normal file
BIN
examples/vizdoom/maps/D2_navigation.wad
Normal file
Binary file not shown.
48
examples/vizdoom/maps/D3_battle.cfg
Normal file
48
examples/vizdoom/maps/D3_battle.cfg
Normal file
@ -0,0 +1,48 @@
|
||||
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||
# It doesn't matter if you use capital letters or not.
|
||||
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||
|
||||
doom_scenario_path = D3_battle.wad
|
||||
doom_map = map01
|
||||
|
||||
# Rewards
|
||||
|
||||
living_reward = 0
|
||||
death_penalty = 100
|
||||
|
||||
# Rendering options
|
||||
screen_resolution = RES_160X120
|
||||
screen_format = GRAY8
|
||||
render_hud = false
|
||||
render_crosshair = true
|
||||
render_weapon = true
|
||||
render_decals = false
|
||||
render_particles = false
|
||||
window_visible = false
|
||||
|
||||
# make episodes finish after 10500 actions (tics)
|
||||
episode_timeout = 10500
|
||||
|
||||
# Available buttons
|
||||
available_buttons =
|
||||
{
|
||||
MOVE_FORWARD
|
||||
MOVE_BACKWARD
|
||||
MOVE_LEFT
|
||||
MOVE_RIGHT
|
||||
TURN_LEFT
|
||||
TURN_RIGHT
|
||||
ATTACK
|
||||
SPEED
|
||||
}
|
||||
|
||||
# Game variables that will be in the state
|
||||
available_game_variables =
|
||||
{
|
||||
KILLCOUNT
|
||||
AMMO2
|
||||
HEALTH
|
||||
}
|
||||
|
||||
mode = PLAYER
|
||||
doom_skill = 2
|
BIN
examples/vizdoom/maps/D3_battle.wad
Normal file
BIN
examples/vizdoom/maps/D3_battle.wad
Normal file
Binary file not shown.
48
examples/vizdoom/maps/D4_battle2.cfg
Normal file
48
examples/vizdoom/maps/D4_battle2.cfg
Normal file
@ -0,0 +1,48 @@
|
||||
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||
# It doesn't matter if you use capital letters or not.
|
||||
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||
|
||||
doom_scenario_path = D4_battle2.wad
|
||||
doom_map = map01
|
||||
|
||||
# Rewards
|
||||
|
||||
living_reward = 0
|
||||
death_penalty = 100
|
||||
|
||||
# Rendering options
|
||||
screen_resolution = RES_160X120
|
||||
screen_format = GRAY8
|
||||
render_hud = false
|
||||
render_crosshair = true
|
||||
render_weapon = true
|
||||
render_decals = false
|
||||
render_particles = false
|
||||
window_visible = false
|
||||
|
||||
# make episodes finish after 10500 actions (tics)
|
||||
episode_timeout = 10500
|
||||
|
||||
# Available buttons
|
||||
available_buttons =
|
||||
{
|
||||
MOVE_FORWARD
|
||||
MOVE_BACKWARD
|
||||
MOVE_LEFT
|
||||
MOVE_RIGHT
|
||||
TURN_LEFT
|
||||
TURN_RIGHT
|
||||
ATTACK
|
||||
SPEED
|
||||
}
|
||||
|
||||
# Game variables that will be in the state
|
||||
available_game_variables =
|
||||
{
|
||||
KILLCOUNT
|
||||
AMMO2
|
||||
HEALTH
|
||||
}
|
||||
|
||||
mode = PLAYER
|
||||
doom_skill = 2
|
BIN
examples/vizdoom/maps/D4_battle2.wad
Normal file
BIN
examples/vizdoom/maps/D4_battle2.wad
Normal file
Binary file not shown.
3
examples/vizdoom/maps/README.md
Normal file
3
examples/vizdoom/maps/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
D1-D4 maps are from https://github.com/intel-isl/DirectFuturePrediction/
|
||||
|
||||
More maps and cfgs: https://github.com/mwydmuch/ViZDoom/tree/master/scenarios
|
71
examples/vizdoom/maps/spectator.py
Normal file
71
examples/vizdoom/maps/spectator.py
Normal file
@ -0,0 +1,71 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
#####################################################################
|
||||
# This script presents SPECTATOR mode. In SPECTATOR mode you play and
|
||||
# your agent can learn from it.
|
||||
# Configuration is loaded from "../../scenarios/<SCENARIO_NAME>.cfg" file.
|
||||
#
|
||||
# To see the scenario description go to "../../scenarios/README.md"
|
||||
#####################################################################
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from time import sleep
|
||||
import vizdoom as vzd
|
||||
from argparse import ArgumentParser
|
||||
# import cv2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.")
|
||||
parser.add_argument('-c', type=str, dest="config", default="D3_battle.cfg")
|
||||
parser.add_argument('-w', type=str, dest="wad_file", default="D3_battle.wad")
|
||||
args = parser.parse_args()
|
||||
game = vzd.DoomGame()
|
||||
|
||||
# Choose scenario config file you wish to watch.
|
||||
# Don't load two configs cause the second will overrite the first one.
|
||||
# Multiple config files are ok but combining these ones doesn't make much sense.
|
||||
|
||||
game.load_config(args.config)
|
||||
game.set_doom_scenario_path(args.wad_file)
|
||||
# Enables freelook in engine
|
||||
game.add_game_args("+freelook 1")
|
||||
|
||||
game.set_screen_resolution(vzd.ScreenResolution.RES_640X480)
|
||||
|
||||
# Enables spectator mode, so you can play.
|
||||
# Sounds strange but it is the agent who is supposed to watch not you.
|
||||
game.set_window_visible(True)
|
||||
game.set_mode(vzd.Mode.SPECTATOR)
|
||||
|
||||
game.init()
|
||||
|
||||
episodes = 1
|
||||
|
||||
for i in range(episodes):
|
||||
print("Episode #" + str(i + 1))
|
||||
|
||||
game.new_episode()
|
||||
while not game.is_episode_finished():
|
||||
state = game.get_state()
|
||||
print(state.screen_buffer.dtype, state.screen_buffer.shape)
|
||||
# cv2.imwrite(f'imgs/{state.number}.png', state.screen_buffer)
|
||||
|
||||
# game.make_action([0, 0, 0])
|
||||
game.advance_action()
|
||||
last_action = game.get_last_action()
|
||||
reward = game.get_last_reward()
|
||||
|
||||
print("State #" + str(state.number))
|
||||
print("Game variables: ", state.game_variables)
|
||||
print("Action:", last_action)
|
||||
print("Reward:", reward)
|
||||
print("=====================")
|
||||
|
||||
print("Episode finished!")
|
||||
print("Total reward:", game.get_total_reward())
|
||||
print("************************")
|
||||
sleep(2.0)
|
||||
|
||||
game.close()
|
1
examples/vizdoom/network.py
Symbolic link
1
examples/vizdoom/network.py
Symbolic link
@ -0,0 +1 @@
|
||||
../atari/atari_network.py
|
35
examples/vizdoom/replay.py
Executable file
35
examples/vizdoom/replay.py
Executable file
@ -0,0 +1,35 @@
|
||||
# import cv2
|
||||
import sys
|
||||
import time
|
||||
import tqdm
|
||||
import vizdoom as vzd
|
||||
|
||||
|
||||
def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp"):
|
||||
game = vzd.DoomGame()
|
||||
game.load_config(cfg_path)
|
||||
game.set_screen_format(vzd.ScreenFormat.CRCGCB)
|
||||
game.set_screen_resolution(vzd.ScreenResolution.RES_1024X576)
|
||||
game.set_window_visible(True)
|
||||
game.set_render_hud(True)
|
||||
game.init()
|
||||
game.replay_episode(lmp_path)
|
||||
|
||||
killcount = 0
|
||||
with tqdm.trange(10500) as tq:
|
||||
while not game.is_episode_finished():
|
||||
game.advance_action()
|
||||
state = game.get_state()
|
||||
if state is None:
|
||||
break
|
||||
killcount = game.get_game_variable(vzd.GameVariable.KILLCOUNT)
|
||||
time.sleep(1 / 35)
|
||||
# cv2.imwrite(f"imgs/{tq.n}.png",
|
||||
# state.screen_buffer.transpose(1, 2, 0)[..., ::-1])
|
||||
tq.update(1)
|
||||
game.close()
|
||||
print("killcount:", killcount)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(*sys.argv[-2:])
|
BIN
examples/vizdoom/results/c51/d3.lmp
Normal file
BIN
examples/vizdoom/results/c51/d3.lmp
Normal file
Binary file not shown.
BIN
examples/vizdoom/results/c51/d4.lmp
Normal file
BIN
examples/vizdoom/results/c51/d4.lmp
Normal file
Binary file not shown.
BIN
examples/vizdoom/results/c51/length.png
Normal file
BIN
examples/vizdoom/results/c51/length.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 111 KiB |
BIN
examples/vizdoom/results/c51/reward.png
Normal file
BIN
examples/vizdoom/results/c51/reward.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 90 KiB |
177
examples/vizdoom/vizdoom_c51.py
Normal file
177
examples/vizdoom/vizdoom_c51.py
Normal file
@ -0,0 +1,177 @@
|
||||
import os
|
||||
import torch
|
||||
import pprint
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import C51Policy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
|
||||
from env import Env
|
||||
from network import C51
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='D1_basic')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||
parser.add_argument('--eps-train', type=float, default=1.)
|
||||
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||
parser.add_argument('--buffer-size', type=int, default=2000000)
|
||||
parser.add_argument('--lr', type=float, default=0.0001)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--num-atoms', type=int, default=51)
|
||||
parser.add_argument('--v-min', type=float, default=-10.)
|
||||
parser.add_argument('--v-max', type=float, default=10.)
|
||||
parser.add_argument('--n-step', type=int, default=3)
|
||||
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||
parser.add_argument('--epoch', type=int, default=300)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--training-num', type=int, default=10)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--frames-stack', type=int, default=4)
|
||||
parser.add_argument('--skip-num', type=int, default=4)
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
parser.add_argument('--save-lmp', default=False, action='store_true',
|
||||
help='save lmp file for replay whole episode')
|
||||
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_c51(args=get_args()):
|
||||
args.cfg_path = f"maps/{args.task}.cfg"
|
||||
args.wad_path = f"maps/{args.task}.wad"
|
||||
args.res = (args.skip_num, 84, 84)
|
||||
env = Env(args.cfg_path, args.frames_stack, args.res)
|
||||
args.state_shape = args.res
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# should be N_FRAMES x H x W
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
# make environments
|
||||
train_envs = SubprocVectorEnv([
|
||||
lambda: Env(args.cfg_path, args.frames_stack, args.res)
|
||||
for _ in range(args.training_num)])
|
||||
test_envs = SubprocVectorEnv([
|
||||
lambda: Env(args.cfg_path, args.frames_stack,
|
||||
args.res, args.save_lmp)
|
||||
for _ in range(min(os.cpu_count() - 1, args.test_num))])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# define model
|
||||
net = C51(*args.state_shape, args.action_shape,
|
||||
args.num_atoms, args.device)
|
||||
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||
# define policy
|
||||
policy = C51Policy(
|
||||
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||
args.n_step, target_update_freq=args.target_update_freq
|
||||
).to(args.device)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||
# when you have enough RAM
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
|
||||
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
if env.spec.reward_threshold:
|
||||
return mean_rewards >= env.spec.reward_threshold
|
||||
elif 'Pong' in args.task:
|
||||
return mean_rewards >= 20
|
||||
else:
|
||||
return False
|
||||
|
||||
def train_fn(epoch, env_step):
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = args.eps_train - env_step / 1e6 * \
|
||||
(args.eps_train - args.eps_train_final)
|
||||
else:
|
||||
eps = args.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write('train/eps', env_step, eps)
|
||||
|
||||
def test_fn(epoch, env_step):
|
||||
policy.set_eps(args.eps_test)
|
||||
|
||||
# watch agent's performance
|
||||
def watch():
|
||||
print("Setup test envs ...")
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
if args.save_buffer_name:
|
||||
print(f"Generate buffer with size {args.buffer_size}")
|
||||
buffer = VectorReplayBuffer(
|
||||
args.buffer_size, buffer_num=len(test_envs),
|
||||
ignore_obs_next=True, save_only_last_obs=True,
|
||||
stack_num=args.frames_stack)
|
||||
collector = Collector(policy, test_envs, buffer,
|
||||
exploration_noise=True)
|
||||
result = collector.collect(n_step=args.buffer_size)
|
||||
print(f"Save buffer into {args.save_buffer_name}")
|
||||
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||
buffer.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
print("Testing agent ...")
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num,
|
||||
render=args.render)
|
||||
rew = result["rews"].mean()
|
||||
lens = result["lens"].mean() * args.skip_num
|
||||
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||
print(f'Mean length (over {result["n/ep"]} episodes): {lens}')
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# test train_collector and start filling replay buffer
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step, test_in_train=False)
|
||||
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_c51(get_args())
|
2
setup.py
2
setup.py
@ -55,7 +55,7 @@ setup(
|
||||
],
|
||||
extras_require={
|
||||
"dev": [
|
||||
"Sphinx",
|
||||
"sphinx<4",
|
||||
"sphinx_rtd_theme",
|
||||
"sphinxcontrib-bibtex",
|
||||
"flake8",
|
||||
|
@ -22,7 +22,7 @@ def get_args():
|
||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=1)
|
||||
parser.add_argument('--training-num', type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.0)
|
||||
parser.add_argument('--rew-mean-prior', type=float, default=0.0)
|
||||
@ -36,12 +36,12 @@ def get_args():
|
||||
def test_psrl(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
if args.task == "NChain-v0":
|
||||
env.spec.reward_threshold = 3647 # described in PSRL paper
|
||||
env.spec.reward_threshold = 3400
|
||||
# env.spec.reward_threshold = 3647 # described in PSRL paper
|
||||
print("reward threshold:", env.spec.reward_threshold)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
# train_envs = gym.make(args.task)
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = DummyVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
# test_envs = gym.make(args.task)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, exploration
|
||||
|
||||
|
||||
__version__ = "0.4.1"
|
||||
__version__ = "0.4.2"
|
||||
|
||||
__all__ = [
|
||||
"env",
|
||||
|
@ -526,7 +526,12 @@ class Batch:
|
||||
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
|
||||
self.__dict__[k] = Batch.stack(v, axis)
|
||||
else: # most often case is np.ndarray
|
||||
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
|
||||
try:
|
||||
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
|
||||
except ValueError:
|
||||
warnings.warn("You are using tensors with different shape,"
|
||||
" fallback to dtype=object by default.")
|
||||
self.__dict__[k] = np.array(v, dtype=object)
|
||||
# all the keys
|
||||
keys_total = set.union(*[set(b.keys()) for b in batches])
|
||||
# keys that are reserved in all batches
|
||||
|
@ -53,6 +53,7 @@ class ReplayBuffer:
|
||||
self._save_only_last_obs = save_only_last_obs
|
||||
self._sample_avail = sample_avail
|
||||
self._meta: Batch = Batch()
|
||||
self._ep_rew: Union[float, np.ndarray]
|
||||
self.reset()
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
@ -56,7 +56,8 @@ class Collector(object):
|
||||
exploration_noise: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not isinstance(env, BaseVectorEnv):
|
||||
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
self.env = env
|
||||
self.env_num = len(env)
|
||||
@ -223,7 +224,8 @@ class Collector(object):
|
||||
# get bounded and remapped actions first (not saved into buffer)
|
||||
action_remap = self.policy.map_action(self.data.act)
|
||||
# step in env
|
||||
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
|
||||
obs_next, rew, done, info = self.env.step(
|
||||
action_remap, ready_env_ids) # type: ignore
|
||||
|
||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
||||
if self.preprocess_fn:
|
||||
@ -426,7 +428,8 @@ class AsyncCollector(Collector):
|
||||
# get bounded and remapped actions first (not saved into buffer)
|
||||
action_remap = self.policy.map_action(self.data.act)
|
||||
# step in env
|
||||
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
|
||||
obs_next, rew, done, info = self.env.step(
|
||||
action_remap, ready_env_ids) # type: ignore
|
||||
|
||||
# change self.data here because ready_env_ids has changed
|
||||
ready_env_ids = np.array([i["env_id"] for i in info])
|
||||
|
@ -75,8 +75,8 @@ class DiscreteCRRPolicy(PGPolicy):
|
||||
self._min_q_weight = min_q_weight
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
self.actor_old.load_state_dict(self.actor.state_dict()) # type: ignore
|
||||
self.critic_old.load_state_dict(self.critic.state_dict()) # type: ignore
|
||||
self.actor_old.load_state_dict(self.actor.state_dict())
|
||||
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore
|
||||
if self._target and self._iter % self._freq == 0:
|
||||
|
@ -72,7 +72,7 @@ class DQNPolicy(BasePolicy):
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
|
@ -64,6 +64,7 @@ class TRPOPolicy(NPGPolicy):
|
||||
self._max_backtracks = max_backtracks
|
||||
self._delta = max_kl
|
||||
self._backtrack_coeff = backtrack_coeff
|
||||
self._optim_critic_iters: int
|
||||
|
||||
def learn( # type: ignore
|
||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||
|
@ -88,7 +88,7 @@ class BasicLogger(BaseLogger):
|
||||
You can also rewrite write() func to use your own writer.
|
||||
|
||||
:param SummaryWriter writer: the writer to log data.
|
||||
:param int train_interval: the log interval in log_train_data(). Default to 1.
|
||||
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||
@ -98,7 +98,7 @@ class BasicLogger(BaseLogger):
|
||||
def __init__(
|
||||
self,
|
||||
writer: SummaryWriter,
|
||||
train_interval: int = 1,
|
||||
train_interval: int = 1000,
|
||||
test_interval: int = 1,
|
||||
update_interval: int = 1000,
|
||||
save_interval: int = 1,
|
||||
|
Loading…
x
Reference in New Issue
Block a user