471 Commits

Author SHA1 Message Date
ChenDRAG
150d0ec51b
Step collector implementation (#280)
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-19 10:33:49 +08:00
Trinkle23897
d918022ce9 merge master into dev 2021-02-18 12:46:55 +08:00
n+e
cb65b56b13
v0.3.2 (#292)
Throw a warning in ListReplayBuffer.

This version update is needed because of #289, the previous v0.3.1 cannot work well under torch<=1.6.0 with cuda environment.
v0.3.2
2021-02-16 09:31:46 +08:00
n+e
d003c8e566
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`

before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```

2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```

3. small optimization for `Batch.stack_` and `Batch.cat_`
2021-02-16 09:01:54 +08:00
ChenDRAG
f528131da1
hotfix:fix test failure in cuda environment (#289) 2021-02-09 17:13:40 +08:00
Trinkle23897
e3ee415b1a temporary fix numpy<1.20.0 (#281) 2021-02-08 12:59:37 +08:00
n+e
c838f2f0e9
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`

before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```

2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```

3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
ChenDRAG
f0129f4ca7
Add CachedReplayBuffer and ReplayBufferManager (#278)
This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail.

1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer;
2. Make sure the reserved keys cannot be edited by methods like `buffer.done = xxx`;
3. Add `set_batch` method for manually choosing the batch the ReplayBuffer wants to handle;
4. Add `sample_index` method, same as `sample` but only return index instead of both index and batch data;
5. Add `prev` (one-step previous transition index), `next` (one-step next transition index) and `unfinished_index` (the last modified index whose done==False);
6. Separate `alloc_fn` method for allocating new memory for `self._meta` when a new `(key, value)` pair comes in;
7. Move buffer's documentation to `docs/tutorials/concepts.rst`.

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-01-29 12:23:18 +08:00
wizardsheng
1eb6137645
Add QR-DQN algorithm (#276)
This is the PR for QR-DQN algorithm: https://arxiv.org/abs/1710.10044

1. add QR-DQN policy in tianshou/policy/modelfree/qrdqn.py.
2. add QR-DQN net in examples/atari/atari_network.py.
3. add QR-DQN atari example in examples/atari/atari_qrdqn.py.
4. add QR-DQN statement in tianshou/policy/init.py.
5. add QR-DQN unit test in test/discrete/test_qrdqn.py.
6. add QR-DQN atari results in examples/atari/results/qrdqn/.
7. add compute_q_value in DQNPolicy and C51Policy for simplify forward function.
8. move `with torch.no_grad():` from `_target_q` to BasePolicy

By running "python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '19.8 ± 0.40', in epoch 8.
2021-01-28 09:27:05 +08:00
Jialu Zhu
a511cb4779
Add offline trainer and discrete BCQ algorithm (#263)
The result needs to be tuned after `done` issue fixed.

Co-authored-by: n+e <trinkle23897@gmail.com>
v0.3.1
2021-01-20 18:13:04 +08:00
ChenDRAG
a633a6a028
update utils.network (#275)
This is the first commit of 6 commits mentioned in #274, which features

1. Refactor of `Class Net` to support any form of MLP.
2. Enable type check in utils.network.
3. Relative change in docs/test/examples.
4. Move atari-related network to examples/atari/atari_network.py

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2021-01-20 16:54:13 +08:00
蔡舒起
866e35d550
fix readme (#273) 2021-01-16 19:27:35 +08:00
wizardsheng
c6f2648e87
Add C51 algorithm (#266)
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887

1. add C51 policy in tianshou/policy/modelfree/c51.py.
2. add C51 net in tianshou/utils/net/discrete.py.
3. add C51 atari example in examples/atari/atari_c51.py.
4. add C51 statement in tianshou/policy/__init__.py.
5. add C51 test in test/discrete/test_c51.py.
6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get  best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
2021-01-06 10:17:45 +08:00
Nico Gürtler
5d13d8a453
Saving and loading replay buffer with HDF5 (#261)
As mentioned in #260, this pull request is about an implementation of saving and loading the replay buffer with HDF5.
2020-12-17 08:58:43 +08:00
Trinkle23897
cd481423dc sac mujoco result (#246) 2020-11-09 16:43:55 +08:00
rocknamx
c97aa4065e
add singleton pattern version of summary_writter (#230)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2020-10-31 16:38:54 +08:00
Trinkle23897
b364f1a26f specify the meaning of logits in documentation (#238) v0.3.0.post1 2020-10-08 23:16:15 +08:00
n+e
5ed6c1c7aa
change the step in trainer (#235)
This PR separates the `global_step` into `env_step` and `gradient_step`. In the future, the data from the collecting state will be stored under `env_step`, and the data from the updating state will be stored under `gradient_step`.

Others:
- add `rew_std` and `best_result` into the monitor
- fix network unbounded in `test/continuous/test_sac_with_il.py` and `examples/box2d/bipedal_hardcore_sac.py`
- change the dependency of ray to 1.0.0 since ray-project/ray#10134 has been resolved
2020-10-04 21:55:43 +08:00
n+e
710966eda7
change API of train_fn and test_fn (#229)
train_fn(epoch) -> train_fn(epoch, num_env_step)
test_fn(epoch) -> test_fn(epoch, num_env_step)
v0.3.0
2020-09-26 16:35:37 +08:00
n+e
d87d31a705
Update Anaconda support (#228)
conda install -c conda-forge tianshou
Related PR: conda-forge/staged-recipes#12719
2020-09-25 15:07:36 +08:00
Joshua Adelman
83bd1ec9e2
Add MANIFEST.in to include license file in source distribution (#227)
Related to Anaconda support (conda-forge/staged-recipes#12719)
2020-09-25 08:15:24 +08:00
Yao Feng
dcfcbb37f4
add PSRL policy (#202)
Add PSRL policy in tianshou/policy/modelbase/psrl.py.

Co-authored-by: n+e <trinkle23897@cmu.edu>
v0.3.0rc0
2020-09-23 20:57:33 +08:00
rocknamx
bf39b9ef7d
clarify updating state (#224)
Add an indicator(i.e. `self.learning`) of learning will be convenient for distinguishing state of policy.
Meanwhile, the state of `self.training` will be undisputed in the training stage.
Related issue: #211 

Others:
- fix a bug in DDQN: target_q could not be sampled from np.random.rand
- fix a bug in DQN atari net: it should add a ReLU before the last layer
- fix a bug in collector timing

Co-authored-by: n+e <463003665@qq.com>
2020-09-22 16:28:46 +08:00
n+e
eec0826fd3
change PER update interface in BasePolicy (#217)
* fix #215
2020-09-16 17:43:19 +08:00
n+e
623bf24f0c
fix unittest (#218) 2020-09-14 15:59:32 +08:00
danagi
a6ee979609
implement sac for discrete action settings (#216)
Co-authored-by: n+e <trinkle23897@cmu.edu>
2020-09-14 14:59:23 +08:00
n+e
b284ace102
type check in unit test (#200)
Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml.

Also remove the out-of-the-date api
2020-09-13 19:31:50 +08:00
n+e
c91def6cbc
code format and update function signatures (#213)
Cherry-pick from #200 

- update the function signature
- format code-style
- move _compile into separate functions
- fix a bug in to_torch and to_numpy (Batch)
- remove None in action_range

In short, the code-format only contains function-signature style and `'` -> `"`. (pick up from [black](https://github.com/psf/black))
2020-09-12 15:39:01 +08:00
danagi
16d8e9b051
SAC implementation update (#212)
- replace DiagGuassian with Independent(Normal) (pytorch has already supported this)
- detach alpha from autograd
- add value/alpha to result (more informational)
- revert #204 to fix #211

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-09-12 08:44:50 +08:00
n+e
b86d78766b
fix docs and add docstring check (#210)
- fix broken links and out-of-the-date content
- add pydocstyle and doc8 check
- remove collector.seed and collector.render
2020-09-11 07:55:37 +08:00
n+e
64af7ea839
fix critical bugs in MAPolicy and docs update (#207)
- fix a bug in MAPolicy: `buffer.rew = Batch()` doesn't change `buffer.rew` (thanks mypy)
- polish examples/box2d/bipedal_hardcore_sac.py
- several docs update
- format setup.py and bump version to 0.2.7
v0.2.7
2020-09-08 21:10:48 +08:00
n+e
380e9e911d
fix atari examples (#206) 2020-09-06 23:05:33 +08:00
n+e
8bb8ecba6e
set policy.eval() before collector.collect (#204)
* fix #203

* no_grad argument in collector.collect
2020-09-06 16:20:16 +08:00
Trinkle23897
34f714a677 Numba acceleration (#193)
Training FPS improvement (base commit is 94bfb32):
test_pdqn: 1660 (without numba) -> 1930
discrete/test_ppo: 5100 -> 5170

since nstep has little impact on overall performance, the unit test result is:
GAE: 4.1s -> 0.057s
nstep: 0.3s -> 0.15s (little improvement)

Others:
- fix a bug in ttt set_eps
- keep only sumtree in segment tree implementation
- dirty fix for asyncVenv check_id test
2020-09-02 13:03:32 +08:00
yingchengyang
5b49192a48
DQN Atari examples (#187)
This PR aims to provide the script of Atari DQN setting:
- A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps)
- A general script for all atari game
Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable.

It also adds another parameter save_only_last_obs for replay buffer in order to save the memory.

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-30 05:48:09 +08:00
n+e
94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00
youkaichao
a9f9940d17
code refactor for venv (#179)
- Refacor code to remove duplicate code

- Enable async simulation for all vector envs

- Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv`

The abstraction of vector env changed.

Prior to this pr, each vector env is almost independent.

After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility.

Co-authored-by: n+e <463003665@qq.com>
Co-authored-by: magicly <magicly007@gmail.com>
v0.2.6
2020-08-19 15:00:24 +08:00
n+e
311a2beafb
Pickle compatible for replay buffer and improve buffer.get (#182)
fix #84 and make buffer more efficient
2020-08-16 16:26:23 +08:00
youkaichao
7f3b817b24
add policy.update to enable post process and remove collector.sample (#180)
* add policy.update to enable post process and remove collector.sample

* update doc in policy concept

* remove collector.sample in doc

* doc update of concepts

* docs

* polish

* polish policy

* remove collector.sample in docs

* minor fix

* Apply suggestions from code review

just a test

* doc fix

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-15 16:10:42 +08:00
n+e
140b1c2cab
Improve PER (#159)
- use segment tree to rewrite the previous PrioReplayBuffer code, add the test

- enable all Q-learning algorithms to use PER
2020-08-06 10:26:24 +08:00
Imone
312b7551cc
Add BipedalWalkerHardcore-v3 SAC example (#177) 2020-08-05 10:29:41 +08:00
ChenDRAG
f2bcc55a25
ShmemVectorEnv Implementation (#174)
* add shmem vecenv, some add&fix in test_env

* generalize test_env IO

* pep8 fix

* comment update

* style change

* pep8 fix

* style fix

* minor fix

* fix a bug

* test fix

* change env

* testenv bug fix& shmem support recurse dict

* bugfix

* pep8 fix

* _NP_TO_CT enhance

* doc update

* docstring update

* pep8 fix

* style change

* style fix

* remove assert

* minor

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-04 13:39:05 +08:00
ChenDRAG
996e2f7c9b
Add profile workflow (#143)
* add a workflow to profile batch
* buffer profiling
* collector profiling

Co-authored-by: Trinkle23897 <463003665@qq.com>
Co-authored-by: Huayu Chen(陈华玉) <chenhuay17@gamil.com>
2020-08-02 18:24:40 +08:00
youkaichao
32df0567bb
use nn.Sequential in DQN (#176) 2020-08-02 15:14:44 +08:00
yingchengyang
99a1d40e85
Dueling DQN (#170)
Co-authored-by: n+e <463003665@qq.com>
2020-07-29 19:44:42 +08:00
youkaichao
ad395b5235
bugfix for test_async_env (#171) 2020-07-28 20:06:01 +08:00
Trinkle23897
b7a4015db7 doc update and do not force save 'policy' in np format (#168) 2020-07-27 16:54:14 +08:00
Alexis DUBURCQ
e024afab8c
Asynchronous sampling vector environment (#134)
Fix #103

Co-authored-by: youkaichao <youkaichao@126.com>
Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-07-26 18:01:21 +08:00
Alexis DUBURCQ
30368c29a6
Replay buffer allows stack_num = 1 (#165)
* stack_num starts at 1 (for no stacking) instead of 0.

* Use getter/stepper for stack_num.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-25 19:33:44 +08:00
n+e
38a95c19da
Yet another 3 fix (#160)
1. DQN learn should keep eps=0

2. Add a warning of env.seed in VecEnv

3. fix #162 of multi-dim action
2020-07-24 17:38:12 +08:00