50 Commits

Author SHA1 Message Date
Michael Panchenko
07702fc007
Improved typing and reduced duplication (#912)
# Goals of the PR

The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.

## Major Change 1 - BatchProtocol

**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.

**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like

```python
class ActionBatchProtocol(BatchProtocol):
    logits: Sequence[Union[tuple, torch.Tensor]]
    dist: torch.distributions.Distribution
    act: torch.Tensor
    state: Optional[torch.Tensor]


class RolloutBatchProtocol(BatchProtocol):
    obs: torch.Tensor
    obs_next: torch.Tensor
    info: Dict[str, Any]
    rew: torch.Tensor
    terminated: torch.Tensor
    truncated: torch.Tensor

class PGPolicy(BasePolicy):
    ...

    def forward(
        self,
        batch: RolloutBatchProtocol,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        **kwargs: Any,
    ) -> ActionBatchProtocol:

```

The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.

## Major Change 2 - remove duplication in trainer package

**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.

**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.

1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.

I have some more reasons, but maybe the above ones are convincing
enough.

## Minor changes: improved input validation and types

I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.

I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.

@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`

## Breaking Changes

The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 09:54:46 -07:00
Zhenjie Zhao
f8808d236f
fix a problem of the atari dqn example (#861) 2023-04-30 08:44:27 -07:00
Markus Krimmel
6c6c872523
Gymnasium Integration (#789)
Changes:
- Disclaimer in README
- Replaced all occurences of Gym with Gymnasium
- Removed code that is now dead since we no longer need to support the
old step API
- Updated type hints to only allow new step API
- Increased required version of envpool to support Gymnasium
- Increased required version of PettingZoo to support Gymnasium
- Updated `PettingZooEnv` to only use the new step API, removed hack to
also support old API
- I had to add some `# type: ignore` comments, due to new type hinting
in Gymnasium. I'm not that familiar with type hinting but I believe that
the issue is on the Gymnasium side and we are looking into it.
- Had to update `MyTestEnv` to support `options` kwarg
- Skip NNI tests because they still use OpenAI Gym
- Also allow `PettingZooEnv` in vector environment
- Updated doc page about ReplayBuffer to also talk about terminated and
truncated flags.

Still need to do: 
- Update the Jupyter notebooks in docs
- Check the entire code base for more dead code (from compatibility
stuff)
- Check the reset functions of all environments/wrappers in code base to
make sure they use the `options` kwarg
- Someone might want to check test_env_finite.py
- Is it okay to allow `PettingZooEnv` in vector environments? Might need
to update docs?
2023-02-03 11:57:27 -08:00
Markus Krimmel
4c3791a459
Updated atari wrappers, fixed pre-commit (#781)
This PR addresses #772 (updates Atari wrappers to work with new Gym API)
and some additional issues:

- Pre-commit was using gitlab for flake8, which as of recently requires
authentication -> Replaced with GitHub
- Yapf was quietly failing in pre-commit. Changed it such that it fixes
formatting in-place
- There is an incompatibility between flake8 and yapf where yapf puts
binary operators after the line break and flake8 wants it before the
break. I added an exception for flake8.
- Also require `packaging` in setup.py

My changes shouldn't change the behaviour of the wrappers for older
versions, but please double check.
Idk whether it's just me, but there are always some incompatibilities
between yapf and flake8 that need to resolved manually. It might make
sense to try black instead.
2022-12-04 13:00:53 -08:00
Yi Su
662af52820
Fix Atari PPO example (#780)
- [x] I have marked all applicable categories:
    + [ ] exception-raising fix
    + [x] algorithm implementation fix
    + [ ] documentation modification
    + [ ] new feature
- [x] I have reformatted the code using `make format` (**required**)
- [x] I have checked the code using `make commit-checks` (**required**)
- [x] If applicable, I have mentioned the relevant/related issue(s)
- [x] If applicable, I have listed every items in this Pull Request
below

While trying to debug Atari PPO+LSTM, I found significant gap between
our Atari PPO example vs [CleanRL's Atari PPO w/
EnvPool](https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpoolpy).
I tried to align our implementation with CleaRL's version, mostly in
hyper parameter choices, and got significant gain in Breakout, Qbert,
SpaceInvaders while on par in other games. After this fix, I would
suggest updating our [Atari
Benchmark](https://tianshou.readthedocs.io/en/master/tutorials/benchmark.html)
PPO experiments.

A few interesting findings:

- Layer initialization helps stabilize the training and enable the use
of larger learning rates; without it, larger learning rates will trigger
NaN gradient very quickly;
- ppo.py#L97-L101: this change helps training stability for reasons I do
not understand; also it makes the GPU usage higher.

Shoutout to [CleanRL](https://github.com/vwxyzjn/cleanrl) for a
well-tuned Atari PPO reference implementation!
2022-12-04 12:23:18 -08:00
Yifei Cheng
43792bf5ab
Upgrade gym (#613)
fixes some deprecation warnings due to new changes in gym version 0.23:
- use `env.np_random.integers` instead of `env.np_random.randint`
- support `seed` and `return_info` arguments for reset (addresses https://github.com/thu-ml/tianshou/issues/605)
2022-06-28 06:52:21 +08:00
Yi Su
9ce0a554dc
Add Atari SAC examples (#657)
- Add Atari (discrete) SAC examples;
- Fix a bug in Discrete SAC evaluation; default to deterministic mode.
2022-06-04 13:26:08 +08:00
Jiayi Weng
5ecea2402e
Fix save_checkpoint_fn return value (#659)
- Fix save_checkpoint_fn return value to checkpoint_path;
- Fix wrong link in doc;
- Fix an off-by-one bug in trainer iterator.
2022-06-03 01:07:07 +08:00
Jiayi Weng
109875d43d
Fix num_envs=test_num (#653)
* fix num_envs=test_num

* fix mypy
2022-05-30 12:38:47 +08:00
Michal Gregor
c87b9f49bc
Add show_progress option for trainer (#641)
- A DummyTqdm class added to utils: it replicates the interface used by trainers, but does not show the progress bar;
- Added a show_progress argument to the base trainer: when show_progress == True, dummy_tqdm is used in place of tqdm.
2022-05-17 23:41:59 +08:00
Chengqi Duan
5eab7dc218
Add Atari Results (#600) 2022-04-24 20:44:54 +08:00
Alex Nikulkov
92456cdb68
Add learning rate scheduler to BasePolicy (#598) 2022-04-17 23:52:30 +08:00
Jiayi Weng
2a9c9289e5
rename save_fn to save_best_fn to avoid ambiguity (#575)
This PR also introduces `tianshou.utils.deprecation` for a unified deprecation wrapper.
2022-03-22 04:29:27 +08:00
Chengqi Duan
ad2e1eaea0 Fix WandbLogger import error in Atari examples (#562) 2022-03-08 08:38:56 -05:00
Costa Huang
df3d7f582b
Update WandbLogger implementation (#558)
* Use `global_step` as the x-axis for wandb
* Use Tensorboard SummaryWritter as core with `wandb.init(..., sync_tensorboard=True)`
* Update all atari examples with wandb

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
2022-03-07 06:40:47 +08:00
Chengqi Duan
23fbc3b712
upgrade gym version to >=0.21, fix related CI and update examples/atari (#534)
Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
2022-02-25 07:40:33 +08:00
Yi Su
d29188ee77
update atari ppo slots (#529) 2022-02-13 04:04:21 +08:00
Yi Su
40289b8b0e
Add atari ppo example (#523)
I needed a policy gradient baseline myself and it has been requested several times (#497, #374, #440). I used https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari.py as a reference for hyper-parameters.

Note that using lr=2.5e-4 will result in "Invalid Value" error for 2 games. The fix is to reduce the learning rate. That's why I set the default lr to 1e-4. See discussion in https://github.com/DLR-RM/rl-baselines3-zoo/issues/156.
2022-02-11 06:45:06 +08:00
ChenDRAG
c25926dd8f
Formalize variable names (#509)
Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
2022-01-30 00:53:56 +08:00
Yi Su
a59d96d041
Add Intrinsic Curiosity Module (#503) 2022-01-15 02:43:48 +08:00
Yi Su
3592f45446
Fix critic network for Discrete CRR (#485)
- Fixes an inconsistency in the implementation of Discrete CRR. Now it uses `Critic` class for its critic, following conventions in other actor-critic policies;
- Updates several offline policies to use `ActorCritic` class for its optimizer to eliminate randomness caused by parameter sharing between actor and critic;
- Add `writer.flush()` in TensorboardLogger to ensure real-time result;
- Enable `test_collector=None` in 3 trainers to turn off testing during training;
- Updates the Atari offline results in README.md;
- Moves Atari offline RL examples to `examples/offline`; tests to `test/offline` per review comments.
2021-11-28 23:10:28 +08:00
Jiayi Weng
098d466467
fix atari wrapper to be deterministic (#467) 2021-10-19 22:26:11 +08:00
Ayush Chaurasia
22d7bf38c8
Improve W&B logger (#441)
- rename WandBLogger -> WandbLogger
- add save_data and restore_data
- allow more input arguments for wandb init
- integrate wandb into test/modelbase/test_psrl.py and examples/atari/atari_dqn.py
- documentation update
2021-09-24 21:52:23 +08:00
Jiayi Weng
e8f8cdfa41
fix logger.write error in atari script (#444)
- fix a bug in #427: logger.write should pass a dict
- change SubprocVectorEnv to ShmemVectorEnv in atari
- increase logger interval for eps
2021-09-09 00:51:39 +08:00
n+e
fc251ab0b8
bump to v0.4.3 (#432)
* add makefile
* bump version
* add isort and yapf
* update contributing.md
* update PR template
* spelling check
2021-09-03 05:05:04 +08:00
Andriy Drozdyuk
8a5e2190f7
Add Weights and Biases Logger (#427)
- rename BasicLogger to TensorboardLogger
- refactor logger code
- add WandbLogger

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
2021-08-30 22:35:02 +08:00
Yi Su
291be08d43
Add Rainbow DQN (#386)
- add RainbowPolicy
- add `set_beta` method in prio_buffer
- add NoisyLinear in utils/network
2021-08-29 23:34:59 +08:00
Yi Su
c0bc8e00ca
Add Fully-parameterized Quantile Function (#376) 2021-06-15 11:59:02 +08:00
Yi Su
21b2b22cd7
update iqn results and reward plots (#377) 2021-06-10 09:05:25 +08:00
Yi Su
f3169b4c1f
Add Implicit Quantile Network (#371) 2021-05-29 09:44:23 +08:00
n+e
458028a326
fix docs (#373)
- fix css style error
- fix mujoco benchmark result
2021-05-23 12:43:03 +08:00
Yi Su
8f7bc65ac7
Add discrete Critic Regularized Regression (#367) 2021-05-19 13:29:56 +08:00
Yi Su
b5c3ddabfa
Add discrete Conservative Q-Learning for offline RL (#359)
Co-authored-by: Yi Su <yi.su@antgroup.com>
Co-authored-by: Yi Su <yi.su@antfin.com>
2021-05-12 09:24:48 +08:00
Ark
84f58636eb
Make trainer resumable (#350)
- specify tensorboard >= 2.5.0
- add `save_checkpoint_fn` and `resume_from_log` in trainer

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2021-05-06 08:53:53 +08:00
n+e
c059f98abf
fix atari_bcq (#345) 2021-04-20 22:59:21 +08:00
ChenDRAG
f22b539761
Remove reward_normaliztion option in offpolicy algorithm (#298)
* remove rew_norm in nstep implementation
* improve test
* remove runnable/
* various doc fix

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-27 11:20:43 +08:00
ChenDRAG
9b61bc620c add logger (#295)
This PR focus on refactor of logging method to solve bug of nan reward and log interval. After these two pr, hopefully fundamental change of tianshou/data is finished. We then can concentrate on building benchmarks of tianshou finally.

Things changed:

1. trainer now accepts logger (BasicLogger or LazyLogger) instead of writer;
2. remove utils.SummaryWriter;
2021-02-24 14:48:42 +08:00
ChenDRAG
7036073649
Trainer refactor : some definition change (#293)
This PR focus on some definition change of trainer to make it more friendly to use and be consistent with typical usage in research papers, typically change `collect-per-step` to `step-per-collect`, add `update-per-step` / `episode-per-collect` accordingly, and modify the documentation.
2021-02-21 13:06:02 +08:00
ChenDRAG
150d0ec51b
Step collector implementation (#280)
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-02-19 10:33:49 +08:00
ChenDRAG
f528131da1
hotfix:fix test failure in cuda environment (#289) 2021-02-09 17:13:40 +08:00
wizardsheng
1eb6137645
Add QR-DQN algorithm (#276)
This is the PR for QR-DQN algorithm: https://arxiv.org/abs/1710.10044

1. add QR-DQN policy in tianshou/policy/modelfree/qrdqn.py.
2. add QR-DQN net in examples/atari/atari_network.py.
3. add QR-DQN atari example in examples/atari/atari_qrdqn.py.
4. add QR-DQN statement in tianshou/policy/init.py.
5. add QR-DQN unit test in test/discrete/test_qrdqn.py.
6. add QR-DQN atari results in examples/atari/results/qrdqn/.
7. add compute_q_value in DQNPolicy and C51Policy for simplify forward function.
8. move `with torch.no_grad():` from `_target_q` to BasePolicy

By running "python3 atari_qrdqn.py --task "PongNoFrameskip-v4" --batch-size 64", get best_result': '19.8 ± 0.40', in epoch 8.
2021-01-28 09:27:05 +08:00
Jialu Zhu
a511cb4779
Add offline trainer and discrete BCQ algorithm (#263)
The result needs to be tuned after `done` issue fixed.

Co-authored-by: n+e <trinkle23897@gmail.com>
2021-01-20 18:13:04 +08:00
ChenDRAG
a633a6a028
update utils.network (#275)
This is the first commit of 6 commits mentioned in #274, which features

1. Refactor of `Class Net` to support any form of MLP.
2. Enable type check in utils.network.
3. Relative change in docs/test/examples.
4. Move atari-related network to examples/atari/atari_network.py

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2021-01-20 16:54:13 +08:00
wizardsheng
c6f2648e87
Add C51 algorithm (#266)
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887

1. add C51 policy in tianshou/policy/modelfree/c51.py.
2. add C51 net in tianshou/utils/net/discrete.py.
3. add C51 atari example in examples/atari/atari_c51.py.
4. add C51 statement in tianshou/policy/__init__.py.
5. add C51 test in test/discrete/test_c51.py.
6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get  best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
2021-01-06 10:17:45 +08:00
Trinkle23897
cd481423dc sac mujoco result (#246) 2020-11-09 16:43:55 +08:00
n+e
710966eda7
change API of train_fn and test_fn (#229)
train_fn(epoch) -> train_fn(epoch, num_env_step)
test_fn(epoch) -> test_fn(epoch, num_env_step)
2020-09-26 16:35:37 +08:00
n+e
380e9e911d
fix atari examples (#206) 2020-09-06 23:05:33 +08:00
yingchengyang
5b49192a48
DQN Atari examples (#187)
This PR aims to provide the script of Atari DQN setting:
- A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps)
- A general script for all atari game
Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable.

It also adds another parameter save_only_last_obs for replay buffer in order to save the memory.

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-08-30 05:48:09 +08:00
n+e
94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00
youkaichao
a9f9940d17
code refactor for venv (#179)
- Refacor code to remove duplicate code

- Enable async simulation for all vector envs

- Remove `collector.close` and rename `VectorEnv` to `DummyVectorEnv`

The abstraction of vector env changed.

Prior to this pr, each vector env is almost independent.

After this pr, each env is wrapped into a worker, and vector envs differ with their worker type. In fact, users can just use `BaseVectorEnv` with different workers, I keep `SubprocVectorEnv`, `ShmemVectorEnv` for backward compatibility.

Co-authored-by: n+e <463003665@qq.com>
Co-authored-by: magicly <magicly007@gmail.com>
2020-08-19 15:00:24 +08:00