56 Commits

Author SHA1 Message Date
n+e
ec23c7efe9
fix qvalue mask_action error for obs_next (#310)
* fix #309
* remove for-loop in dqn expl_noise
2021-03-15 08:06:24 +08:00
ChenDRAG
f22b539761
Remove reward_normaliztion option in offpolicy algorithm (#298)
* remove rew_norm in nstep implementation
* improve test
* remove runnable/
* various doc fix

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-27 11:20:43 +08:00
ChenDRAG
3108b9db0d
Add Timelimit trick to optimize policies (#296)
* consider timelimit.truncated in calculating returns by default
* remove ignore_done
2021-02-26 13:23:18 +08:00
ChenDRAG
9b61bc620c add logger (#295)
This PR focus on refactor of logging method to solve bug of nan reward and log interval. After these two pr, hopefully fundamental change of tianshou/data is finished. We then can concentrate on building benchmarks of tianshou finally.

Things changed:

1. trainer now accepts logger (BasicLogger or LazyLogger) instead of writer;
2. remove utils.SummaryWriter;
2021-02-24 14:48:42 +08:00
ChenDRAG
7036073649
Trainer refactor : some definition change (#293)
This PR focus on some definition change of trainer to make it more friendly to use and be consistent with typical usage in research papers, typically change `collect-per-step` to `step-per-collect`, add `update-per-step` / `episode-per-collect` accordingly, and modify the documentation.
2021-02-21 13:06:02 +08:00
ChenDRAG
150d0ec51b
Step collector implementation (#280)
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-19 10:33:49 +08:00
Trinkle23897
d918022ce9 merge master into dev 2021-02-18 12:46:55 +08:00
ChenDRAG
f528131da1
hotfix:fix test failure in cuda environment (#289) 2021-02-09 17:13:40 +08:00
ChenDRAG
f0129f4ca7
Add CachedReplayBuffer and ReplayBufferManager (#278)
This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail.

1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer;
2. Make sure the reserved keys cannot be edited by methods like `buffer.done = xxx`;
3. Add `set_batch` method for manually choosing the batch the ReplayBuffer wants to handle;
4. Add `sample_index` method, same as `sample` but only return index instead of both index and batch data;
5. Add `prev` (one-step previous transition index), `next` (one-step next transition index) and `unfinished_index` (the last modified index whose done==False);
6. Separate `alloc_fn` method for allocating new memory for `self._meta` when a new `(key, value)` pair comes in;
7. Move buffer's documentation to `docs/tutorials/concepts.rst`.

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-01-29 12:23:18 +08:00
wizardsheng
1eb6137645
Add QR-DQN algorithm (#276)
This is the PR for QR-DQN algorithm: https://arxiv.org/abs/1710.10044

1. add QR-DQN policy in tianshou/policy/modelfree/qrdqn.py.
2. add QR-DQN net in examples/atari/atari_network.py.
3. add QR-DQN atari example in examples/atari/atari_qrdqn.py.
4. add QR-DQN statement in tianshou/policy/init.py.
5. add QR-DQN unit test in test/discrete/test_qrdqn.py.
6. add QR-DQN atari results in examples/atari/results/qrdqn/.
7. add compute_q_value in DQNPolicy and C51Policy for simplify forward function.
8. move `with torch.no_grad():` from `_target_q` to BasePolicy

By running "python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '19.8 ± 0.40', in epoch 8.
2021-01-28 09:27:05 +08:00
Jialu Zhu
a511cb4779
Add offline trainer and discrete BCQ algorithm (#263)
The result needs to be tuned after `done` issue fixed.

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-01-20 18:13:04 +08:00
ChenDRAG
a633a6a028
update utils.network (#275)
This is the first commit of 6 commits mentioned in #274, which features

1. Refactor of `Class Net` to support any form of MLP.
2. Enable type check in utils.network.
3. Relative change in docs/test/examples.
4. Move atari-related network to examples/atari/atari_network.py

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2021-01-20 16:54:13 +08:00
wizardsheng
c6f2648e87
Add C51 algorithm (#266)
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887

1. add C51 policy in tianshou/policy/modelfree/c51.py.
2. add C51 net in tianshou/utils/net/discrete.py.
3. add C51 atari example in examples/atari/atari_c51.py.
4. add C51 statement in tianshou/policy/__init__.py.
5. add C51 test in test/discrete/test_c51.py.
6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get  best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
2021-01-06 10:17:45 +08:00
Trinkle23897
cd481423dc sac mujoco result (#246) 2020-11-09 16:43:55 +08:00
rocknamx
c97aa4065e
add singleton pattern version of summary_writter (#230)
Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2020-10-31 16:38:54 +08:00
n+e
710966eda7
change API of train_fn and test_fn (#229)
train_fn(epoch) -> train_fn(epoch, num_env_step)
test_fn(epoch) -> test_fn(epoch, num_env_step)
2020-09-26 16:35:37 +08:00
danagi
a6ee979609
implement sac for discrete action settings (#216)
Co-authored-by: n+e <trinkle23897@cmu.edu>
2020-09-14 14:59:23 +08:00
Trinkle23897
34f714a677 Numba acceleration (#193)
Training FPS improvement (base commit is 94bfb32):
test_pdqn: 1660 (without numba) -> 1930
discrete/test_ppo: 5100 -> 5170

since nstep has little impact on overall performance, the unit test result is:
GAE: 4.1s -> 0.057s
nstep: 0.3s -> 0.15s (little improvement)

Others:
- fix a bug in ttt set_eps
- keep only sumtree in segment tree implementation
- dirty fix for asyncVenv check_id test
2020-09-02 13:03:32 +08:00
n+e
94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00
youkaichao
a9f9940d17
code refactor for venv (#179)
- Refacor code to remove duplicate code

- Enable async simulation for all vector envs

- Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv`

The abstraction of vector env changed.

Prior to this pr, each vector env is almost independent.

After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility.

Co-authored-by: n+e <463003665@qq.com>
Co-authored-by: magicly <magicly007@gmail.com>
2020-08-19 15:00:24 +08:00
n+e
140b1c2cab
Improve PER (#159)
- use segment tree to rewrite the previous PrioReplayBuffer code, add the test

- enable all Q-learning algorithms to use PER
2020-08-06 10:26:24 +08:00
yingchengyang
99a1d40e85
Dueling DQN (#170)
Co-authored-by: n+e <463003665@qq.com>
2020-07-29 19:44:42 +08:00
n+e
38a95c19da
Yet another 3 fix (#160)
1. DQN learn should keep eps=0

2. Add a warning of env.seed in VecEnv

3. fix #162 of multi-dim action
2020-07-24 17:38:12 +08:00
youkaichao
8c32d99c65
Add multi-agent example: tic-tac-toe (#122)
* make fileds with empty Batch rather than None after reset

* dummy code

* remove dummy

* add reward_length argument for collector

* Improve Batch (#126)

* make sure the key type of Batch is string, and add unit tests

* add is_empty() function and unit tests

* enable cat of mixing dict and Batch, just like stack

* bugfix for reward_length

* add get_final_reward_fn argument to collector to deal with marl

* minor polish

* remove multibuf

* minor polish

* improve and implement Batch.cat_

* bugfix for buffer.sample with field impt_weight

* restore the usage of a.cat_(b)

* fix 2 bugs in batch and add corresponding unittest

* code fix for update

* update is_empty to recognize empty over empty; bugfix for len

* bugfix for update and add testcase

* add testcase of update

* make fileds with empty Batch rather than None after reset

* dummy code

* remove dummy

* add reward_length argument for collector

* bugfix for reward_length

* add get_final_reward_fn argument to collector to deal with marl

* make sure the key type of Batch is string, and add unit tests

* add is_empty() function and unit tests

* enable cat of mixing dict and Batch, just like stack

* dummy code

* remove dummy

* add multi-agent example: tic-tac-toe

* move TicTacToeEnv to a separate file

* remove dummy MANet

* code refactor

* move tic-tac-toe example to test

* update doc with marl-example

* fix docs

* reduce the threshold

* revert

* update player id to start from 1 and change player to agent; keep coding

* add reward_length argument for collector

* Improve Batch (#128)

* minor polish

* improve and implement Batch.cat_

* bugfix for buffer.sample with field impt_weight

* restore the usage of a.cat_(b)

* fix 2 bugs in batch and add corresponding unittest

* code fix for update

* update is_empty to recognize empty over empty; bugfix for len

* bugfix for update and add testcase

* add testcase of update

* fix docs

* fix docs

* fix docs [ci skip]

* fix docs [ci skip]

Co-authored-by: Trinkle23897 <463003665@qq.com>

* refact

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reward_metric

* modify flag

* minor fix

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix reward stat in collector

* fix stat of collector, simplify test/base/env.py

* fix docs

* minor fix

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

* minor fix

* marl-examples

* add condense; bugfix for torch.Tensor; code refactor

* marl example can run now

* enable tic tac toe with larger board size and win-size

* add test dependency

* Fix padding of inconsistent keys with Batch.stack and Batch.cat (#130)

* re-implement Batch.stack and add testcases

* add doc for Batch.stack

* reuse _create_values and refactor stack_ & cat_

* fix pep8

* fix docs

* raise exception for stacking with partial keys and axis!=0

* minor fix

* minor fix

Co-authored-by: Trinkle23897 <463003665@qq.com>

* stash

* let agent learn to play as agent 2 which is harder

* code refactor

* Improve collector (#125)

* remove multibuf

* reward_metric

* make fileds with empty Batch rather than None after reset

* many fixes and refactor
Co-authored-by: Trinkle23897 <463003665@qq.com>

* marl for tic-tac-toe and general gomoku

* update default gamma to 0.1 for tic tac toe to win earlier

* fix name typo; change default game config; add rew_norm option

* fix pep8

* test commit

* mv test dir name

* add rew flag

* fix torch.optim import error and madqn rew_norm

* remove useless kwargs

* Vector env enable select worker (#132)

* Enable selecting worker for vector env step method.

* Update collector to match new vecenv selective worker behavior.

* Bug fix.

* Fix rebase

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>

* show the last move of tictactoe by capital letters

* add multi-agent tutorial

* fix link

* Standardized behavior of Batch.cat and misc code refactor (#137)

* code refactor; remove unused kwargs; add reward_normalization for dqn

* bugfix for __setitem__ with torch.Tensor; add Batch.condense

* minor fix

* support cat with empty Batch

* remove the dependency of is_empty on len; specify the semantic of empty Batch by test cases

* support stack with empty Batch

* remove condense

* refactor code to reflect the shared / partial / reserved categories of keys

* add is_empty(recursive=False)

* doc fix

* docfix and bugfix for _is_batch_set

* add doc for key reservation

* bugfix for algebra operators

* fix cat with lens hint

* code refactor

* bugfix for storing None

* use ValueError instead of exception

* hide lens away from users

* add comment for __cat

* move the computation of the initial value of lens in cat_ itself.

* change the place of doc string

* doc fix for Batch doc string

* change recursive to recurse

* doc string fix

* minor fix for batch doc

* write tutorials to specify the standard of Batch (#142)

* add doc for len exceptions

* doc move; unify is_scalar_value function

* remove some issubclass check

* bugfix for shape of Batch(a=1)

* keep moving doc

* keep writing batch tutorial

* draft version of Batch tutorial done

* improving doc

* keep improving doc

* batch tutorial done

* rename _is_number

* rename _is_scalar

* shape property do not raise exception

* restore some doc string

* grammarly [ci skip]

* grammarly + fix warning of building docs

* polish docs

* trim and re-arrange batch tutorial

* go straight to the point

* minor fix for batch doc

* add shape / len in basic usage

* keep improving tutorial

* unify _to_array_with_correct_type to remove duplicate code

* delegate type convertion to Batch.__init__

* further delegate type convertion to Batch.__init__

* bugfix for setattr

* add a _parse_value function

* remove dummy function call

* polish docs

Co-authored-by: Trinkle23897 <463003665@qq.com>

* bugfix for mapolicy

* pretty code

* remove debug code; remove condense

* doc fix

* check before get_agents in tutorials/tictactoe

* tutorial

* fix

* minor fix for batch doc

* minor polish

* faster test_ttt

* improve tic-tac-toe environment

* change default epoch and step-per-epoch for tic-tac-toe

* fix mapolicy

* minor polish for mapolicy

* 90% to 80% (need to change the tutorial)

* win rate

* show step number at board

* simplify mapolicy

* minor polish for mapolicy

* remove MADQN

* fix pep8

* change legal_actions to mask (need to update docs)

* simplify maenv

* fix typo

* move basevecenv to single file

* separate RandomAgent

* update docs

* grammarly

* fix pep8

* win rate typo

* format in cheatsheet

* use bool mask directly

* update doc for boolean mask

Co-authored-by: Trinkle23897 <463003665@qq.com>
Co-authored-by: Alexis DUBURCQ <alexis.duburcq@gmail.com>
Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-07-21 14:59:49 +08:00
n+e
47e8e2686c
move atari wrapper to examples and publish v0.2.4 (#124)
* move atari wrapper to examples

* consistency

* change drqn seed since it is quite unstable in current seed

* minor fix

* 0.2.4
2020-07-10 17:20:39 +08:00
youkaichao
e767de044b
Remove dummy net code (#123)
* remove dummy net; delete two files

* split code to have backbone and head

* rename class

* change torch.float to torch.float32

* use flatten(1) instead of view(batch, -1)

* remove dummy net in docs

* bugfix for rnn

* fix cuda error

* minor fix of docs

* do not change the example code in dqn tutorial, since it is for demonstration

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-07-09 22:57:01 +08:00
rocknamx
506cc97ba5
fix #91 (#94) 2020-06-25 07:02:59 +08:00
Trinkle23897
81e4a16ef2 fix a bug in re-index replay buffer (fix #82) 2020-06-17 16:37:51 +08:00
Trinkle23897
dc451dfe88 nstep all (fix #51) 2020-06-03 13:59:47 +08:00
Trinkle23897
ff81a18f42 compute_nstep_returns (item 2 of #51) 2020-06-02 22:29:50 +08:00
Trinkle23897
70122dc03d oinit with 0 bias 2020-05-17 17:06:20 +08:00
Trinkle23897
3271c92609 orthogonal init for ppo in test script 2020-05-16 20:27:01 +08:00
Trinkle23897
0eef0ca198 fix optional type syntax 2020-05-16 20:08:32 +08:00
Trinkle23897
c2a7caf806 add recurrent actor and critic 2020-04-30 16:31:40 +08:00
Trinkle23897
959955fa2a fix historical issues 2020-04-26 16:13:51 +08:00
rocknamx
b23749463e
Prioritized DQN (#30)
* add sum_tree.py

* add prioritized replay buffer

* del sum_tree.py

* fix some format issues

* fix weight_update bug

* simply replace replaybuffer in test_dqn without weight update

* weight default set to 1

* fix sampling bug when buffer is not full

* rename parameter

* fix formula error, add accuracy check

* add PrioritizedDQN test

* add test_pdqn.py

* add update_weight() doc

* add ref of prio dqn in readme.md and index.rst

* restore test_dqn.py, fix args of test_pdqn.py
2020-04-26 12:05:58 +08:00
Trinkle23897
815f3522bb imitation with discrete action space 2020-04-20 11:25:20 +08:00
Trinkle23897
6bf1ea644d fix ppo 2020-04-19 14:30:42 +08:00
Trinkle23897
680fc0ffbe gae 2020-04-14 21:11:06 +08:00
Trinkle23897
6a244d1fbb save_fn 2020-04-11 16:54:27 +08:00
Trinkle23897
13086b7f64 add ignore_obs_next in buffer 2020-04-10 09:01:17 +08:00
Trinkle23897
6da80e045a fix rnn (#19), add __repr__, and fix #26 2020-04-09 19:53:45 +08:00
Trinkle23897
86572c66d4 maybe finished rnn? 2020-04-08 21:13:15 +08:00
Trinkle23897
e0809ff135 add policy docs (#21) 2020-04-06 19:36:59 +08:00
Trinkle23897
974ade8019 add some docs 2020-04-03 21:28:12 +08:00
ShenDezhou
4da857d86e
Fix windows env setup bugs and other typo. (#11) 2020-03-31 17:22:32 +08:00
Trinkle23897
d9e4b9d16f upd doc 2020-03-29 10:22:03 +08:00
Trinkle23897
f68f23292e update readme and force flake8 2020-03-28 13:27:01 +08:00
Minghao Zhang
068c4068ec
fix atari/mujoco env (#7)
* update atari.py

* fix setup.py
pass the pytest

* fix setup.py
pass the pytest

* add args "render"

* change the tensorboard writter

* change the tensorboard writter

* change device, render, tensorboard log location

* change device, render, tensorboard log location

* remove some wrong local files

* fix some tab mistakes and the envs name in continuous/test_xx.py

* add examples and point robot maze environment

* fix some bugs during testing examples

* add dqn network and fix some args

* change back the tensorboard writter's frequency to ensure ppo and a2c can write things normally

* add a warning to collector

* rm some unrelated files

* reformat

* fix a bug in test_dqn due to the model wrong selection

* change atari frame skip and observation to improve performance

* readd some files

* change import

* modified readme

* rm tensorboard log

* update atari and mujoco which are ignored

* rm the wrong lines
2020-03-28 12:03:49 +08:00
Trinkle23897
c42990c725 add rllib result and fix pep8 2020-03-28 09:43:35 +08:00