add cross-platform test and release 0.4.1 (#331)
* bump to 0.4.1 * add cross-platform test
This commit is contained in:
parent
09692c84fe
commit
825da9bc53
4
.github/ISSUE_TEMPLATE.md
vendored
4
.github/ISSUE_TEMPLATE.md
vendored
@ -7,6 +7,6 @@
|
|||||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||||
```python
|
```python
|
||||||
import tianshou, torch, sys
|
import tianshou, torch, numpy, sys
|
||||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||||
```
|
```
|
||||||
|
|||||||
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -11,6 +11,6 @@ Less important but also useful:
|
|||||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||||
```python
|
```python
|
||||||
import tianshou, torch, sys
|
import tianshou, torch, numpy, sys
|
||||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||||
```
|
```
|
||||||
|
|||||||
27
.github/workflows/extra_sys.yml
vendored
Normal file
27
.github/workflows/extra_sys.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
name: Unittest
|
||||||
|
|
||||||
|
on: [push, pull_request]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [macos-latest, windows-latest]
|
||||||
|
python-version: [3.6, 3.7, 3.8]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v2
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Upgrade pip
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip setuptools wheel
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install ".[dev]" --upgrade
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
pytest test/base test/continuous --ignore-glob "*env.py" --cov=tianshou --durations=0 -v
|
||||||
6
.github/workflows/pytest.yml
vendored
6
.github/workflows/pytest.yml
vendored
@ -15,12 +15,6 @@ jobs:
|
|||||||
uses: actions/setup-python@v2
|
uses: actions/setup-python@v2
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
# - uses: actions/cache@v2
|
|
||||||
# with:
|
|
||||||
# path: /opt/hostedtoolcache/Python/
|
|
||||||
# key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
|
||||||
# restore-keys: |
|
|
||||||
# ${{ runner.os }}-${{ matrix.python-version }}-
|
|
||||||
- name: Upgrade pip
|
- name: Upgrade pip
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip setuptools wheel
|
python -m pip install --upgrade pip setuptools wheel
|
||||||
|
|||||||
@ -73,10 +73,10 @@ class MyTestEnv(gym.Env):
|
|||||||
elif self.recurse_state:
|
elif self.recurse_state:
|
||||||
return {'index': np.array([self.index], dtype=np.float32),
|
return {'index': np.array([self.index], dtype=np.float32),
|
||||||
'dict': {"tuple": (np.array([1],
|
'dict': {"tuple": (np.array([1],
|
||||||
dtype=np.int64), self.rng.rand(2)),
|
dtype=int), self.rng.rand(2)),
|
||||||
"rand": self.rng.rand(1, 2)}}
|
"rand": self.rng.rand(1, 2)}}
|
||||||
elif self.array_state:
|
elif self.array_state:
|
||||||
img = np.zeros([4, 84, 84], np.int)
|
img = np.zeros([4, 84, 84], int)
|
||||||
img[3, np.arange(84), np.arange(84)] = self.index
|
img[3, np.arange(84), np.arange(84)] = self.index
|
||||||
img[2, np.arange(84)] = self.index
|
img[2, np.arange(84)] = self.index
|
||||||
img[1, :, np.arange(84)] = self.index
|
img[1, :, np.arange(84)] = self.index
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import copy
|
import copy
|
||||||
import torch
|
import torch
|
||||||
import pickle
|
import pickle
|
||||||
@ -373,7 +374,10 @@ def test_batch_over_batch_to_torch():
|
|||||||
assert batch.a.dtype == torch.float64
|
assert batch.a.dtype == torch.float64
|
||||||
assert batch.b.c.dtype == torch.float32
|
assert batch.b.c.dtype == torch.float32
|
||||||
assert batch.b.d.dtype == torch.float64
|
assert batch.b.d.dtype == torch.float64
|
||||||
assert batch.b.e.dtype == torch.int64
|
if sys.platform in ["win32", "cygwin"]: # windows
|
||||||
|
assert batch.b.e.dtype == torch.int32
|
||||||
|
else:
|
||||||
|
assert batch.b.e.dtype == torch.int64
|
||||||
batch.to_torch(dtype=torch.float32)
|
batch.to_torch(dtype=torch.float32)
|
||||||
assert batch.a.dtype == torch.float32
|
assert batch.a.dtype == torch.float32
|
||||||
assert batch.b.c.dtype == torch.float32
|
assert batch.b.c.dtype == torch.float32
|
||||||
@ -439,7 +443,10 @@ def test_utils_to_torch_numpy():
|
|||||||
assert to_numpy(to_numpy).item() == to_numpy
|
assert to_numpy(to_numpy).item() == to_numpy
|
||||||
# additional test for to_torch, for code-coverage
|
# additional test for to_torch, for code-coverage
|
||||||
assert isinstance(to_torch(1), torch.Tensor)
|
assert isinstance(to_torch(1), torch.Tensor)
|
||||||
assert to_torch(1).dtype == torch.int64
|
if sys.platform in ["win32", "cygwin"]: # windows
|
||||||
|
assert to_torch(1).dtype == torch.int32
|
||||||
|
else:
|
||||||
|
assert to_torch(1).dtype == torch.int64
|
||||||
assert to_torch(1.).dtype == torch.float64
|
assert to_torch(1.).dtype == torch.float64
|
||||||
assert isinstance(to_torch({'a': [1]})['a'], torch.Tensor)
|
assert isinstance(to_torch({'a': [1]})['a'], torch.Tensor)
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
|
|||||||
@ -13,8 +13,8 @@ else: # pytest
|
|||||||
|
|
||||||
def has_ray():
|
def has_ray():
|
||||||
try:
|
try:
|
||||||
import ray
|
import ray # noqa: F401
|
||||||
return hasattr(ray, 'init') # avoid PEP8 F401 Error
|
return True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ def has_ray():
|
|||||||
def recurse_comp(a, b):
|
def recurse_comp(a, b):
|
||||||
try:
|
try:
|
||||||
if isinstance(a, np.ndarray):
|
if isinstance(a, np.ndarray):
|
||||||
if a.dtype == np.object:
|
if a.dtype == object:
|
||||||
return np.array(
|
return np.array(
|
||||||
[recurse_comp(m, n) for m, n in zip(a, b)]).all()
|
[recurse_comp(m, n) for m, n in zip(a, b)]).all()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -122,7 +122,7 @@ def target_q_fn_multidim(buffer, indice):
|
|||||||
|
|
||||||
|
|
||||||
def compute_nstep_return_base(nstep, gamma, buffer, indice):
|
def compute_nstep_return_base(nstep, gamma, buffer, indice):
|
||||||
returns = np.zeros_like(indice, dtype=np.float)
|
returns = np.zeros_like(indice, dtype=float)
|
||||||
buf_len = len(buffer)
|
buf_len = len(buffer)
|
||||||
for i in range(len(indice)):
|
for i in range(len(indice)):
|
||||||
flag, r = False, 0.
|
flag, r = False, 0.
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from tianshou import data, env, utils, policy, trainer, exploration
|
from tianshou import data, env, utils, policy, trainer, exploration
|
||||||
|
|
||||||
|
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.4.1"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"env",
|
"env",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user