31 Commits

Author SHA1 Message Date
Dominik Jain
023b33c917 Make mypy happy 2023-10-18 20:44:18 +02:00
Dominik Jain
213e08a846 Add method get_output_dim to BaseActor 2023-10-18 20:44:17 +02:00
Dominik Jain
a161a9cf58 Improve type annotations, fix type issues and add checks 2023-10-18 20:44:17 +02:00
Dominik Jain
e0e7349b0a Add base class BaseActor with method get_preprocess_net for high-level API 2023-10-18 20:44:16 +02:00
Michael Panchenko
b900fdf6f2
Remove kwargs in policy init (#950)
Closes #947 

This removes all kwargs from all policy constructors. While doing that,
I also improved several names and added a whole lot of TODOs.

## Functional changes:

1. Added possibility to pass None as `critic2` and `critic2_optim`. In
fact, the default behavior then should cover the absolute majority of
cases
2. Added a function called `clone_optimizer` as a temporary measure to
support passing `critic2_optim=None`

## Breaking changes:

1. `action_space` is no longer optional. In fact, it already was
non-optional, as there was a ValueError in BasePolicy.init. So now
several examples were fixed to reflect that
2. `reward_normalization` removed from DDPG and children. It was never
allowed to pass it as `True` there, an error would have been raised in
`compute_n_step_reward`. Now I removed it from the interface
3. renamed `critic1` and similar to `critic`, in order to have uniform
interfaces. Note that the `critic` in DDPG was optional for the sole
reason that child classes used `critic1`. I removed this optionality
(DDPG can't do anything with `critic=None`)
4. Several renamings of fields (mostly private to public, so backwards
compatible)

## Additional changes: 
1. Removed type and default declaration from docstring. This kind of
duplication is really not necessary
2. Policy constructors are now only called using named arguments, not a
fragile mixture of positional and named as before
5. Minor beautifications in typing and code 
6. Generally shortened docstrings and made them uniform across all
policies (hopefully)

## Comment:

With these changes, several problems in tianshou's inheritance hierarchy
become more apparent. I tried highlighting them for future work.

---------

Co-authored-by: Dominik Jain <d.jain@appliedai.de>
2023-10-08 08:57:03 -07:00
Michael Panchenko
2cc34fb72b
Poetry install, remove gym, bump python (#925)
Closes #914 

Additional changes:

- Deprecate python below 11
- Remove 3rd party and throughput tests. This simplifies install and
test pipeline
- Remove gym compatibility and shimmy
- Format with 3.11 conventions. In particular, add `zip(...,
strict=True/False)` where possible

Since the additional tests and gym were complicating the CI pipeline
(flaky and dist-dependent), it didn't make sense to work on fixing the
current tests in this PR to then just delete them in the next one. So
this PR changes the build and removes these tests at the same time.
2023-09-05 14:34:23 -07:00
Michael Panchenko
600f4bbd55
Python 3.9, black + ruff formatting (#921)
Preparation for #914 and #920

Changes formatting to ruff and black. Remove python 3.8

## Additional Changes

- Removed flake8 dependencies
- Adjusted pre-commit. Now CI and Make use pre-commit, reducing the
duplication of linting calls
- Removed check-docstyle option (ruff is doing that)
- Merged format and lint. In CI the format-lint step fails if any
changes are done, so it fulfills the lint functionality.

---------

Co-authored-by: Jiayi Weng <jiayi@openai.com>
2023-08-25 14:40:56 -07:00
Michael Panchenko
07702fc007
Improved typing and reduced duplication (#912)
# Goals of the PR

The PR introduces **no changes to functionality**, apart from improved
input validation here and there. The main goals are to reduce some
complexity of the code, to improve types and IDE completions, and to
extend documentation and block comments where appropriate. Because of
the change to the trainer interfaces, many files are affected (more
details below), but still the overall changes are "small" in a certain
sense.

## Major Change 1 - BatchProtocol

**TL;DR:** One can now annotate which fields the batch is expected to
have on input params and which fields a returned batch has. Should be
useful for reading the code. getting meaningful IDE support, and
catching bugs with mypy. This annotation strategy will continue to work
if Batch is replaced by TensorDict or by something else.

**In more detail:** Batch itself has no fields and using it for
annotations is of limited informational power. Batches with fields are
not separate classes but instead instances of Batch directly, so there
is no type that could be used for annotation. Fortunately, python
`Protocol` is here for the rescue. With these changes we can now do
things like

```python
class ActionBatchProtocol(BatchProtocol):
    logits: Sequence[Union[tuple, torch.Tensor]]
    dist: torch.distributions.Distribution
    act: torch.Tensor
    state: Optional[torch.Tensor]


class RolloutBatchProtocol(BatchProtocol):
    obs: torch.Tensor
    obs_next: torch.Tensor
    info: Dict[str, Any]
    rew: torch.Tensor
    terminated: torch.Tensor
    truncated: torch.Tensor

class PGPolicy(BasePolicy):
    ...

    def forward(
        self,
        batch: RolloutBatchProtocol,
        state: Optional[Union[dict, Batch, np.ndarray]] = None,
        **kwargs: Any,
    ) -> ActionBatchProtocol:

```

The IDE and mypy are now very helpful in finding errors and in
auto-completion, whereas before the tools couldn't assist in that at
all.

## Major Change 2 - remove duplication in trainer package

**TL;DR:** There was a lot of duplication between `BaseTrainer` and its
subclasses. Even worse, it was almost-duplication. There was also
interface fragmentation through things like `onpolicy_trainer`. Now this
duplication is gone and all downstream code was adjusted.

**In more detail:** Since this change affects a lot of code, I would
like to explain why I thought it to be necessary.

1. The subclasses of `BaseTrainer` just duplicated docstrings and
constructors. What's worse, they changed the order of args there, even
turning some kwargs of BaseTrainer into args. They also had the arg
`learning_type` which was passed as kwarg to the base class and was
unused there. This made things difficult to maintain, and in fact some
errors were already present in the duplicated docstrings.
2. The "functions" a la `onpolicy_trainer`, which just called the
`OnpolicyTrainer.run`, not only introduced interface fragmentation but
also completely obfuscated the docstring and interfaces. They themselves
had no dosctring and the interface was just `*args, **kwargs`, which
makes it impossible to understand what they do and which things can be
passed without reading their implementation, then reading the docstring
of the associated class, etc. Needless to say, mypy and IDEs provide no
support with such functions. Nevertheless, they were used everywhere in
the code-base. I didn't find the sacrifices in clarity and complexity
justified just for the sake of not having to write `.run()` after
instantiating a trainer.
3. The trainers are all very similar to each other. As for my
application I needed a new trainer, I wanted to understand their
structure. The similarity, however, was hard to discover since they were
all in separate modules and there was so much duplication. I kept
staring at the constructors for a while until I figured out that
essentially no changes to the superclass were introduced. Now they are
all in the same module and the similarities/differences between them are
much easier to grasp (in my opinion)
4. Because of (1), I had to manually change and check a lot of code,
which was very tedious and boring. This kind of work won't be necessary
in the future, since now IDEs can be used for changing signatures,
renaming args and kwargs, changing class names and so on.

I have some more reasons, but maybe the above ones are convincing
enough.

## Minor changes: improved input validation and types

I added input validation for things like `state` and `action_scaling`
(which only makes sense for continuous envs). After adding this, some
tests failed to pass this validation. There I added
`action_scaling=isinstance(env.action_space, Box)`, after which tests
were green. I don't know why the tests were green before, since action
scaling doesn't make sense for discrete actions. I guess some aspect was
not tested and didn't crash.

I also added Literal in some places, in particular for
`action_bound_method`. Now it is no longer allowed to pass an empty
string, instead one should pass `None`. Also here there is input
validation with clear error messages.

@Trinkle23897 The functional tests are green. I didn't want to fix the
formatting, since it will change in the next PR that will solve #914
anyway. I also found a whole bunch of code in `docs/_static`, which I
just deleted (shouldn't it be copied from the sources during docs build
instead of committed?). I also haven't adjusted the documentation yet,
which atm still mentions the trainers of the type
`onpolicy_trainer(...)` instead of `OnpolicyTrainer(...).run()`

## Breaking Changes

The adjustments to the trainer package introduce breaking changes as
duplicated interfaces are deleted. However, it should be very easy for
users to adjust to them

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-08-22 09:54:46 -07:00
Markus Krimmel
ea36dc5195
Changes to support Gym 0.26.0 (#748)
* Changes to support Gym 0.26.0

* Replace map by simpler list comprehension

* Use syntax that is compatible with python 3.7

* Format code

* Fix environment seeding in test environment, fix buffer_profile test

* Remove self.seed() from __init__

* Fix random number generation

* Fix throughput tests

* Fix tests

* Removed done field from Buffer, fixed throughput test, turned off wandb, fixed formatting, fixed type hints, allow preprocessing_fn with truncated and terminated arguments, updated docstrings

* fix lint

* fix

* fix import

* fix

* fix mypy

* pytest --ignore='test/3rd_party'

* Use correct step API in _SetAttrWrapper

* Format

* Fix mypy

* Format

* Fix pydocstyle.
2022-09-26 09:31:23 -07:00
Squeemos
e01385ea30
Change action_dim to action_shape (#602)
Noticed that in IQN and FQF there were some mismatches in the docstrings. Figured I would make a pull request to make it match.
2022-04-22 08:09:57 +08:00
ChenDRAG
c25926dd8f
Formalize variable names (#509)
Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
2022-01-30 00:53:56 +08:00
Yi Su
a59d96d041
Add Intrinsic Curiosity Module (#503) 2022-01-15 02:43:48 +08:00
Markus28
a2d76d1276
Remove reset_buffer() from reset method (#501) 2022-01-12 16:46:28 -08:00
n+e
fc251ab0b8
bump to v0.4.3 (#432)
* add makefile
* bump version
* add isort and yapf
* update contributing.md
* update PR template
* spelling check
2021-09-03 05:05:04 +08:00
Yi Su
291be08d43
Add Rainbow DQN (#386)
- add RainbowPolicy
- add `set_beta` method in prio_buffer
- add NoisyLinear in utils/network
2021-08-29 23:34:59 +08:00
Yi Su
c0bc8e00ca
Add Fully-parameterized Quantile Function (#376) 2021-06-15 11:59:02 +08:00
Yi Su
f3169b4c1f
Add Implicit Quantile Network (#371) 2021-05-29 09:44:23 +08:00
n+e
09692c84fe
fix numpy>=1.20 typing check (#323)
Change the behavior of to_numpy and to_torch: from now on, dict is automatically converted to Batch and list is automatically converted to np.ndarray (if an error occurs, raise the exception instead of converting each element in the list).
2021-03-30 16:06:03 +08:00
ChenDRAG
f528131da1
hotfix:fix test failure in cuda environment (#289) 2021-02-09 17:13:40 +08:00
ChenDRAG
a633a6a028
update utils.network (#275)
This is the first commit of 6 commits mentioned in #274, which features

1. Refactor of `Class Net` to support any form of MLP.
2. Enable type check in utils.network.
3. Relative change in docs/test/examples.
4. Move atari-related network to examples/atari/atari_network.py

Co-authored-by: Trinkle23897 <trinkle23897@gmail.com>
2021-01-20 16:54:13 +08:00
wizardsheng
c6f2648e87
Add C51 algorithm (#266)
This is the PR for C51algorithm: https://arxiv.org/abs/1707.06887

1. add C51 policy in tianshou/policy/modelfree/c51.py.
2. add C51 net in tianshou/utils/net/discrete.py.
3. add C51 atari example in examples/atari/atari_c51.py.
4. add C51 statement in tianshou/policy/__init__.py.
5. add C51 test in test/discrete/test_c51.py.
6. add C51 atari results in examples/atari/results/c51/.

By running "python3 atari_c51.py --task "PongNoFrameskip-v4" --batch-size 64", get  best_result': '20.50 ± 0.50', in epoch 9.

By running "python3 atari_c51.py --task "BreakoutNoFrameskip-v4" --n-step 1 --epoch 40", get best_reward: 407.400000 ± 31.155096 in epoch 39.
2021-01-06 10:17:45 +08:00
rocknamx
bf39b9ef7d
clarify updating state (#224)
Add an indicator(i.e. `self.learning`) of learning will be convenient for distinguishing state of policy.
Meanwhile, the state of `self.training` will be undisputed in the training stage.
Related issue: #211 

Others:
- fix a bug in DDQN: target_q could not be sampled from np.random.rand
- fix a bug in DQN atari net: it should add a ReLU before the last layer
- fix a bug in collector timing

Co-authored-by: n+e <463003665@qq.com>
2020-09-22 16:28:46 +08:00
danagi
a6ee979609
implement sac for discrete action settings (#216)
Co-authored-by: n+e <trinkle23897@cmu.edu>
2020-09-14 14:59:23 +08:00
n+e
b284ace102
type check in unit test (#200)
Fix #195: Add mypy test in .github/workflows/docs_and_lint.yml.

Also remove the out-of-the-date api
2020-09-13 19:31:50 +08:00
n+e
c91def6cbc
code format and update function signatures (#213)
Cherry-pick from #200 

- update the function signature
- format code-style
- move _compile into separate functions
- fix a bug in to_torch and to_numpy (Batch)
- remove None in action_range

In short, the code-format only contains function-signature style and `'` -> `"`. (pick up from [black](https://github.com/psf/black))
2020-09-12 15:39:01 +08:00
n+e
b86d78766b
fix docs and add docstring check (#210)
- fix broken links and out-of-the-date content
- add pydocstyle and doc8 check
- remove collector.seed and collector.render
2020-09-11 07:55:37 +08:00
n+e
94bfb32cc1
optimize training procedure and improve code coverage (#189)
1. add policy.eval() in all test scripts' "watch performance"
2. remove dict return support for collector preprocess_fn
3. add `__contains__` and `pop` in batch: `key in batch`, `batch.pop(key, deft)`
4. exact n_episode for a list of n_episode limitation and save fake data in cache_buffer when self.buffer is None (#184)
5. fix tensorboard logging: h-axis stands for env step instead of gradient step; add test results into tensorboard
6. add test_returns (both GAE and nstep)
7. change the type-checking order in batch.py and converter.py in order to meet the most often case first
8. fix shape inconsistency for torch.Tensor in replay buffer
9. remove `**kwargs` in ReplayBuffer
10. remove default value in batch.split() and add merge_last argument (#185)
11. improve nstep efficiency
12. add max_batchsize in onpolicy algorithms
13. potential bugfix for subproc.wait
14. fix RecurrentActorProb
15. improve the code-coverage (from 90% to 95%) and remove the dead code
16. fix some incorrect type annotation

The above improvement also increases the training FPS: on my computer, the previous version is only ~1800 FPS and after that, it can reach ~2050 (faster than v0.2.4.post1).
2020-08-27 12:15:18 +08:00
youkaichao
32df0567bb
use nn.Sequential in DQN (#176) 2020-08-02 15:14:44 +08:00
yingchengyang
99a1d40e85
Dueling DQN (#170)
Co-authored-by: n+e <463003665@qq.com>
2020-07-29 19:44:42 +08:00
n+e
bd9c3c7f8d
docs fix and v0.2.5 (#156)
* pre

* update docs

* update docs

* $ in bash

* size -> hidden_layer_size

* doctest

* doctest again

* filter a warning

* fix bug

* fix examples

* test fail

* test succ
2020-07-22 14:42:08 +08:00
youkaichao
e767de044b
Remove dummy net code (#123)
* remove dummy net; delete two files

* split code to have backbone and head

* rename class

* change torch.float to torch.float32

* use flatten(1) instead of view(batch, -1)

* remove dummy net in docs

* bugfix for rnn

* fix cuda error

* minor fix of docs

* do not change the example code in dqn tutorial, since it is for demonstration

Co-authored-by: Trinkle23897 <463003665@qq.com>
2020-07-09 22:57:01 +08:00