add vizdoom example, bump version to 0.4.2 (#384)

This commit is contained in:
n+e 2021-06-26 18:08:41 +08:00 committed by GitHub
parent c0bc8e00ca
commit ebaca6f8da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 683 additions and 17 deletions

View File

@ -24,7 +24,7 @@ jobs:
- name: Test with pytest
# ignore test/throughput which only profiles the code
run: |
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --durations=0 -v
pytest test --ignore-glob='*profile.py' --cov=tianshou --cov-report=xml --cov-report=term-missing --durations=0 -v
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v1
with:

View File

@ -6,7 +6,7 @@
[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/)
[![Conda](https://img.shields.io/conda/vn/conda-forge/tianshou)](https://github.com/conda-forge/tianshou-feedstock)
[![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/latest)
[![Read the Docs](https://img.shields.io/readthedocs/tianshou)](https://tianshou.readthedocs.io/en/master)
[![Read the Docs](https://img.shields.io/readthedocs/tianshou-docs-zh-cn?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://tianshou.readthedocs.io/zh/latest/)
[![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions)
[![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou)
@ -14,7 +14,6 @@
[![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network)
[![GitHub license](https://img.shields.io/github/license/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/blob/master/LICENSE)
[![Gitter](https://badges.gitter.im/thu-ml/tianshou.svg)](https://gitter.im/thu-ml/tianshou?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge)
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:

1
examples/vizdoom/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
_vizdoom.ini

View File

@ -0,0 +1,66 @@
# ViZDoom
[ViZDoom](https://github.com/mwydmuch/ViZDoom) is a popular RL env for a famous first-person shooting game Doom. Here we provide some results and intuitions for this scenario.
## Train
To train an agent:
```bash
python3 vizdoom_c51.py --task {D1_basic|D3_battle|D4_battle2}
```
D1 (health gathering) should finish training (no death) in less than 500k env step (5 epochs);
D3 can reach 1600+ reward (75+ killcount in 5 minutes);
D4 can reach 700+ reward. Here is the result:
(episode length, the maximum length is 2625 because we use frameskip=4, that is 10500/4=2625)
![](results/c51/length.png)
(episode reward)
![](results/c51/reward.png)
To evaluate an agent's performance:
```bash
python3 vizdoom_c51.py --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2}
```
To save `.lmp` files for recording:
```bash
python3 vizdoom_c51.py --save-lmp --test-num 100 --resume-path policy.pth --watch --task {D1_basic|D3_battle|D4_battle2}
```
it will store `lmp` file in `lmps/` directory. To watch these `lmp` files (for example, d3 lmp):
```bash
python3 replay.py maps/D3_battle.cfg episode_8_25.lmp
```
We provide two lmp files (d3 best and d4 best) under `results/c51`, you can use the following command to enjoy:
```bash
python3 replay.py maps/D3_battle.cfg results/c51/d3.lmp
python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp
```
## Maps
See [maps/README.md](maps/README.md)
## Algorithms
The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.
## Reward
1. living reward is bad
2. combo-action is really important
3. negative reward for health and ammo2 is really helpful for d3/d4
4. only with positive reward for health is really helpful for d1
5. remove MOVE_BACKWARD may converge faster but the final performance may be lower

129
examples/vizdoom/env.py Normal file
View File

@ -0,0 +1,129 @@
import os
import cv2
import gym
import numpy as np
import vizdoom as vzd
def normal_button_comb():
actions = []
m_forward = [[0.0], [1.0]]
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
for i in m_forward:
for j in t_left_right:
actions.append(i + j)
return actions
def battle_button_comb():
actions = []
m_forward_backward = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
m_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
t_left_right = [[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]]
attack = [[0.0], [1.0]]
speed = [[0.0], [1.0]]
for m in attack:
for n in speed:
for j in m_left_right:
for i in m_forward_backward:
for k in t_left_right:
actions.append(i + j + k + m + n)
return actions
class Env(gym.Env):
def __init__(
self, cfg_path, frameskip=4, res=(4, 40, 60), save_lmp=False
):
super().__init__()
self.save_lmp = save_lmp
self.health_setting = "battle" in cfg_path
if save_lmp:
os.makedirs("lmps", exist_ok=True)
self.res = res
self.skip = frameskip
self.observation_space = gym.spaces.Box(
low=0, high=255, shape=res, dtype=np.float32
)
self.game = vzd.DoomGame()
self.game.load_config(cfg_path)
self.game.init()
if "battle" in cfg_path:
self.available_actions = battle_button_comb()
else:
self.available_actions = normal_button_comb()
self.action_num = len(self.available_actions)
self.action_space = gym.spaces.Discrete(self.action_num)
self.spec = gym.envs.registration.EnvSpec("vizdoom-v0")
self.count = 0
def get_obs(self):
state = self.game.get_state()
if state is None:
return
obs = state.screen_buffer
self.obs_buffer[:-1] = self.obs_buffer[1:]
self.obs_buffer[-1] = cv2.resize(obs, (self.res[-1], self.res[-2]))
def reset(self):
if self.save_lmp:
self.game.new_episode(f"lmps/episode_{self.count}.lmp")
else:
self.game.new_episode()
self.count += 1
self.obs_buffer = np.zeros(self.res, dtype=np.uint8)
self.get_obs()
self.health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
self.killcount = self.game.get_game_variable(
vzd.GameVariable.KILLCOUNT)
self.ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
return self.obs_buffer
def step(self, action):
self.game.make_action(self.available_actions[action], self.skip)
reward = 0.0
self.get_obs()
health = self.game.get_game_variable(vzd.GameVariable.HEALTH)
if self.health_setting:
reward += health - self.health
elif health > self.health: # positive health reward only for d1/d2
reward += health - self.health
self.health = health
killcount = self.game.get_game_variable(vzd.GameVariable.KILLCOUNT)
reward += 20 * (killcount - self.killcount)
self.killcount = killcount
ammo2 = self.game.get_game_variable(vzd.GameVariable.AMMO2)
# if ammo2 > self.ammo2:
reward += ammo2 - self.ammo2
self.ammo2 = ammo2
done = False
info = {}
if self.game.is_player_dead() or self.game.get_state() is None:
done = True
elif self.game.is_episode_finished():
done = True
info["TimeLimit.truncated"] = True
return self.obs_buffer, reward, done, info
def render(self):
pass
def close(self):
self.game.close()
if __name__ == '__main__':
# env = Env("maps/D1_basic.cfg", 4, (4, 84, 84))
env = Env("maps/D3_battle.cfg", 4, (4, 84, 84))
print(env.available_actions)
action_num = env.action_space.n
obs = env.reset()
print(env.spec.reward_threshold)
print(obs.shape, action_num)
for i in range(4000):
obs, rew, done, info = env.step(0)
if done:
env.reset()
print(obs.shape, rew, done)
cv2.imwrite("test.png", obs.transpose(1, 2, 0)[..., :3])

View File

@ -0,0 +1,39 @@
# Lines starting with # are treated as comments (or with whitespaces+#).
# It doesn't matter if you use capital letters or not.
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
doom_scenario_path = D1_basic.wad
doom_map = map01
# Rewards
# Each step is good for you!
living_reward = 0
# And death is not!
death_penalty = 0
# Rendering options
screen_resolution = RES_160X120
screen_format = GRAY8
render_hud = false
render_crosshair = false
render_weapon = false
render_decals = false
render_particles = false
window_visible = false
# make episodes finish after 10500 actions (tics)
episode_timeout = 10500
# Available buttons
available_buttons =
{
MOVE_FORWARD
TURN_LEFT
TURN_RIGHT
}
# Game variables that will be in the state
available_game_variables = { HEALTH }
mode = PLAYER

Binary file not shown.

View File

@ -0,0 +1,39 @@
# Lines starting with # are treated as comments (or with whitespaces+#).
# It doesn't matter if you use capital letters or not.
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
doom_scenario_path = D2_navigation.wad
doom_map = map01
# Rewards
# Each step is good for you!
living_reward = 0
# And death is not!
death_penalty = 0
# Rendering options
screen_resolution = RES_160X120
screen_format = GRAY8
render_hud = false
render_crosshair = false
render_weapon = false
render_decals = false
render_particles = false
window_visible = false
# make episodes finish after 10500 actions (tics)
episode_timeout = 10500
# Available buttons
available_buttons =
{
MOVE_FORWARD
TURN_LEFT
TURN_RIGHT
}
# Game variables that will be in the state
available_game_variables = { HEALTH }
mode = PLAYER

Binary file not shown.

View File

@ -0,0 +1,48 @@
# Lines starting with # are treated as comments (or with whitespaces+#).
# It doesn't matter if you use capital letters or not.
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
doom_scenario_path = D3_battle.wad
doom_map = map01
# Rewards
living_reward = 0
death_penalty = 100
# Rendering options
screen_resolution = RES_160X120
screen_format = GRAY8
render_hud = false
render_crosshair = true
render_weapon = true
render_decals = false
render_particles = false
window_visible = false
# make episodes finish after 10500 actions (tics)
episode_timeout = 10500
# Available buttons
available_buttons =
{
MOVE_FORWARD
MOVE_BACKWARD
MOVE_LEFT
MOVE_RIGHT
TURN_LEFT
TURN_RIGHT
ATTACK
SPEED
}
# Game variables that will be in the state
available_game_variables =
{
KILLCOUNT
AMMO2
HEALTH
}
mode = PLAYER
doom_skill = 2

Binary file not shown.

View File

@ -0,0 +1,48 @@
# Lines starting with # are treated as comments (or with whitespaces+#).
# It doesn't matter if you use capital letters or not.
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout.
doom_scenario_path = D4_battle2.wad
doom_map = map01
# Rewards
living_reward = 0
death_penalty = 100
# Rendering options
screen_resolution = RES_160X120
screen_format = GRAY8
render_hud = false
render_crosshair = true
render_weapon = true
render_decals = false
render_particles = false
window_visible = false
# make episodes finish after 10500 actions (tics)
episode_timeout = 10500
# Available buttons
available_buttons =
{
MOVE_FORWARD
MOVE_BACKWARD
MOVE_LEFT
MOVE_RIGHT
TURN_LEFT
TURN_RIGHT
ATTACK
SPEED
}
# Game variables that will be in the state
available_game_variables =
{
KILLCOUNT
AMMO2
HEALTH
}
mode = PLAYER
doom_skill = 2

Binary file not shown.

View File

@ -0,0 +1,3 @@
D1-D4 maps are from https://github.com/intel-isl/DirectFuturePrediction/
More maps and cfgs: https://github.com/mwydmuch/ViZDoom/tree/master/scenarios

View File

@ -0,0 +1,71 @@
#!/usr/bin/env python3
#####################################################################
# This script presents SPECTATOR mode. In SPECTATOR mode you play and
# your agent can learn from it.
# Configuration is loaded from "../../scenarios/<SCENARIO_NAME>.cfg" file.
#
# To see the scenario description go to "../../scenarios/README.md"
#####################################################################
from __future__ import print_function
from time import sleep
import vizdoom as vzd
from argparse import ArgumentParser
# import cv2
if __name__ == "__main__":
parser = ArgumentParser("ViZDoom example showing how to use SPECTATOR mode.")
parser.add_argument('-c', type=str, dest="config", default="D3_battle.cfg")
parser.add_argument('-w', type=str, dest="wad_file", default="D3_battle.wad")
args = parser.parse_args()
game = vzd.DoomGame()
# Choose scenario config file you wish to watch.
# Don't load two configs cause the second will overrite the first one.
# Multiple config files are ok but combining these ones doesn't make much sense.
game.load_config(args.config)
game.set_doom_scenario_path(args.wad_file)
# Enables freelook in engine
game.add_game_args("+freelook 1")
game.set_screen_resolution(vzd.ScreenResolution.RES_640X480)
# Enables spectator mode, so you can play.
# Sounds strange but it is the agent who is supposed to watch not you.
game.set_window_visible(True)
game.set_mode(vzd.Mode.SPECTATOR)
game.init()
episodes = 1
for i in range(episodes):
print("Episode #" + str(i + 1))
game.new_episode()
while not game.is_episode_finished():
state = game.get_state()
print(state.screen_buffer.dtype, state.screen_buffer.shape)
# cv2.imwrite(f'imgs/{state.number}.png', state.screen_buffer)
# game.make_action([0, 0, 0])
game.advance_action()
last_action = game.get_last_action()
reward = game.get_last_reward()
print("State #" + str(state.number))
print("Game variables: ", state.game_variables)
print("Action:", last_action)
print("Reward:", reward)
print("=====================")
print("Episode finished!")
print("Total reward:", game.get_total_reward())
print("************************")
sleep(2.0)
game.close()

1
examples/vizdoom/network.py Symbolic link
View File

@ -0,0 +1 @@
../atari/atari_network.py

35
examples/vizdoom/replay.py Executable file
View File

@ -0,0 +1,35 @@
# import cv2
import sys
import time
import tqdm
import vizdoom as vzd
def main(cfg_path="maps/D3_battle.cfg", lmp_path="test.lmp"):
game = vzd.DoomGame()
game.load_config(cfg_path)
game.set_screen_format(vzd.ScreenFormat.CRCGCB)
game.set_screen_resolution(vzd.ScreenResolution.RES_1024X576)
game.set_window_visible(True)
game.set_render_hud(True)
game.init()
game.replay_episode(lmp_path)
killcount = 0
with tqdm.trange(10500) as tq:
while not game.is_episode_finished():
game.advance_action()
state = game.get_state()
if state is None:
break
killcount = game.get_game_variable(vzd.GameVariable.KILLCOUNT)
time.sleep(1 / 35)
# cv2.imwrite(f"imgs/{tq.n}.png",
# state.screen_buffer.transpose(1, 2, 0)[..., ::-1])
tq.update(1)
game.close()
print("killcount:", killcount)
if __name__ == '__main__':
main(*sys.argv[-2:])

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 111 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

View File

@ -0,0 +1,177 @@
import os
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import C51Policy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from env import Env
from network import C51
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='D1_basic')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--eps-test', type=float, default=0.005)
parser.add_argument('--eps-train', type=float, default=1.)
parser.add_argument('--eps-train-final', type=float, default=0.05)
parser.add_argument('--buffer-size', type=int, default=2000000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--num-atoms', type=int, default=51)
parser.add_argument('--v-min', type=float, default=-10.)
parser.add_argument('--v-max', type=float, default=10.)
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=500)
parser.add_argument('--epoch', type=int, default=300)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--step-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument('--skip-num', type=int, default=4)
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
parser.add_argument('--save-lmp', default=False, action='store_true',
help='save lmp file for replay whole episode')
parser.add_argument('--save-buffer-name', type=str, default=None)
return parser.parse_args()
def test_c51(args=get_args()):
args.cfg_path = f"maps/{args.task}.cfg"
args.wad_path = f"maps/{args.task}.wad"
args.res = (args.skip_num, 84, 84)
env = Env(args.cfg_path, args.frames_stack, args.res)
args.state_shape = args.res
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
train_envs = SubprocVectorEnv([
lambda: Env(args.cfg_path, args.frames_stack, args.res)
for _ in range(args.training_num)])
test_envs = SubprocVectorEnv([
lambda: Env(args.cfg_path, args.frames_stack,
args.res, args.save_lmp)
for _ in range(min(os.cpu_count() - 1, args.test_num))])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# define model
net = C51(*args.state_shape, args.action_shape,
args.num_atoms, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = C51Policy(
net, optim, args.gamma, args.num_atoms, args.v_min, args.v_max,
args.n_step, target_update_freq=args.target_update_freq
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# replay buffer: `save_last_obs` and `stack_num` can be removed together
# when you have enough RAM
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(train_envs), ignore_obs_next=True,
save_only_last_obs=True, stack_num=args.frames_stack)
# collector
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'c51')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False
def train_fn(epoch, env_step):
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = args.eps_train - env_step / 1e6 * \
(args.eps_train - args.eps_train_final)
else:
eps = args.eps_train_final
policy.set_eps(eps)
logger.write('train/eps', env_step, eps)
def test_fn(epoch, env_step):
policy.set_eps(args.eps_test)
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
if args.save_buffer_name:
print(f"Generate buffer with size {args.buffer_size}")
buffer = VectorReplayBuffer(
args.buffer_size, buffer_num=len(test_envs),
ignore_obs_next=True, save_only_last_obs=True,
stack_num=args.frames_stack)
collector = Collector(policy, test_envs, buffer,
exploration_noise=True)
result = collector.collect(n_step=args.buffer_size)
print(f"Save buffer into {args.save_buffer_name}")
# Unfortunately, pickle will cause oom with 1M buffer size
buffer.save_hdf5(args.save_buffer_name)
else:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
rew = result["rews"].mean()
lens = result["lens"].mean() * args.skip_num
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
print(f'Mean length (over {result["n/ep"]} episodes): {lens}')
if args.watch:
watch()
exit(0)
# test train_collector and start filling replay buffer
train_collector.collect(n_step=args.batch_size * args.training_num)
# trainer
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, logger=logger,
update_per_step=args.update_per_step, test_in_train=False)
pprint.pprint(result)
watch()
if __name__ == '__main__':
test_c51(get_args())

View File

@ -55,7 +55,7 @@ setup(
],
extras_require={
"dev": [
"Sphinx",
"sphinx<4",
"sphinx_rtd_theme",
"sphinxcontrib-bibtex",
"flake8",

View File

@ -22,7 +22,7 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--episode-per-collect', type=int, default=1)
parser.add_argument('--training-num', type=int, default=1)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.0)
parser.add_argument('--rew-mean-prior', type=float, default=0.0)
@ -36,12 +36,12 @@ def get_args():
def test_psrl(args=get_args()):
env = gym.make(args.task)
if args.task == "NChain-v0":
env.spec.reward_threshold = 3647 # described in PSRL paper
env.spec.reward_threshold = 3400
# env.spec.reward_threshold = 3647 # described in PSRL paper
print("reward threshold:", env.spec.reward_threshold)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)])
# test_envs = gym.make(args.task)

View File

@ -1,7 +1,7 @@
from tianshou import data, env, utils, policy, trainer, exploration
__version__ = "0.4.1"
__version__ = "0.4.2"
__all__ = [
"env",

View File

@ -526,7 +526,12 @@ class Batch:
elif all(isinstance(e, (Batch, dict)) for e in v): # third often
self.__dict__[k] = Batch.stack(v, axis)
else: # most often case is np.ndarray
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
try:
self.__dict__[k] = _to_array_with_correct_type(np.stack(v, axis))
except ValueError:
warnings.warn("You are using tensors with different shape,"
" fallback to dtype=object by default.")
self.__dict__[k] = np.array(v, dtype=object)
# all the keys
keys_total = set.union(*[set(b.keys()) for b in batches])
# keys that are reserved in all batches

View File

@ -53,6 +53,7 @@ class ReplayBuffer:
self._save_only_last_obs = save_only_last_obs
self._sample_avail = sample_avail
self._meta: Batch = Batch()
self._ep_rew: Union[float, np.ndarray]
self.reset()
def __len__(self) -> int:

View File

@ -56,7 +56,8 @@ class Collector(object):
exploration_noise: bool = False,
) -> None:
super().__init__()
if not isinstance(env, BaseVectorEnv):
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
env = DummyVectorEnv([lambda: env])
self.env = env
self.env_num = len(env)
@ -223,7 +224,8 @@ class Collector(object):
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
obs_next, rew, done, info = self.env.step(
action_remap, ready_env_ids) # type: ignore
self.data.update(obs_next=obs_next, rew=rew, done=done, info=info)
if self.preprocess_fn:
@ -426,7 +428,8 @@ class AsyncCollector(Collector):
# get bounded and remapped actions first (not saved into buffer)
action_remap = self.policy.map_action(self.data.act)
# step in env
obs_next, rew, done, info = self.env.step(action_remap, id=ready_env_ids)
obs_next, rew, done, info = self.env.step(
action_remap, ready_env_ids) # type: ignore
# change self.data here because ready_env_ids has changed
ready_env_ids = np.array([i["env_id"] for i in info])

View File

@ -75,8 +75,8 @@ class DiscreteCRRPolicy(PGPolicy):
self._min_q_weight = min_q_weight
def sync_weight(self) -> None:
self.actor_old.load_state_dict(self.actor.state_dict()) # type: ignore
self.critic_old.load_state_dict(self.critic.state_dict()) # type: ignore
self.actor_old.load_state_dict(self.actor.state_dict())
self.critic_old.load_state_dict(self.critic.state_dict())
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]: # type: ignore
if self._target and self._iter % self._freq == 0:

View File

@ -72,7 +72,7 @@ class DQNPolicy(BasePolicy):
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
self.model_old.load_state_dict(self.model.state_dict())
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}

View File

@ -64,6 +64,7 @@ class TRPOPolicy(NPGPolicy):
self._max_backtracks = max_backtracks
self._delta = max_kl
self._backtrack_coeff = backtrack_coeff
self._optim_critic_iters: int
def learn( # type: ignore
self, batch: Batch, batch_size: int, repeat: int, **kwargs: Any

View File

@ -88,7 +88,7 @@ class BasicLogger(BaseLogger):
You can also rewrite write() func to use your own writer.
:param SummaryWriter writer: the writer to log data.
:param int train_interval: the log interval in log_train_data(). Default to 1.
:param int train_interval: the log interval in log_train_data(). Default to 1000.
:param int test_interval: the log interval in log_test_data(). Default to 1.
:param int update_interval: the log interval in log_update_data(). Default to 1000.
:param int save_interval: the save interval in save_data(). Default to 1 (save at
@ -98,7 +98,7 @@ class BasicLogger(BaseLogger):
def __init__(
self,
writer: SummaryWriter,
train_interval: int = 1,
train_interval: int = 1000,
test_interval: int = 1,
update_interval: int = 1000,
save_interval: int = 1,