add cross-platform test and release 0.4.1 (#331)
* bump to 0.4.1 * add cross-platform test
This commit is contained in:
parent
09692c84fe
commit
825da9bc53
4
.github/ISSUE_TEMPLATE.md
vendored
4
.github/ISSUE_TEMPLATE.md
vendored
@ -7,6 +7,6 @@
|
||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, torch, sys
|
||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||
import tianshou, torch, numpy, sys
|
||||
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||
```
|
||||
|
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
4
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -11,6 +11,6 @@ Less important but also useful:
|
||||
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
|
||||
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
|
||||
```python
|
||||
import tianshou, torch, sys
|
||||
print(tianshou.__version__, torch.__version__, sys.version, sys.platform)
|
||||
import tianshou, torch, numpy, sys
|
||||
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
|
||||
```
|
||||
|
27
.github/workflows/extra_sys.yml
vendored
Normal file
27
.github/workflows/extra_sys.yml
vendored
Normal file
@ -0,0 +1,27 @@
|
||||
name: Unittest
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ${{ matrix.os }}
|
||||
if: "!contains(github.event.head_commit.message, 'ci skip')"
|
||||
strategy:
|
||||
matrix:
|
||||
os: [macos-latest, windows-latest]
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Upgrade pip
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install ".[dev]" --upgrade
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest test/base test/continuous --ignore-glob "*env.py" --cov=tianshou --durations=0 -v
|
6
.github/workflows/pytest.yml
vendored
6
.github/workflows/pytest.yml
vendored
@ -15,12 +15,6 @@ jobs:
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
# - uses: actions/cache@v2
|
||||
# with:
|
||||
# path: /opt/hostedtoolcache/Python/
|
||||
# key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}
|
||||
# restore-keys: |
|
||||
# ${{ runner.os }}-${{ matrix.python-version }}-
|
||||
- name: Upgrade pip
|
||||
run: |
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
@ -73,10 +73,10 @@ class MyTestEnv(gym.Env):
|
||||
elif self.recurse_state:
|
||||
return {'index': np.array([self.index], dtype=np.float32),
|
||||
'dict': {"tuple": (np.array([1],
|
||||
dtype=np.int64), self.rng.rand(2)),
|
||||
dtype=int), self.rng.rand(2)),
|
||||
"rand": self.rng.rand(1, 2)}}
|
||||
elif self.array_state:
|
||||
img = np.zeros([4, 84, 84], np.int)
|
||||
img = np.zeros([4, 84, 84], int)
|
||||
img[3, np.arange(84), np.arange(84)] = self.index
|
||||
img[2, np.arange(84)] = self.index
|
||||
img[1, :, np.arange(84)] = self.index
|
||||
|
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import copy
|
||||
import torch
|
||||
import pickle
|
||||
@ -373,7 +374,10 @@ def test_batch_over_batch_to_torch():
|
||||
assert batch.a.dtype == torch.float64
|
||||
assert batch.b.c.dtype == torch.float32
|
||||
assert batch.b.d.dtype == torch.float64
|
||||
assert batch.b.e.dtype == torch.int64
|
||||
if sys.platform in ["win32", "cygwin"]: # windows
|
||||
assert batch.b.e.dtype == torch.int32
|
||||
else:
|
||||
assert batch.b.e.dtype == torch.int64
|
||||
batch.to_torch(dtype=torch.float32)
|
||||
assert batch.a.dtype == torch.float32
|
||||
assert batch.b.c.dtype == torch.float32
|
||||
@ -439,7 +443,10 @@ def test_utils_to_torch_numpy():
|
||||
assert to_numpy(to_numpy).item() == to_numpy
|
||||
# additional test for to_torch, for code-coverage
|
||||
assert isinstance(to_torch(1), torch.Tensor)
|
||||
assert to_torch(1).dtype == torch.int64
|
||||
if sys.platform in ["win32", "cygwin"]: # windows
|
||||
assert to_torch(1).dtype == torch.int32
|
||||
else:
|
||||
assert to_torch(1).dtype == torch.int64
|
||||
assert to_torch(1.).dtype == torch.float64
|
||||
assert isinstance(to_torch({'a': [1]})['a'], torch.Tensor)
|
||||
with pytest.raises(TypeError):
|
||||
|
@ -13,8 +13,8 @@ else: # pytest
|
||||
|
||||
def has_ray():
|
||||
try:
|
||||
import ray
|
||||
return hasattr(ray, 'init') # avoid PEP8 F401 Error
|
||||
import ray # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@ -22,7 +22,7 @@ def has_ray():
|
||||
def recurse_comp(a, b):
|
||||
try:
|
||||
if isinstance(a, np.ndarray):
|
||||
if a.dtype == np.object:
|
||||
if a.dtype == object:
|
||||
return np.array(
|
||||
[recurse_comp(m, n) for m, n in zip(a, b)]).all()
|
||||
else:
|
||||
|
@ -122,7 +122,7 @@ def target_q_fn_multidim(buffer, indice):
|
||||
|
||||
|
||||
def compute_nstep_return_base(nstep, gamma, buffer, indice):
|
||||
returns = np.zeros_like(indice, dtype=np.float)
|
||||
returns = np.zeros_like(indice, dtype=float)
|
||||
buf_len = len(buffer)
|
||||
for i in range(len(indice)):
|
||||
flag, r = False, 0.
|
||||
|
@ -1,7 +1,7 @@
|
||||
from tianshou import data, env, utils, policy, trainer, exploration
|
||||
|
||||
|
||||
__version__ = "0.4.0"
|
||||
__version__ = "0.4.1"
|
||||
|
||||
__all__ = [
|
||||
"env",
|
||||
|
Loading…
x
Reference in New Issue
Block a user