update version to 0.5.0 (#826)
This commit is contained in:
parent
73600edc58
commit
f0afdeaf6a
2
.github/ISSUE_TEMPLATE.md
vendored
2
.github/ISSUE_TEMPLATE.md
vendored
@ -7,6 +7,6 @@
|
|||||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||||
```python
|
```python
|
||||||
import tianshou, gym, torch, numpy, sys
|
import tianshou, gymnasium as gym, torch, numpy, sys
|
||||||
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||||
```
|
```
|
||||||
|
6
.github/workflows/extra_sys.yml
vendored
6
.github/workflows/extra_sys.yml
vendored
@ -12,12 +12,12 @@ jobs:
|
|||||||
python-version: [3.7, 3.8]
|
python-version: [3.7, 3.8]
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous run
|
- name: Cancel previous run
|
||||||
uses: styfle/cancel-workflow-action@0.9.1
|
uses: styfle/cancel-workflow-action@0.11.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Upgrade pip
|
- name: Upgrade pip
|
||||||
|
6
.github/workflows/gputest.yml
vendored
6
.github/workflows/gputest.yml
vendored
@ -8,12 +8,12 @@ jobs:
|
|||||||
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous run
|
- name: Cancel previous run
|
||||||
uses: styfle/cancel-workflow-action@0.9.1
|
uses: styfle/cancel-workflow-action@0.11.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.8
|
- name: Set up Python 3.8
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
- name: Upgrade pip
|
- name: Upgrade pip
|
||||||
|
6
.github/workflows/lint_and_docs.yml
vendored
6
.github/workflows/lint_and_docs.yml
vendored
@ -7,12 +7,12 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous run
|
- name: Cancel previous run
|
||||||
uses: styfle/cancel-workflow-action@0.9.1
|
uses: styfle/cancel-workflow-action@0.11.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python 3.8
|
- name: Set up Python 3.8
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.8
|
python-version: 3.8
|
||||||
- name: Upgrade pip
|
- name: Upgrade pip
|
||||||
|
27
.github/workflows/profile.yml
vendored
27
.github/workflows/profile.yml
vendored
@ -1,27 +0,0 @@
|
|||||||
name: Data Profile
|
|
||||||
|
|
||||||
on: [push, pull_request]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
profile:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
|
||||||
steps:
|
|
||||||
- name: Cancel previous run
|
|
||||||
uses: styfle/cancel-workflow-action@0.9.1
|
|
||||||
with:
|
|
||||||
access_token: ${{ github.token }}
|
|
||||||
- uses: actions/checkout@v2
|
|
||||||
- name: Set up Python 3.8
|
|
||||||
uses: actions/setup-python@v2
|
|
||||||
with:
|
|
||||||
python-version: 3.8
|
|
||||||
- name: Upgrade pip
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip setuptools wheel
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install ".[dev]" --upgrade
|
|
||||||
- name: Test with pytest
|
|
||||||
run: |
|
|
||||||
pytest test/throughput --durations=0 -v --color=yes
|
|
6
.github/workflows/pytest.yml
vendored
6
.github/workflows/pytest.yml
vendored
@ -11,12 +11,12 @@ jobs:
|
|||||||
python-version: [3.7, 3.8, 3.9]
|
python-version: [3.7, 3.8, 3.9]
|
||||||
steps:
|
steps:
|
||||||
- name: Cancel previous run
|
- name: Cancel previous run
|
||||||
uses: styfle/cancel-workflow-action@0.9.1
|
uses: styfle/cancel-workflow-action@0.11.0
|
||||||
with:
|
with:
|
||||||
access_token: ${{ github.token }}
|
access_token: ${{ github.token }}
|
||||||
- uses: actions/checkout@v2
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
- name: Upgrade pip
|
- name: Upgrade pip
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
gym
|
gym
|
||||||
numba
|
numba
|
||||||
numpy>=1.20
|
numpy>=1.20
|
||||||
sphinx<4
|
sphinx
|
||||||
sphinxcontrib-bibtex
|
sphinxcontrib-bibtex
|
||||||
tensorboard
|
tensorboard
|
||||||
torch
|
torch
|
||||||
tqdm
|
tqdm
|
||||||
protobuf~=3.19.0
|
protobuf
|
||||||
pettingzoo
|
pettingzoo
|
||||||
|
@ -40,7 +40,7 @@ def get_args():
|
|||||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||||
parser.add_argument("--test-num", type=int, default=100)
|
parser.add_argument("--test-num", type=int, default=100)
|
||||||
parser.add_argument("--logdir", type=str, default="log")
|
parser.add_argument("--logdir", type=str, default="log")
|
||||||
parser.add_argument("--render", type=float, default=0.)
|
parser.add_argument("--render", type=float, default=0.0)
|
||||||
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
|
parser.add_argument("--load-buffer-name", type=str, default=expert_file_name())
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device",
|
"--device",
|
||||||
@ -59,7 +59,7 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||||
args.action_shape = env.action_space.shape or env.action_space.n
|
args.action_shape = env.action_space.shape or env.action_space.n
|
||||||
if args.reward_threshold is None:
|
if args.reward_threshold is None:
|
||||||
default_reward_threshold = {"CartPole-v0": 190}
|
default_reward_threshold = {"CartPole-v0": 185}
|
||||||
args.reward_threshold = default_reward_threshold.get(
|
args.reward_threshold = default_reward_threshold.get(
|
||||||
args.task, env.spec.reward_threshold
|
args.task, env.spec.reward_threshold
|
||||||
)
|
)
|
||||||
@ -123,7 +123,8 @@ def test_discrete_bcq(args=get_args()):
|
|||||||
{
|
{
|
||||||
"model": policy.state_dict(),
|
"model": policy.state_dict(),
|
||||||
"optim": optim.state_dict(),
|
"optim": optim.state_dict(),
|
||||||
}, ckpt_path
|
},
|
||||||
|
ckpt_path,
|
||||||
)
|
)
|
||||||
return ckpt_path
|
return ckpt_path
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from tianshou import data, env, exploration, policy, trainer, utils
|
from tianshou import data, env, exploration, policy, trainer, utils
|
||||||
|
|
||||||
__version__ = "0.4.11"
|
__version__ = "0.5.0"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"env",
|
"env",
|
||||||
|
49
tianshou/env/venvs.py
vendored
49
tianshou/env/venvs.py
vendored
@ -5,7 +5,6 @@ import gymnasium as gym
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import packaging
|
import packaging
|
||||||
|
|
||||||
from tianshou.env.pettingzoo_env import PettingZooEnv
|
|
||||||
from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type
|
from tianshou.env.utils import ENV_TYPE, gym_new_venv_step_type
|
||||||
from tianshou.env.worker import (
|
from tianshou.env.worker import (
|
||||||
DummyEnvWorker,
|
DummyEnvWorker,
|
||||||
@ -14,8 +13,14 @@ from tianshou.env.worker import (
|
|||||||
SubprocEnvWorker,
|
SubprocEnvWorker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tianshou.env.pettingzoo_env import PettingZooEnv
|
||||||
|
except ImportError:
|
||||||
|
PettingZooEnv = None # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import gym as old_gym
|
import gym as old_gym
|
||||||
|
|
||||||
has_old_gym = True
|
has_old_gym = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
has_old_gym = False
|
has_old_gym = False
|
||||||
@ -152,11 +157,13 @@ class BaseVectorEnv(object):
|
|||||||
|
|
||||||
self.env_num = len(env_fns)
|
self.env_num = len(env_fns)
|
||||||
self.wait_num = wait_num or len(env_fns)
|
self.wait_num = wait_num or len(env_fns)
|
||||||
assert 1 <= self.wait_num <= len(env_fns), \
|
assert (
|
||||||
f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
1 <= self.wait_num <= len(env_fns)
|
||||||
|
), f"wait_num should be in [1, {len(env_fns)}], but got {wait_num}"
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
assert self.timeout is None or self.timeout > 0, \
|
assert (
|
||||||
f"timeout is {timeout}, it should be positive if provided!"
|
self.timeout is None or self.timeout > 0
|
||||||
|
), f"timeout is {timeout}, it should be positive if provided!"
|
||||||
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
self.is_async = self.wait_num != len(env_fns) or timeout is not None
|
||||||
self.waiting_conn: List[EnvWorker] = []
|
self.waiting_conn: List[EnvWorker] = []
|
||||||
# environments in self.ready_id is actually ready
|
# environments in self.ready_id is actually ready
|
||||||
@ -169,8 +176,9 @@ class BaseVectorEnv(object):
|
|||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
|
|
||||||
def _assert_is_not_closed(self) -> None:
|
def _assert_is_not_closed(self) -> None:
|
||||||
assert not self.is_closed, \
|
assert (
|
||||||
f"Methods of {self.__class__.__name__} cannot be called after close."
|
not self.is_closed
|
||||||
|
), f"Methods of {self.__class__.__name__} cannot be called after close."
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Return len(self), which is the number of environments."""
|
"""Return len(self), which is the number of environments."""
|
||||||
@ -245,10 +253,12 @@ class BaseVectorEnv(object):
|
|||||||
|
|
||||||
def _assert_id(self, id: Union[List[int], np.ndarray]) -> None:
|
def _assert_id(self, id: Union[List[int], np.ndarray]) -> None:
|
||||||
for i in id:
|
for i in id:
|
||||||
assert i not in self.waiting_id, \
|
assert (
|
||||||
f"Cannot interact with environment {i} which is stepping now."
|
i not in self.waiting_id
|
||||||
assert i in self.ready_id, \
|
), f"Cannot interact with environment {i} which is stepping now."
|
||||||
f"Can only interact with ready environments {self.ready_id}."
|
assert (
|
||||||
|
i in self.ready_id
|
||||||
|
), f"Can only interact with ready environments {self.ready_id}."
|
||||||
|
|
||||||
def reset(
|
def reset(
|
||||||
self,
|
self,
|
||||||
@ -271,9 +281,10 @@ class BaseVectorEnv(object):
|
|||||||
self.workers[i].send(None, **kwargs)
|
self.workers[i].send(None, **kwargs)
|
||||||
ret_list = [self.workers[i].recv() for i in id]
|
ret_list = [self.workers[i].recv() for i in id]
|
||||||
|
|
||||||
assert isinstance(ret_list[0], (tuple, list)) and len(
|
assert (
|
||||||
ret_list[0]
|
isinstance(ret_list[0], (tuple, list)) and len(ret_list[0]) == 2
|
||||||
) == 2 and isinstance(ret_list[0][1], dict)
|
and isinstance(ret_list[0][1], dict)
|
||||||
|
)
|
||||||
|
|
||||||
obs_list = [r[0] for r in ret_list]
|
obs_list = [r[0] for r in ret_list]
|
||||||
|
|
||||||
@ -367,9 +378,13 @@ class BaseVectorEnv(object):
|
|||||||
obs_stack = np.stack(obs_list)
|
obs_stack = np.stack(obs_list)
|
||||||
except ValueError: # different len(obs)
|
except ValueError: # different len(obs)
|
||||||
obs_stack = np.array(obs_list, dtype=object)
|
obs_stack = np.array(obs_list, dtype=object)
|
||||||
return obs_stack, np.stack(rew_list), np.stack(term_list), np.stack(
|
return (
|
||||||
trunc_list
|
obs_stack,
|
||||||
), np.stack(info_list)
|
np.stack(rew_list),
|
||||||
|
np.stack(term_list),
|
||||||
|
np.stack(trunc_list),
|
||||||
|
np.stack(info_list),
|
||||||
|
)
|
||||||
|
|
||||||
def seed(
|
def seed(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user