27 Commits

Author SHA1 Message Date
n+e
fc251ab0b8
bump to v0.4.3 (#432)
* add makefile
* bump version
* add isort and yapf
* update contributing.md
* update PR template
* spelling check
2021-09-03 05:05:04 +08:00
Yi Su
291be08d43
Add Rainbow DQN (#386)
- add RainbowPolicy
- add `set_beta` method in prio_buffer
- add NoisyLinear in utils/network
2021-08-29 23:34:59 +08:00
Andriy Drozdyuk
d161059c3d
Replaced indice by plural indices (#422) 2021-08-20 21:58:44 +08:00
n+e
09692c84fe
fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
2021-03-30 16:06:03 +08:00
Trinkle23897
e99e1b0fdd Improve buffer.prev() & buffer.next() (#294) 2021-02-22 19:19:22 +08:00
ChenDRAG
150d0ec51b
Step collector implementation (#280)
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-19 10:33:49 +08:00
n+e
c838f2f0e9
fix 2 bugs of batch (#284)
1. `_create_value(Batch(a={}, b=[1, 2, 3]), 10, False)`

before:
```python
TypeError: cannot concatenate with Batch() which is scalar
```
after:
```python
Batch(
    a: Batch(),
    b: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
)
```

2. creating keys in a batch's subkey, e.g. 
```python
a = Batch(info={"key1": [0, 1], "key2": [2, 3]})
a[0] = Batch(info={"key1": 2, "key3": 4})
print(a)
```
before:
```python
Batch(
    info: Batch(
              key1: array([0, 1]),
              key2: array([0, 3]),
          ),
)
```
after:
```python
ValueError: Creating keys is not supported by item assignment.
```

3. small optimization for `Batch.stack_` and `Batch.cat_`, raise ValueError when receiving invalid data format.
2021-02-02 19:28:05 +08:00
ChenDRAG
f0129f4ca7
Add CachedReplayBuffer and ReplayBufferManager (#278)
This is the second commit of 6 commits mentioned in #274, which features minor refactor of ReplayBuffer and adding two new ReplayBuffer classes called CachedReplayBuffer and ReplayBufferManager. You can check #274 for more detail.

1. Add ReplayBufferManager (handle a list of buffers) and CachedReplayBuffer;
2. Make sure the reserved keys cannot be edited by methods like `buffer.done = xxx`;
3. Add `set_batch` method for manually choosing the batch the ReplayBuffer wants to handle;
4. Add `sample_index` method, same as `sample` but only return index instead of both index and batch data;
5. Add `prev` (one-step previous transition index), `next` (one-step next transition index) and `unfinished_index` (the last modified index whose done==False);
6. Separate `alloc_fn` method for allocating new memory for `self._meta` when a new `(key, value)` pair comes in;
7. Move buffer's documentation to `docs/tutorials/concepts.rst`.

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-01-29 12:23:18 +08:00
Nico Gürtler
5d13d8a453
Saving and loading replay buffer with HDF5 (#261)
As mentioned in #260, this pull request is about an implementation of saving and loading the replay buffer with HDF5.
2020-12-17 08:58:43 +08:00
Trinkle23897
34f714a677 Numba acceleration (#193)
Training FPS improvement (base commit is 94bfb32):
test_pdqn: 1660 (without numba) -> 1930
discrete/test_ppo: 5100 -> 5170

since nstep has little impact on overall performance, the unit test result is:
GAE: 4.1s -> 0.057s
nstep: 0.3s -> 0.15s (little improvement)

Others:
- fix a bug in ttt set_eps
- keep only sumtree in segment tree implementation
- dirty fix for asyncVenv check_id test
2020-09-02 13:03:32 +08:00
yingchengyang
5b49192a48
DQN Atari examples (#187)
This PR aims to provide the script of Atari DQN setting:
- A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps)
- A general script for all atari game
Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable.

It also adds another parameter save_only_last_obs for replay buffer in order to save the memory.

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-30 05:48:09 +08:00
n+e
94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00
n+e
311a2beafb
Pickle compatible for replay buffer and improve buffer.get (#182)
fix #84 and make buffer more efficient
2020-08-16 16:26:23 +08:00
n+e
140b1c2cab
Improve PER (#159)
- use segment tree to rewrite the previous PrioReplayBuffer code, add the test

- enable all Q-learning algorithms to use PER
2020-08-06 10:26:24 +08:00
ChenDRAG
f2bcc55a25
ShmemVectorEnv Implementation (#174)
* add shmem vecenv, some add&fix in test_env

* generalize test_env IO

* pep8 fix

* comment update

* style change

* pep8 fix

* style fix

* minor fix

* fix a bug

* test fix

* change env

* testenv bug fix& shmem support recurse dict

* bugfix

* pep8 fix

* _NP_TO_CT enhance

* doc update

* docstring update

* pep8 fix

* style change

* style fix

* remove assert

* minor

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-04 13:39:05 +08:00
ChenDRAG
d09b69e594
buffer update bug fix (#154)
* buffer update bug fix

* some fix in buffer update

* polish

Co-authored-by: n+e <463003665@qq.com>
2020-07-20 22:12:57 +08:00
youkaichao
ff99662fe6
bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer (#120)
* bugfix for update with empty buffer; remove duplicate variable _weight_sum in PrioritizedReplayBuffer

* point out that ListReplayBuffer cannot be sampled

* remove useless _amortization_counter variable
2020-07-10 08:24:11 +08:00
Trinkle23897
e0f4862d01 store RNN hidden states in policy._state and add sample_avail in buffer (#19) 2020-06-29 12:18:52 +08:00
Alexis DUBURCQ
a951a32487
Enable partial stacking at Batch level (#100)
* Enable stacking of partially matching Batch instances.

* Fix list support for getitem.

* Fix Batch 'size' method.

* Update Batch documentation.
2020-06-27 09:06:40 +08:00
Alexis DUBURCQ
70aa7bf93e
Use lower-level API to reduce overhead. (#97)
* Use lower-level API to reduce overhead.

* Further improvements.

* Buffer _add_to_buffer improvement.

* Do not use _data field to store Batch data to avoid overhead. Add back _meta field in Buffer.

* Restore metadata attribute to store batch in Buffer.

* Move out nested methods.

* Update try/catch instead of actual check to efficiency.

* Remove unsed branches for efficiency.

* Use np.array over list when possible for efficiency.

* Final performance improvement.

* Add unit tests for Batch size method.

* Add missing stack unit tests.

* Enforce Buffer initialization to zero.

Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
2020-06-26 18:37:50 +08:00
Trinkle23897
81e4a16ef2 fix a bug in re-index replay buffer (fix #82) 2020-06-17 16:37:51 +08:00
Trinkle23897
560116d0b2 cheat sheet 2020-06-08 21:53:00 +08:00
Trinkle23897
6b96f124ae fix pdqn 2020-04-26 15:11:20 +08:00
Trinkle23897
13086b7f64 add ignore_obs_next in buffer 2020-04-10 09:01:17 +08:00
Trinkle23897
6da80e045a fix rnn (#19), add __repr__, and fix #26 2020-04-09 19:53:45 +08:00
Minghao Zhang
3c0a09fefd
minor reformat (#2)
* update atari.py

* fix setup.py
pass the pytest

* fix setup.py
pass the pytest
2020-03-26 09:01:20 +08:00
Trinkle23897
8bd8246b16 refract test code 2020-03-21 10:58:01 +08:00