add vizdoom example, bump version to 0.4.2 (#384)
This commit is contained in:
parent
c0bc8e00ca
commit
ebaca6f8da
2
.github/workflows/pytest.yml
vendored
2
.github/workflows/pytest.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
|||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
# ignore test/throughput which only profiles the code
|
# ignore test/throughput which only profiles the code
|
||||||
run: |
|
run: |
|
||||||
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v
|
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v1
|
uses: codecov/codecov-action@v1
|
||||||
with:
|
with:
|
||||||
|
@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
[](https://pypi.org/project/tianshou/)
|
[](https://pypi.org/project/tianshou/)
|
||||||
[](https://github.com/conda-forge/tianshou-feedstock)
|
[](https://github.com/conda-forge/tianshou-feedstock)
|
||||||
[](https://tianshou.readthedocs.io/en/latest)
|
[](https://tianshou.readthedocs.io/en/master)
|
||||||
[](https://tianshou.readthedocs.io/zh/latest/)
|
[](https://tianshou.readthedocs.io/zh/latest/)
|
||||||
[](https://github.com/thu-ml/tianshou/actions)
|
[](https://github.com/thu-ml/tianshou/actions)
|
||||||
[](https://codecov.io/gh/thu-ml/tianshou)
|
[](https://codecov.io/gh/thu-ml/tianshou)
|
||||||
@ -14,7 +14,6 @@
|
|||||||
[](https://github.com/thu-ml/tianshou/stargazers)
|
[](https://github.com/thu-ml/tianshou/stargazers)
|
||||||
[](https://github.com/thu-ml/tianshou/network)
|
[](https://github.com/thu-ml/tianshou/network)
|
||||||
[](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
|
[](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
|
||||||
[](https://gitter.im/thu-ml/tianshou?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
|
|
||||||
|
|
||||||
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
|
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
|
||||||
|
|
||||||
|
1
examples/vizdoom/.gitignore
vendored
Normal file
1
examples/vizdoom/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
_vizdoom.ini
|
66
examples/vizdoom/README.md
Normal file
66
examples/vizdoom/README.md
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# ViZDoom
|
||||||
|
|
||||||
|
[ViZDoom](https://github.com/mwydmuch/ViZDoom) is a popular RL env for a famous first-person shooting game Doom. Here we provide some results and intuitions for this scenario.
|
||||||
|
|
||||||
|
## Train
|
||||||
|
|
||||||
|
To train an agent:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vizdoom_c51.py --task {D1_basic|D3_battle|D4_battle2}
|
||||||
|
```
|
||||||
|
|
||||||
|
D1 (health gathering) should finish training (no death) in less than 500k env step (5 epochs);
|
||||||
|
|
||||||
|
D3 can reach 1600+ reward (75+ killcount in 5 minutes);
|
||||||
|
|
||||||
|
D4 can reach 700+ reward. Here is the result:
|
||||||
|
|
||||||
|
(episode length, the maximum length is 2625 because we use frameskip=4, that is 10500/4=2625)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
(episode reward)
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
To evaluate an agent's performance:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vizdoom_c51.py --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2}
|
||||||
|
```
|
||||||
|
|
||||||
|
To save `.lmp` files for recording:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 vizdoom_c51.py --save-lmp --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2}
|
||||||
|
```
|
||||||
|
|
||||||
|
it will store `lmp` file in `lmps/` directory. To watch these `lmp` files (for example, d3 lmp):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 replay.py maps/D3_battle.cfg episode_8_25.lmp
|
||||||
|
```
|
||||||
|
|
||||||
|
We provide two lmp files (d3 best and d4 best) under `results/c51`, you can use the following command to enjoy:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 replay.py maps/D3_battle.cfg results/c51/d3.lmp
|
||||||
|
python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp
|
||||||
|
```
|
||||||
|
|
||||||
|
## Maps
|
||||||
|
|
||||||
|
See [maps/README.md](maps/README.md)
|
||||||
|
|
||||||
|
## Algorithms
|
||||||
|
|
||||||
|
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
|
||||||
|
|
||||||
|
## Reward
|
||||||
|
|
||||||
|
1. living reward is bad
|
||||||
|
2. combo-action is really important
|
||||||
|
3. negative reward for health and ammo2 is really helpful for d3/d4
|
||||||
|
4. only with positive reward for health is really helpful for d1
|
||||||
|
5. remove MOVE_BACKWARD may converge faster but the final performance may be lower
|
129
examples/vizdoom/env.py
Normal file
129
examples/vizdoom/env.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import gym
|
||||||
|
import numpy as np
|
||||||
|
import vizdoom as vzd
|
||||||
|
|
||||||
|
|
||||||
|
def normal_button_comb():
|
||||||
|
actions = []
|
||||||
|
m_forward = [[0.0], [1.0]]
|
||||||
|
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||||
|
for i in m_forward:
|
||||||
|
for j in t_left_right:
|
||||||
|
actions.append(i + j)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def battle_button_comb():
|
||||||
|
actions = []
|
||||||
|
m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||||
|
m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||||
|
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
|
||||||
|
attack = [[0.0], [1.0]]
|
||||||
|
speed = [[0.0], [1.0]]
|
||||||
|
|
||||||
|
for m in attack:
|
||||||
|
for n in speed:
|
||||||
|
for j in m_left_right:
|
||||||
|
for i in m_forward_backward:
|
||||||
|
for k in t_left_right:
|
||||||
|
actions.append(i + j + k + m + n)
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
class Env(gym.Env):
|
||||||
|
def __init__(
|
||||||
|
self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.save_lmp = save_lmp
|
||||||
|
self.health_setting = "battle" in cfg_path
|
||||||
|
if save_lmp:
|
||||||
|
os.makedirs("lmps", exist_ok=True)
|
||||||
|
self.res = res
|
||||||
|
self.skip = frameskip
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=0, high=255, shape=res, dtype=np.float32
|
||||||
|
)
|
||||||
|
self.game = vzd.DoomGame()
|
||||||
|
self.game.load_config(cfg_path)
|
||||||
|
self.game.init()
|
||||||
|
if "battle" in cfg_path:
|
||||||
|
self.available_actions = battle_button_comb()
|
||||||
|
else:
|
||||||
|
self.available_actions = normal_button_comb()
|
||||||
|
self.action_num = len(self.available_actions)
|
||||||
|
self.action_space = gym.spaces.Discrete(self.action_num)
|
||||||
|
self.spec = gym.envs.registration.EnvSpec("vizdoom-v0")
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def get_obs(self):
|
||||||
|
state = self.game.get_state()
|
||||||
|
if state is None:
|
||||||
|
return
|
||||||
|
obs = state.screen_buffer
|
||||||
|
self.obs_buffer[:-1] = self.obs_buffer[1:]
|
||||||
|
self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2]))
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if self.save_lmp:
|
||||||
|
self.game.new_episode(f"lmps/episode_{self.count}.lmp")
|
||||||
|
else:
|
||||||
|
self.game.new_episode()
|
||||||
|
self.count += 1
|
||||||
|
self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
|
||||||
|
self.get_obs()
|
||||||
|
self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
|
||||||
|
self.killcount = self.game.get_game_variable(
|
||||||
|
vzd.GameVariable.KILLCOUNT)
|
||||||
|
self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
|
||||||
|
return self.obs_buffer
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
self.game.make_action(self.available_actions[action], self.skip)
|
||||||
|
reward = 0.0
|
||||||
|
self.get_obs()
|
||||||
|
health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
|
||||||
|
if self.health_setting:
|
||||||
|
reward += health - self.health
|
||||||
|
elif health > self.health: # positive health reward only for d1/d2
|
||||||
|
reward += health - self.health
|
||||||
|
self.health = health
|
||||||
|
killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
|
||||||
|
reward += 20 * (killcount - self.killcount)
|
||||||
|
self.killcount = killcount
|
||||||
|
ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
|
||||||
|
# if ammo2 > self.ammo2:
|
||||||
|
reward += ammo2 - self.ammo2
|
||||||
|
self.ammo2 = ammo2
|
||||||
|
done = False
|
||||||
|
info = {}
|
||||||
|
if self.game.is_player_dead() or self.game.get_state() is None:
|
||||||
|
done = True
|
||||||
|
elif self.game.is_episode_finished():
|
||||||
|
done = True
|
||||||
|
info["TimeLimit.truncated"] = True
|
||||||
|
return self.obs_buffer, reward, done, info
|
||||||
|
|
||||||
|
def render(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.game.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
|
||||||
|
env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))
|
||||||
|
print(env.available_actions)
|
||||||
|
action_num = env.action_space.n
|
||||||
|
obs = env.reset()
|
||||||
|
print(env.spec.reward_threshold)
|
||||||
|
print(obs.shape, action_num)
|
||||||
|
for i in range(4000):
|
||||||
|
obs, rew, done, info = env.step(0)
|
||||||
|
if done:
|
||||||
|
env.reset()
|
||||||
|
print(obs.shape, rew, done)
|
||||||
|
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])
|
39
examples/vizdoom/maps/D1_basic.cfg
Normal file
39
examples/vizdoom/maps/D1_basic.cfg
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||||
|
# It doesn't matter if you use capital letters or not.
|
||||||
|
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||||
|
|
||||||
|
doom_scenario_path = D1_basic.wad
|
||||||
|
doom_map = map01
|
||||||
|
|
||||||
|
# Rewards
|
||||||
|
|
||||||
|
# Each step is good for you!
|
||||||
|
living_reward = 0
|
||||||
|
# And death is not!
|
||||||
|
death_penalty = 0
|
||||||
|
|
||||||
|
# Rendering options
|
||||||
|
screen_resolution = RES_160X120
|
||||||
|
screen_format = GRAY8
|
||||||
|
render_hud = false
|
||||||
|
render_crosshair = false
|
||||||
|
render_weapon = false
|
||||||
|
render_decals = false
|
||||||
|
render_particles = false
|
||||||
|
window_visible = false
|
||||||
|
|
||||||
|
# make episodes finish after 10500 actions (tics)
|
||||||
|
episode_timeout = 10500
|
||||||
|
|
||||||
|
# Available buttons
|
||||||
|
available_buttons =
|
||||||
|
{
|
||||||
|
MOVE_FORWARD
|
||||||
|
TURN_LEFT
|
||||||
|
TURN_RIGHT
|
||||||
|
}
|
||||||
|
|
||||||
|
# Game variables that will be in the state
|
||||||
|
available_game_variables = { HEALTH }
|
||||||
|
|
||||||
|
mode = PLAYER
|
BIN
examples/vizdoom/maps/D1_basic.wad
Normal file
BIN
examples/vizdoom/maps/D1_basic.wad
Normal file
Binary file not shown.
39
examples/vizdoom/maps/D2_navigation.cfg
Normal file
39
examples/vizdoom/maps/D2_navigation.cfg
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||||
|
# It doesn't matter if you use capital letters or not.
|
||||||
|
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||||
|
|
||||||
|
doom_scenario_path = D2_navigation.wad
|
||||||
|
doom_map = map01
|
||||||
|
|
||||||
|
# Rewards
|
||||||
|
|
||||||
|
# Each step is good for you!
|
||||||
|
living_reward = 0
|
||||||
|
# And death is not!
|
||||||
|
death_penalty = 0
|
||||||
|
|
||||||
|
# Rendering options
|
||||||
|
screen_resolution = RES_160X120
|
||||||
|
screen_format = GRAY8
|
||||||
|
render_hud = false
|
||||||
|
render_crosshair = false
|
||||||
|
render_weapon = false
|
||||||
|
render_decals = false
|
||||||
|
render_particles = false
|
||||||
|
window_visible = false
|
||||||
|
|
||||||
|
# make episodes finish after 10500 actions (tics)
|
||||||
|
episode_timeout = 10500
|
||||||
|
|
||||||
|
# Available buttons
|
||||||
|
available_buttons =
|
||||||
|
{
|
||||||
|
MOVE_FORWARD
|
||||||
|
TURN_LEFT
|
||||||
|
TURN_RIGHT
|
||||||
|
}
|
||||||
|
|
||||||
|
# Game variables that will be in the state
|
||||||
|
available_game_variables = { HEALTH }
|
||||||
|
|
||||||
|
mode = PLAYER
|
BIN
examples/vizdoom/maps/D2_navigation.wad
Normal file
BIN
examples/vizdoom/maps/D2_navigation.wad
Normal file
Binary file not shown.
48
examples/vizdoom/maps/D3_battle.cfg
Normal file
48
examples/vizdoom/maps/D3_battle.cfg
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||||
|
# It doesn't matter if you use capital letters or not.
|
||||||
|
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||||
|
|
||||||
|
doom_scenario_path = D3_battle.wad
|
||||||
|
doom_map = map01
|
||||||
|
|
||||||
|
# Rewards
|
||||||
|
|
||||||
|
living_reward = 0
|
||||||
|
death_penalty = 100
|
||||||
|
|
||||||
|
# Rendering options
|
||||||
|
screen_resolution = RES_160X120
|
||||||
|
screen_format = GRAY8
|
||||||
|
render_hud = false
|
||||||
|
render_crosshair = true
|
||||||
|
render_weapon = true
|
||||||
|
render_decals = false
|
||||||
|
render_particles = false
|
||||||
|
window_visible = false
|
||||||
|
|
||||||
|
# make episodes finish after 10500 actions (tics)
|
||||||
|
episode_timeout = 10500
|
||||||
|
|
||||||
|
# Available buttons
|
||||||
|
available_buttons =
|
||||||
|
{
|
||||||
|
MOVE_FORWARD
|
||||||
|
MOVE_BACKWARD
|
||||||
|
MOVE_LEFT
|
||||||
|
MOVE_RIGHT
|
||||||
|
TURN_LEFT
|
||||||
|
TURN_RIGHT
|
||||||
|
ATTACK
|
||||||
|
SPEED
|
||||||
|
}
|
||||||
|
|
||||||
|
# Game variables that will be in the state
|
||||||
|
available_game_variables =
|
||||||
|
{
|
||||||
|
KILLCOUNT
|
||||||
|
AMMO2
|
||||||
|
HEALTH
|
||||||
|
}
|
||||||
|
|
||||||
|
mode = PLAYER
|
||||||
|
doom_skill = 2
|
BIN
examples/vizdoom/maps/D3_battle.wad
Normal file
BIN
examples/vizdoom/maps/D3_battle.wad
Normal file
Binary file not shown.
48
examples/vizdoom/maps/D4_battle2.cfg
Normal file
48
examples/vizdoom/maps/D4_battle2.cfg
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
# Lines starting with # are treated as comments (or with whitespaces+#).
|
||||||
|
# It doesn't matter if you use capital letters or not.
|
||||||
|
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
|
||||||
|
|
||||||
|
doom_scenario_path = D4_battle2.wad
|
||||||
|
doom_map = map01
|
||||||
|
|
||||||
|
# Rewards
|
||||||
|
|
||||||
|
living_reward = 0
|
||||||
|
death_penalty = 100
|
||||||
|
|
||||||
|
# Rendering options
|
||||||
|
screen_resolution = RES_160X120
|
||||||
|
screen_format = GRAY8
|
||||||
|
render_hud = false
|
||||||
|
render_crosshair = true
|
||||||
|
render_weapon = true
|
||||||
|
render_decals = false
|
||||||
|
render_particles = false
|
||||||
|
window_visible = false
|
||||||
|
|
||||||
|
# make episodes finish after 10500 actions (tics)
|
||||||
|
episode_timeout = 10500
|
||||||
|
|
||||||
|
# Available buttons
|
||||||
|
available_buttons =
|
||||||
|
{
|
||||||
|
MOVE_FORWARD
|
||||||
|
MOVE_BACKWARD
|
||||||
|
MOVE_LEFT
|
||||||
|
MOVE_RIGHT
|
||||||
|
TURN_LEFT
|
||||||
|
TURN_RIGHT
|
||||||
|
ATTACK
|
||||||
|
SPEED
|
||||||
|
}
|
||||||
|
|
||||||
|
# Game variables that will be in the state
|
||||||
|
available_game_variables =
|
||||||
|
{
|
||||||
|
KILLCOUNT
|
||||||
|
AMMO2
|
||||||
|
HEALTH
|
||||||
|
}
|
||||||
|
|
||||||
|
mode = PLAYER
|
||||||
|
doom_skill = 2
|
BIN
examples/vizdoom/maps/D4_battle2.wad
Normal file
BIN
examples/vizdoom/maps/D4_battle2.wad
Normal file
Binary file not shown.
3
examples/vizdoom/maps/README.md
Normal file
3
examples/vizdoom/maps/README.md
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
D1-D4 maps are from https://github.com/intel-isl/DirectFuturePrediction/
|
||||||
|
|
||||||
|
More maps and cfgs: https://github.com/mwydmuch/ViZDoom/tree/master/scenarios
|
71
examples/vizdoom/maps/spectator.py
Normal file
71
examples/vizdoom/maps/spectator.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
#####################################################################
|
||||||
|
# This script presents SPECTATOR mode. In SPECTATOR mode you play and
|
||||||
|
# your agent can learn from it.
|
||||||
|
# Configuration is loaded from "../../scenarios/<SCENARIO_NAME>.cfg" file.
|
||||||
|
#
|
||||||
|
# To see the scenario description go to "../../scenarios/README.md"
|
||||||
|
#####################################################################
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from time import sleep
|
||||||
|
import vizdoom as vzd
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
# import cv2
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.")
|
||||||
|
parser.add_argument('-c', type=str, dest="config", default="D3_battle.cfg")
|
||||||
|
parser.add_argument('-w', type=str, dest="wad_file", default="D3_battle.wad")
|
||||||
|
args = parser.parse_args()
|
||||||
|
game = vzd.DoomGame()
|
||||||
|
|
||||||
|
# Choose scenario config file you wish to watch.
|
||||||
|
# Don't load two configs cause the second will overrite the first one.
|
||||||
|
# Multiple config files are ok but combining these ones doesn't make much sense.
|
||||||
|
|
||||||
|
game.load_config(args.config)
|
||||||
|
game.set_doom_scenario_path(args.wad_file)
|
||||||
|
# Enables freelook in engine
|
||||||
|
game.add_game_args("+freelook 1")
|
||||||
|
|
||||||
|
game.set_screen_resolution(vzd.ScreenResolution.RES_640X480)
|
||||||
|
|
||||||
|
# Enables spectator mode, so you can play.
|
||||||
|
# Sounds strange but it is the agent who is supposed to watch not you.
|
||||||
|
game.set_window_visible(True)
|
||||||
|
game.set_mode(vzd.Mode.SPECTATOR)
|
||||||
|
|
||||||
|
game.init()
|
||||||
|
|
||||||
|
episodes = 1
|
||||||
|
|
||||||
|
for i in range(episodes):
|
||||||
|
print("Episode #" + str(i + 1))
|
||||||
|
|
||||||
|
game.new_episode()
|
||||||
|
while not game.is_episode_finished():
|
||||||
|
state = game.get_state()
|
||||||
|
print(state.screen_buffer.dtype, state.screen_buffer.shape)
|
||||||
|
# cv2.imwrite(f'imgs/{state.number}.png', state.screen_buffer)
|
||||||
|
|
||||||
|
# game.make_action([0, 0, 0])
|
||||||
|
game.advance_action()
|
||||||
|
last_action = game.get_last_action()
|
||||||
|
reward = game.get_last_reward()
|
||||||
|
|
||||||
|
print("State #" + str(state.number))
|
||||||
|
print("Game variables: ", state.game_variables)
|
||||||
|
print("Action:", last_action)
|
||||||
|
print("Reward:", reward)
|
||||||
|
print("=====================")
|
||||||
|
|
||||||
|
print("Episode finished!")
|
||||||
|
print("Total reward:", game.get_total_reward())
|
||||||
|
print("************************")
|
||||||
|
sleep(2.0)
|
||||||
|
|
||||||
|
game.close()
|
1
examples/vizdoom/network.py
Symbolic link
1
examples/vizdoom/network.py
Symbolic link
@ -0,0 +1 @@
|
|||||||
|
../atari/atari_network.py
|
35
examples/vizdoom/replay.py
Executable file
35
examples/vizdoom/replay.py
Executable file
@ -0,0 +1,35 @@
|
|||||||
|
# import cv2
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import tqdm
|
||||||
|
import vizdoom as vzd
|
||||||
|
|
||||||
|
|
||||||
|
def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp"):
|
||||||
|
game = vzd.DoomGame()
|
||||||
|
game.load_config(cfg_path)
|
||||||
|
game.set_screen_format(vzd.ScreenFormat.CRCGCB)
|
||||||
|
game.set_screen_resolution(vzd.ScreenResolution.RES_1024X576)
|
||||||
|
game.set_window_visible(True)
|
||||||
|
game.set_render_hud(True)
|
||||||
|
game.init()
|
||||||
|
game.replay_episode(lmp_path)
|
||||||
|
|
||||||
|
killcount = 0
|
||||||
|
with tqdm.trange(10500) as tq:
|
||||||
|
while not game.is_episode_finished():
|
||||||
|
game.advance_action()
|
||||||
|
state = game.get_state()
|
||||||
|
if state is None:
|
||||||
|
break
|
||||||
|
killcount = game.get_game_variable(vzd.GameVariable.KILLCOUNT)
|
||||||
|
time.sleep(1 / 35)
|
||||||
|
# cv2.imwrite(f"imgs/{tq.n}.png",
|
||||||
|
# state.screen_buffer.transpose(1, 2, 0)[..., ::-1])
|
||||||
|
tq.update(1)
|
||||||
|
game.close()
|
||||||
|
print("killcount:", killcount)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main(*sys.argv[-2:])
|
BIN
examples/vizdoom/results/c51/d3.lmp
Normal file
BIN
examples/vizdoom/results/c51/d3.lmp
Normal file
Binary file not shown.
BIN
examples/vizdoom/results/c51/d4.lmp
Normal file
BIN
examples/vizdoom/results/c51/d4.lmp
Normal file
Binary file not shown.
BIN
examples/vizdoom/results/c51/length.png
Normal file
BIN
examples/vizdoom/results/c51/length.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 111 KiB |
BIN
examples/vizdoom/results/c51/reward.png
Normal file
BIN
examples/vizdoom/results/c51/reward.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 90 KiB |
177
examples/vizdoom/vizdoom_c51.py
Normal file
177
examples/vizdoom/vizdoom_c51.py
Normal file
@ -0,0 +1,177 @@
|
|||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import pprint
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from tianshou.policy import C51Policy
|
||||||
|
from tianshou.utils import BasicLogger
|
||||||
|
from tianshou.env import SubprocVectorEnv
|
||||||
|
from tianshou.trainer import offpolicy_trainer
|
||||||
|
from tianshou.data import Collector, VectorReplayBuffer
|
||||||
|
|
||||||
|
from env import Env
|
||||||
|
from network import C51
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task', type=str, default='D1_basic')
|
||||||
|
parser.add_argument('--seed', type=int, default=0)
|
||||||
|
parser.add_argument('--eps-test', type=float, default=0.005)
|
||||||
|
parser.add_argument('--eps-train', type=float, default=1.)
|
||||||
|
parser.add_argument('--eps-train-final', type=float, default=0.05)
|
||||||
|
parser.add_argument('--buffer-size', type=int, default=2000000)
|
||||||
|
parser.add_argument('--lr', type=float, default=0.0001)
|
||||||
|
parser.add_argument('--gamma', type=float, default=0.99)
|
||||||
|
parser.add_argument('--num-atoms', type=int, default=51)
|
||||||
|
parser.add_argument('--v-min', type=float, default=-10.)
|
||||||
|
parser.add_argument('--v-max', type=float, default=10.)
|
||||||
|
parser.add_argument('--n-step', type=int, default=3)
|
||||||
|
parser.add_argument('--target-update-freq', type=int, default=500)
|
||||||
|
parser.add_argument('--epoch', type=int, default=300)
|
||||||
|
parser.add_argument('--step-per-epoch', type=int, default=100000)
|
||||||
|
parser.add_argument('--step-per-collect', type=int, default=10)
|
||||||
|
parser.add_argument('--update-per-step', type=float, default=0.1)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=64)
|
||||||
|
parser.add_argument('--training-num', type=int, default=10)
|
||||||
|
parser.add_argument('--test-num', type=int, default=100)
|
||||||
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
|
parser.add_argument('--render', type=float, default=0.)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device', type=str,
|
||||||
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
parser.add_argument('--frames-stack', type=int, default=4)
|
||||||
|
parser.add_argument('--skip-num', type=int, default=4)
|
||||||
|
parser.add_argument('--resume-path', type=str, default=None)
|
||||||
|
parser.add_argument('--watch', default=False, action='store_true',
|
||||||
|
help='watch the play of pre-trained policy only')
|
||||||
|
parser.add_argument('--save-lmp', default=False, action='store_true',
|
||||||
|
help='save lmp file for replay whole episode')
|
||||||
|
parser.add_argument('--save-buffer-name', type=str, default=None)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def test_c51(args=get_args()):
|
||||||
|
args.cfg_path = f"maps/{args.task}.cfg"
|
||||||
|
args.wad_path = f"maps/{args.task}.wad"
|
||||||
|
args.res = (args.skip_num, 84, 84)
|
||||||
|
env = Env(args.cfg_path, args.frames_stack, args.res)
|
||||||
|
args.state_shape = args.res
|
||||||
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
|
# should be N_FRAMES x H x W
|
||||||
|
print("Observations shape:", args.state_shape)
|
||||||
|
print("Actions shape:", args.action_shape)
|
||||||
|
# make environments
|
||||||
|
train_envs = SubprocVectorEnv([
|
||||||
|
lambda: Env(args.cfg_path, args.frames_stack, args.res)
|
||||||
|
for _ in range(args.training_num)])
|
||||||
|
test_envs = SubprocVectorEnv([
|
||||||
|
lambda: Env(args.cfg_path, args.frames_stack,
|
||||||
|
args.res, args.save_lmp)
|
||||||
|
for _ in range(min(os.cpu_count() - 1, args.test_num))])
|
||||||
|
# seed
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
train_envs.seed(args.seed)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
# define model
|
||||||
|
net = C51(*args.state_shape, args.action_shape,
|
||||||
|
args.num_atoms, args.device)
|
||||||
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
||||||
|
# define policy
|
||||||
|
policy = C51Policy(
|
||||||
|
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
|
||||||
|
args.n_step, target_update_freq=args.target_update_freq
|
||||||
|
).to(args.device)
|
||||||
|
# load a previous policy
|
||||||
|
if args.resume_path:
|
||||||
|
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||||
|
print("Loaded agent from: ", args.resume_path)
|
||||||
|
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
||||||
|
# when you have enough RAM
|
||||||
|
buffer = VectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
|
||||||
|
save_only_last_obs=True, stack_num=args.frames_stack)
|
||||||
|
# collector
|
||||||
|
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||||
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
|
# log
|
||||||
|
log_path = os.path.join(args.logdir, args.task, 'c51')
|
||||||
|
writer = SummaryWriter(log_path)
|
||||||
|
writer.add_text("args", str(args))
|
||||||
|
logger = BasicLogger(writer)
|
||||||
|
|
||||||
|
def save_fn(policy):
|
||||||
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||||
|
|
||||||
|
def stop_fn(mean_rewards):
|
||||||
|
if env.spec.reward_threshold:
|
||||||
|
return mean_rewards >= env.spec.reward_threshold
|
||||||
|
elif 'Pong' in args.task:
|
||||||
|
return mean_rewards >= 20
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def train_fn(epoch, env_step):
|
||||||
|
# nature DQN setting, linear decay in the first 1M steps
|
||||||
|
if env_step <= 1e6:
|
||||||
|
eps = args.eps_train - env_step / 1e6 * \
|
||||||
|
(args.eps_train - args.eps_train_final)
|
||||||
|
else:
|
||||||
|
eps = args.eps_train_final
|
||||||
|
policy.set_eps(eps)
|
||||||
|
logger.write('train/eps', env_step, eps)
|
||||||
|
|
||||||
|
def test_fn(epoch, env_step):
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
|
||||||
|
# watch agent's performance
|
||||||
|
def watch():
|
||||||
|
print("Setup test envs ...")
|
||||||
|
policy.eval()
|
||||||
|
policy.set_eps(args.eps_test)
|
||||||
|
test_envs.seed(args.seed)
|
||||||
|
if args.save_buffer_name:
|
||||||
|
print(f"Generate buffer with size {args.buffer_size}")
|
||||||
|
buffer = VectorReplayBuffer(
|
||||||
|
args.buffer_size, buffer_num=len(test_envs),
|
||||||
|
ignore_obs_next=True, save_only_last_obs=True,
|
||||||
|
stack_num=args.frames_stack)
|
||||||
|
collector = Collector(policy, test_envs, buffer,
|
||||||
|
exploration_noise=True)
|
||||||
|
result = collector.collect(n_step=args.buffer_size)
|
||||||
|
print(f"Save buffer into {args.save_buffer_name}")
|
||||||
|
# Unfortunately, pickle will cause oom with 1M buffer size
|
||||||
|
buffer.save_hdf5(args.save_buffer_name)
|
||||||
|
else:
|
||||||
|
print("Testing agent ...")
|
||||||
|
test_collector.reset()
|
||||||
|
result = test_collector.collect(n_episode=args.test_num,
|
||||||
|
render=args.render)
|
||||||
|
rew = result["rews"].mean()
|
||||||
|
lens = result["lens"].mean() * args.skip_num
|
||||||
|
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
|
||||||
|
print(f'Mean length (over {result["n/ep"]} episodes): {lens}')
|
||||||
|
|
||||||
|
if args.watch:
|
||||||
|
watch()
|
||||||
|
exit(0)
|
||||||
|
|
||||||
|
# test train_collector and start filling replay buffer
|
||||||
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
|
# trainer
|
||||||
|
result = offpolicy_trainer(
|
||||||
|
policy, train_collector, test_collector, args.epoch,
|
||||||
|
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||||
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
||||||
|
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||||
|
update_per_step=args.update_per_step, test_in_train=False)
|
||||||
|
|
||||||
|
pprint.pprint(result)
|
||||||
|
watch()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_c51(get_args())
|
2
setup.py
2
setup.py
@ -55,7 +55,7 @@ setup(
|
|||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
"Sphinx",
|
"sphinx<4",
|
||||||
"sphinx_rtd_theme",
|
"sphinx_rtd_theme",
|
||||||
"sphinxcontrib-bibtex",
|
"sphinxcontrib-bibtex",
|
||||||
"flake8",
|
"flake8",
|
||||||
|
@ -22,7 +22,7 @@ def get_args():
|
|||||||
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
parser.add_argument('--step-per-epoch', type=int, default=1000)
|
||||||
parser.add_argument('--episode-per-collect', type=int, default=1)
|
parser.add_argument('--episode-per-collect', type=int, default=1)
|
||||||
parser.add_argument('--training-num', type=int, default=1)
|
parser.add_argument('--training-num', type=int, default=1)
|
||||||
parser.add_argument('--test-num', type=int, default=100)
|
parser.add_argument('--test-num', type=int, default=10)
|
||||||
parser.add_argument('--logdir', type=str, default='log')
|
parser.add_argument('--logdir', type=str, default='log')
|
||||||
parser.add_argument('--render', type=float, default=0.0)
|
parser.add_argument('--render', type=float, default=0.0)
|
||||||
parser.add_argument('--rew-mean-prior', type=float, default=0.0)
|
parser.add_argument('--rew-mean-prior', type=float, default=0.0)
|
||||||
@ -36,12 +36,12 @@ def get_args():
|
|||||||
def test_psrl(args=get_args()):
|
def test_psrl(args=get_args()):
|
||||||
env = gym.make(args.task)
|
env = gym.make(args.task)
|
||||||
if args.task == "NChain-v0":
|
if args.task == "NChain-v0":
|
||||||
env.spec.reward_threshold = 3647 # described in PSRL paper
|
env.spec.reward_threshold = 3400
|
||||||
|
# env.spec.reward_threshold = 3647 # described in PSRL paper
|
||||||
print("reward threshold:", env.spec.reward_threshold)
|
print("reward threshold:", env.spec.reward_threshold)
|
||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
# train_envs = gym.make(args.task)
|
# train_envs = gym.make(args.task)
|
||||||
# train_envs = gym.make(args.task)
|
|
||||||
train_envs = DummyVectorEnv(
|
train_envs = DummyVectorEnv(
|
||||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||||
# test_envs = gym.make(args.task)
|
# test_envs = gym.make(args.task)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from tianshou import data, env, utils, policy, trainer, exploration
|
from tianshou import data, env, utils, policy, trainer, exploration
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.4.1"
|
__version__ = "0.4.2"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"env",
|
"env",
|
||||||
|
@ -526,7 +526,12 @@ class Batch:
|
|||||||
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
|
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
|
||||||
self.__dict__[k] = Batch.stack(v, axis)
|
self.__dict__[k] = Batch.stack(v, axis)
|
||||||
else: # most often case is np.ndarray
|
else: # most often case is np.ndarray
|
||||||
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
|
try:
|
||||||
|
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
|
||||||
|
except ValueError:
|
||||||
|
warnings.warn("You are using tensors with different shape,"
|
||||||
|
" fallback to dtype=object by default.")
|
||||||
|
self.__dict__[k] = np.array(v, dtype=object)
|
||||||
# all the keys
|
# all the keys
|
||||||
keys_total = set.union(*[set(b.keys()) for b in batches])
|
keys_total = set.union(*[set(b.keys()) for b in batches])
|
||||||
# keys that are reserved in all batches
|
# keys that are reserved in all batches
|
||||||
|
@ -53,6 +53,7 @@ class ReplayBuffer:
|
|||||||
self._save_only_last_obs = save_only_last_obs
|
self._save_only_last_obs = save_only_last_obs
|
||||||
self._sample_avail = sample_avail
|
self._sample_avail = sample_avail
|
||||||
self._meta: Batch = Batch()
|
self._meta: Batch = Batch()
|
||||||
|
self._ep_rew: Union[float, np.ndarray]
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
@ -56,7 +56,8 @@ class Collector(object):
|
|||||||
exploration_noise: bool = False,
|
exploration_noise: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not isinstance(env, BaseVectorEnv):
|
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||||
|
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||||
env = DummyVectorEnv([lambda: env])
|
env = DummyVectorEnv([lambda: env])
|
||||||
self.env = env
|
self.env = env
|
||||||
self.env_num = len(env)
|
self.env_num = len(env)
|
||||||
@ -223,7 +224,8 @@ class Collector(object):
|
|||||||
# get bounded and remapped actions first (not saved into buffer)
|
# get bounded and remapped actions first (not saved into buffer)
|
||||||
action_remap = self.policy.map_action(self.data.act)
|
action_remap = self.policy.map_action(self.data.act)
|
||||||
# step in env
|
# step in env
|
||||||
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
|
obs_next, rew, done, info = self.env.step(
|
||||||
|
action_remap, ready_env_ids) # type: ignore
|
||||||
|
|
||||||
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
@ -426,7 +428,8 @@ class AsyncCollector(Collector):
|
|||||||
# get bounded and remapped actions first (not saved into buffer)
|
# get bounded and remapped actions first (not saved into buffer)
|
||||||
action_remap = self.policy.map_action(self.data.act)
|
action_remap = self.policy.map_action(self.data.act)
|
||||||
# step in env
|
# step in env
|
||||||
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
|
obs_next, rew, done, info = self.env.step(
|
||||||
|
action_remap, ready_env_ids) # type: ignore
|
||||||
|
|
||||||
# change self.data here because ready_env_ids has changed
|
# change self.data here because ready_env_ids has changed
|
||||||
ready_env_ids = np.array([i["env_id"] for i in info])
|
ready_env_ids = np.array([i["env_id"] for i in info])
|
||||||
|
@ -75,8 +75,8 @@ class DiscreteCRRPolicy(PGPolicy):
|
|||||||
self._min_q_weight = min_q_weight
|
self._min_q_weight = min_q_weight
|
||||||
|
|
||||||
def sync_weight(self) -> None:
|
def sync_weight(self) -> None:
|
||||||
self.actor_old.load_state_dict(self.actor.state_dict()) # type: ignore
|
self.actor_old.load_state_dict(self.actor.state_dict())
|
||||||
self.critic_old.load_state_dict(self.critic.state_dict()) # type: ignore
|
self.critic_old.load_state_dict(self.critic.state_dict())
|
||||||
|
|
||||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore
|
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore
|
||||||
if self._target and self._iter % self._freq == 0:
|
if self._target and self._iter % self._freq == 0:
|
||||||
|
@ -72,7 +72,7 @@ class DQNPolicy(BasePolicy):
|
|||||||
|
|
||||||
def sync_weight(self) -> None:
|
def sync_weight(self) -> None:
|
||||||
"""Synchronize the weight for the target network."""
|
"""Synchronize the weight for the target network."""
|
||||||
self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
|
self.model_old.load_state_dict(self.model.state_dict())
|
||||||
|
|
||||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||||
|
@ -64,6 +64,7 @@ class TRPOPolicy(NPGPolicy):
|
|||||||
self._max_backtracks = max_backtracks
|
self._max_backtracks = max_backtracks
|
||||||
self._delta = max_kl
|
self._delta = max_kl
|
||||||
self._backtrack_coeff = backtrack_coeff
|
self._backtrack_coeff = backtrack_coeff
|
||||||
|
self._optim_critic_iters: int
|
||||||
|
|
||||||
def learn( # type: ignore
|
def learn( # type: ignore
|
||||||
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any
|
||||||
|
@ -88,7 +88,7 @@ class BasicLogger(BaseLogger):
|
|||||||
You can also rewrite write() func to use your own writer.
|
You can also rewrite write() func to use your own writer.
|
||||||
|
|
||||||
:param SummaryWriter writer: the writer to log data.
|
:param SummaryWriter writer: the writer to log data.
|
||||||
:param int train_interval: the log interval in log_train_data(). Default to 1.
|
:param int train_interval: the log interval in log_train_data(). Default to 1000.
|
||||||
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
:param int test_interval: the log interval in log_test_data(). Default to 1.
|
||||||
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
:param int update_interval: the log interval in log_update_data(). Default to 1000.
|
||||||
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
:param int save_interval: the save interval in save_data(). Default to 1 (save at
|
||||||
@ -98,7 +98,7 @@ class BasicLogger(BaseLogger):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
train_interval: int = 1,
|
train_interval: int = 1000,
|
||||||
test_interval: int = 1,
|
test_interval: int = 1,
|
||||||
update_interval: int = 1000,
|
update_interval: int = 1000,
|
||||||
save_interval: int = 1,
|
save_interval: int = 1,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user