Merge branch 'refs/heads/thuml-master' into policy-train-eval

# Conflicts:
#	CHANGELOG.md
This commit is contained in:
Michael Panchenko 2024-05-05 16:03:34 +02:00
commit 4e38aeb829
12 changed files with 186 additions and 151 deletions

@ -1,6 +1,6 @@
- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s) - [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s)
- [ ] I have provided a description of the changes in this Pull Request - [ ] I have provided a description of the changes in this Pull Request
- [ ] I have added documentation for my changes - [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md
- [ ] If applicable, I have added tests to cover my changes. - [ ] If applicable, I have added tests to cover my changes.
- [ ] I have reformatted the code using `poe format` - [ ] I have reformatted the code using `poe format`
- [ ] I have checked style and types with `poe lint` and `poe type-check` - [ ] I have checked style and types with `poe lint` and `poe type-check`

@ -19,11 +19,24 @@
- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!). - New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!).
Launchers for parallelization currently in alpha state. #1074 Launchers for parallelization currently in alpha state. #1074
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074 - Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
- `continuous.Critic`:
- Add flag `apply_preprocess_net_to_obs_only` to allow the
preprocessing network to be applied to the observations only (without
the actions concatenated), which is essential for the case where we want
to reuse the actor's preprocessing network #1128
- Base class for collectors: `BaseCollector` #1122 - Base class for collectors: `BaseCollector` #1122
- Collectors can now explicitly specify whether to use the policy in training or evaluation mode. #1122 - Collectors can now explicitly specify whether to use the policy in training or evaluation mode. #1122
- New util context managers `in_eval_mode` and `in_train_mode` for torch modules. #1122 - New util context managers `in_eval_mode` and `in_train_mode` for torch modules. #1122
- `reset` of `Collectors` now returns `obs` and `info`. #1122 - `reset` of `Collectors` now returns `obs` and `info`. #1122
### Fixes
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
fixing the case where we want to reuse an actor's preprocessing network for the critic (affects usages
of the experiment builder method `with_critic_factory_use_actor` with continuous environments) #1128
- `atari_network.DQN`:
- Fix constructor input validation #1128
- Fix `output_dim` not being set if `features_only`=True and `output_dim_added_layer` is not None #1128
### Internal Improvements ### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063

@ -152,7 +152,7 @@
"id": "Lh2-hwE5Dn9I" "id": "Lh2-hwE5Dn9I"
}, },
"source": [ "source": [
"Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution." "Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
] ]
}, },
{ {

@ -266,3 +266,8 @@ postfix
backend backend
rliable rliable
hl hl
v_s
v_s_
obs
obs_next

@ -66,7 +66,7 @@ class DQN(NetBase[Any]):
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x, layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
) -> None: ) -> None:
# TODO: Add docstring # TODO: Add docstring
if features_only and output_dim_added_layer is not None: if not features_only and output_dim_added_layer is not None:
raise ValueError( raise ValueError(
"Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.", "Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.",
) )
@ -98,6 +98,7 @@ class DQN(NetBase[Any]):
layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)), layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
self.output_dim = output_dim_added_layer
else: else:
self.output_dim = base_cnn_output_dim self.output_dim = base_cnn_output_dim

134
poetry.lock generated

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. # This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -156,7 +156,7 @@ files = [
name = "arch" name = "arch"
version = "5.3.1" version = "5.3.1"
description = "ARCH for Python" description = "ARCH for Python"
optional = true optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "arch-5.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:75fa6f9386ecc2df81bcbf5d055a290a697482ca51e0b3459dab183d288993cb"}, {file = "arch-5.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:75fa6f9386ecc2df81bcbf5d055a290a697482ca51e0b3459dab183d288993cb"},
@ -3815,75 +3815,6 @@ sql-other = ["SQLAlchemy (>=1.4.36)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.8.0)"] xml = ["lxml (>=4.8.0)"]
[[package]]
name = "pandas"
version = "2.2.2"
description = "Powerful data structures for data analysis, time series, and statistics"
optional = false
python-versions = ">=3.9"
files = [
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"},
{file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"},
{file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"},
{file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"},
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"},
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"},
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"},
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"},
{file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"},
{file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"},
{file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"},
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"},
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"},
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"},
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"},
{file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"},
{file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"},
]
[package.dependencies]
numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""}
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
tzdata = ">=2022.7"
[package.extras]
all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"]
aws = ["s3fs (>=2022.11.0)"]
clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"]
compression = ["zstandard (>=0.19.0)"]
computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"]
consortium-standard = ["dataframe-api-compat (>=0.1.7)"]
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"]
feather = ["pyarrow (>=10.0.1)"]
fss = ["fsspec (>=2022.11.0)"]
gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"]
hdf5 = ["tables (>=3.8.0)"]
html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"]
mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"]
output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"]
parquet = ["pyarrow (>=10.0.1)"]
performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"]
plot = ["matplotlib (>=3.6.3)"]
postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"]
pyarrow = ["pyarrow (>=10.0.1)"]
spss = ["pyreadstat (>=1.2.0)"]
sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"]
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
xml = ["lxml (>=4.9.2)"]
[[package]] [[package]]
name = "pandocfilters" name = "pandocfilters"
version = "1.5.0" version = "1.5.0"
@ -3946,7 +3877,7 @@ files = [
name = "patsy" name = "patsy"
version = "0.5.6" version = "0.5.6"
description = "A Python package for describing statistical models and for building design matrices." description = "A Python package for describing statistical models and for building design matrices."
optional = true optional = false
python-versions = "*" python-versions = "*"
files = [ files = [
{file = "patsy-0.5.6-py2.py3-none-any.whl", hash = "sha256:19056886fd8fa71863fa32f0eb090267f21fb74be00f19f5c70b2e9d76c883c6"}, {file = "patsy-0.5.6-py2.py3-none-any.whl", hash = "sha256:19056886fd8fa71863fa32f0eb090267f21fb74be00f19f5c70b2e9d76c883c6"},
@ -4213,7 +4144,7 @@ wcwidth = "*"
name = "property-cached" name = "property-cached"
version = "1.6.4" version = "1.6.4"
description = "A decorator for caching properties in classes (forked from cached-property)." description = "A decorator for caching properties in classes (forked from cached-property)."
optional = true optional = false
python-versions = ">= 3.5" python-versions = ">= 3.5"
files = [ files = [
{file = "property-cached-1.6.4.zip", hash = "sha256:3e9c4ef1ed3653909147510481d7df62a3cfb483461a6986a6f1dcd09b2ebb73"}, {file = "property-cached-1.6.4.zip", hash = "sha256:3e9c4ef1ed3653909147510481d7df62a3cfb483461a6986a6f1dcd09b2ebb73"},
@ -5077,7 +5008,7 @@ files = [
name = "rliable" name = "rliable"
version = "1.0.8" version = "1.0.8"
description = "rliable: Reliable evaluation on reinforcement learning and machine learning benchmarks." description = "rliable: Reliable evaluation on reinforcement learning and machine learning benchmarks."
optional = true optional = false
python-versions = "*" python-versions = "*"
files = [] files = []
develop = false develop = false
@ -5366,7 +5297,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo
name = "seaborn" name = "seaborn"
version = "0.13.2" version = "0.13.2"
description = "Statistical data visualization" description = "Statistical data visualization"
optional = true optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"}, {file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"},
@ -6214,7 +6145,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
name = "statsmodels" name = "statsmodels"
version = "0.14.0" version = "0.14.0"
description = "Statistical computations and models for Python" description = "Statistical computations and models for Python"
optional = true optional = false
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "statsmodels-0.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16bfe0c96a53b20fa19067e3b6bd2f1d39e30d4891ea0d7bc20734a0ae95942d"}, {file = "statsmodels-0.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16bfe0c96a53b20fa19067e3b6bd2f1d39e30d4891ea0d7bc20734a0ae95942d"},
@ -6260,51 +6191,6 @@ build = ["cython (>=0.29.26)"]
develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"] develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"]
docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"] docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"]
[[package]]
name = "statsmodels"
version = "0.14.2"
description = "Statistical computations and models for Python"
optional = true
python-versions = ">=3.9"
files = [
{file = "statsmodels-0.14.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df5d6f95c46f0341da6c79ee7617e025bf2b371e190a8e60af1ae9cabbdb7a97"},
{file = "statsmodels-0.14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a87ef21fadb445b650f327340dde703f13aec1540f3d497afb66324499dea97a"},
{file = "statsmodels-0.14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5827a12e3ede2b98a784476d61d6bec43011fedb64aa815f2098e0573bece257"},
{file = "statsmodels-0.14.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10f2b7611a61adb7d596a6d239abdf1a4d5492b931b00d5ed23d32844d40e48e"},
{file = "statsmodels-0.14.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c254c66142f1167b4c7d031cf8db55294cc62ff3280e090fc45bd10a7f5fd029"},
{file = "statsmodels-0.14.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e46e9d59293c1af4cc1f4e5248f17e7e7bc596bfce44d327c789ac27f09111b"},
{file = "statsmodels-0.14.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:50fcb633987779e795142f51ba49fb27648d46e8a1382b32ebe8e503aaabaa9e"},
{file = "statsmodels-0.14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:876794068abfaeed41df71b7887000031ecf44fbfa6b50d53ccb12ebb4ab747a"},
{file = "statsmodels-0.14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a91f6c4943de13e3ce2e20ee3b5d26d02bd42300616a421becd53756f5deb37"},
{file = "statsmodels-0.14.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4864a1c4615c5ea5f2e3b078a75bdedc90dd9da210a37e0738e064b419eccee2"},
{file = "statsmodels-0.14.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afbd92410e0df06f3d8c4e7c0e2e71f63f4969531f280fb66059e2ecdb6e0415"},
{file = "statsmodels-0.14.2-cp311-cp311-win_amd64.whl", hash = "sha256:8e004cfad0e46ce73fe3f3812010c746f0d4cfd48e307b45c14e9e360f3d2510"},
{file = "statsmodels-0.14.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eb0ba1ad3627705f5ae20af6b2982f500546d43892543b36c7bca3e2f87105e7"},
{file = "statsmodels-0.14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90fd2f0110b73fc3fa5a2f21c3ca99b0e22285cccf38e56b5b8fd8ce28791b0f"},
{file = "statsmodels-0.14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac780ad9ff552773798829a0b9c46820b0faa10e6454891f5e49a845123758ab"},
{file = "statsmodels-0.14.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55d1742778400ae67acb04b50a2c7f5804182f8a874bd09ca397d69ed159a751"},
{file = "statsmodels-0.14.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f870d14a587ea58a3b596aa994c2ed889cc051f9e450e887d2c83656fc6a64bf"},
{file = "statsmodels-0.14.2-cp312-cp312-win_amd64.whl", hash = "sha256:f450fcbae214aae66bd9d2b9af48e0f8ba1cb0e8596c6ebb34e6e3f0fec6542c"},
{file = "statsmodels-0.14.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:201c3d00929c4a67cda1fe05b098c8dcf1b1eeefa88e80a8f963a844801ed59f"},
{file = "statsmodels-0.14.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9edefa4ce08e40bc1d67d2f79bc686ee5e238e801312b5a029ee7786448c389a"},
{file = "statsmodels-0.14.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c78a7601fdae1aa32104c5ebff2e0b72c26f33e870e2f94ab1bcfd927ece9b"},
{file = "statsmodels-0.14.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f36494df7c03d63168fccee5038a62f469469ed6a4dd6eaeb9338abedcd0d5f5"},
{file = "statsmodels-0.14.2-cp39-cp39-win_amd64.whl", hash = "sha256:8875823bdd41806dc853333cc4e1b7ef9481bad2380a999e66ea42382cf2178d"},
{file = "statsmodels-0.14.2.tar.gz", hash = "sha256:890550147ad3a81cda24f0ba1a5c4021adc1601080bd00e191ae7cd6feecd6ad"},
]
[package.dependencies]
numpy = ">=1.22.3"
packaging = ">=21.3"
pandas = ">=1.4,<2.1.0 || >2.1.0"
patsy = ">=0.5.6"
scipy = ">=1.8,<1.9.2 || >1.9.2"
[package.extras]
build = ["cython (>=0.29.33)"]
develop = ["colorama", "cython (>=0.29.33)", "cython (>=3.0.10,<4)", "flake8", "isort", "joblib", "matplotlib (>=3)", "pytest (>=7.3.0,<8)", "pytest-cov", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=8.0,<9.0)"]
docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"]
[[package]] [[package]]
name = "swig" name = "swig"
version = "4.2.0" version = "4.2.0"
@ -6530,13 +6416,13 @@ files = [
[[package]] [[package]]
name = "tqdm" name = "tqdm"
version = "4.66.1" version = "4.66.3"
description = "Fast, Extensible Progress Meter" description = "Fast, Extensible Progress Meter"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, {file = "tqdm-4.66.3-py3-none-any.whl", hash = "sha256:4f41d54107ff9a223dca80b53efe4fb654c67efaba7f47bada3ee9d50e05bd53"},
{file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, {file = "tqdm-4.66.3.tar.gz", hash = "sha256:23097a41eba115ba99ecae40d06444c15d1c0c698d527a01c6c8bd1c5d0647e5"},
] ]
[package.dependencies] [package.dependencies]

@ -172,18 +172,19 @@ class ReplayBuffer:
return np.array([last] if not self.done[last] and self._size else [], int) return np.array([last] if not self.done[last] and self._size else [], int)
def prev(self, index: int | np.ndarray) -> np.ndarray: def prev(self, index: int | np.ndarray) -> np.ndarray:
"""Return the index of previous transition. """Return the index of preceding step within the same episode if it exists.
If it does not exist (because it is the first index within the episode),
The index won't be modified if it is the beginning of an episode. the index remains unmodified.
""" """
index = (index - 1) % self._size index = (index - 1) % self._size # compute preceding index with wrap-around
# end_flag will be 1 if the previous index is the last step of an episode or
# if it is the very last index of the buffer (wrap-around case), and 0 otherwise
end_flag = self.done[index] | (index == self.last_index[0]) end_flag = self.done[index] | (index == self.last_index[0])
return (index + end_flag) % self._size return (index + end_flag) % self._size
def next(self, index: int | np.ndarray) -> np.ndarray: def next(self, index: int | np.ndarray) -> np.ndarray:
"""Return the index of next transition. """Return the index of next step if there is a next step within the episode.
If there isn't a next step, the index remains unmodified.
The index won't be modified if it is the end of an episode.
""" """
end_flag = self.done[index] | (index == self.last_index[0]) end_flag = self.done[index] | (index == self.last_index[0])
return (index + (1 - end_flag)) % self._size return (index + (1 - end_flag)) % self._size

@ -118,9 +118,12 @@ class SamplingConfig(ToStringMixin):
replay_buffer_ignore_obs_next: bool = False replay_buffer_ignore_obs_next: bool = False
replay_buffer_save_only_last_obs: bool = False replay_buffer_save_only_last_obs: bool = False
"""if True, only the most recent frame is saved when appending to experiences rather than the """if True, for the case where the environment outputs stacked frames (e.g. because it
full stacked frames. This avoids duplicating observations in buffer memory. Set to False to is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate
save stacked frames in full. observations in buffer memory. Specifically, if the environment outputs observations `obs` with
shape (N, ...), only obs[-1] of shape (...) will be stored.
Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting
:attr:`replay_buffer_stack_num`.
""" """
replay_buffer_stack_num: int = 1 replay_buffer_stack_num: int = 1
@ -128,6 +131,9 @@ class SamplingConfig(ToStringMixin):
the number of consecutive environment observations to stack and use as the observation input the number of consecutive environment observations to stack and use as the observation input
to the agent for each time step. Setting this to a value greater than 1 can help agents learn to the agent for each time step. Setting this to a value greater than 1 can help agents learn
temporal aspects (e.g. velocities of moving objects for which only positions are observed). temporal aspects (e.g. velocities of moving objects for which only positions are observed).
If the environment already stacks frames (e.g. using a `FrameStack` wrapper), this should either not
be used or should be used in conjunction with :attr:`replay_buffer_save_only_last_obs`.
""" """
@property @property

@ -197,7 +197,11 @@ class CriticFactoryReuseActor(CriticFactory):
last_size=last_size, last_size=last_size,
).to(device) ).to(device)
elif envs.get_type().is_continuous(): elif envs.get_type().is_continuous():
return continuous.Critic(actor.get_preprocess_net(), device=device).to(device) return continuous.Critic(
actor.get_preprocess_net(),
device=device,
apply_preprocess_net_to_obs_only=True,
).to(device)
else: else:
raise ValueError raise ValueError

@ -323,9 +323,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys: :return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \ * ``act`` a numpy.ndarray or a torch.Tensor, the action over \
given batch data. given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \ * ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default. internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example, Other keys are user-defined. It depends on the algorithm. For example,
@ -587,19 +587,23 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
advantage + value, which is exactly equivalent to using :math:`TD(\lambda)` advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
for estimating returns. for estimating returns.
Setting `v_s_` and `v_s` to None (or all zeros) and `gae_lambda` to 1.0 calculates the
discounted return-to-go/ Monte-Carlo return.
:param batch: a data batch which contains several episodes of data in :param batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be should be marked by done flag, unfinished (or collecting) episodes will be
recognized by buffer.unfinished_index(). recognized by buffer.unfinished_index().
:param buffer: the corresponding replay buffer. :param buffer: the corresponding replay buffer.
:param numpy.ndarray indices: tell batch's location in buffer, batch is equal :param indices: tells the batch's location in buffer, batch is equal
to buffer[indices]. to buffer[indices].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`. :param v_s_: the value function of all next states :math:`V(s')`.
If None, it will be set to an array of 0. If None, it will be set to an array of 0.
:param v_s: the value function of all current states :math:`V(s)`. :param v_s: the value function of all current states :math:`V(s)`. If None,
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99. it is set based upon `v_s_` rolled by 1.
:param gamma: the discount factor, should be in [0, 1].
:param gae_lambda: the parameter for Generalized Advantage Estimation, :param gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95. should be in [0, 1].
:return: two numpy arrays (returns, advantage) with each shape (bsz, ). :return: two numpy arrays (returns, advantage) with each shape (bsz, ).
""" """
@ -643,10 +647,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
:param indices: tell batch's location in buffer :param indices: tell batch's location in buffer
:param function target_q_fn: a function which compute target Q value :param function target_q_fn: a function which compute target Q value
of "obs_next" given data buffer and wanted indices. of "obs_next" given data buffer and wanted indices.
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param gamma: the discount factor, should be in [0, 1].
:param n_step: the number of estimation step, should be an int greater :param n_step: the number of estimation step, should be an int greater
than 0. Default to 1. than 0.
:param rew_norm: normalize the reward to Normal(0, 1), Default to False. :param rew_norm: normalize the reward to Normal(0, 1).
TODO: passing True is not supported and will cause an error! TODO: passing True is not supported and will cause an error!
:return: a Batch. The result will be stored in batch.returns as a :return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with the same shape as target_q_fn's return tensor. torch.Tensor with the same shape as target_q_fn's return tensor.

@ -15,6 +15,7 @@ from tianshou.utils.net.common import (
TLinearLayer, TLinearLayer,
get_output_dim, get_output_dim,
) )
from tianshou.utils.pickle import setstate
SIGMA_MIN = -20 SIGMA_MIN = -20
SIGMA_MAX = 2 SIGMA_MAX = 2
@ -109,6 +110,9 @@ class Critic(CriticBase):
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
:param linear_layer: use this module as linear layer. :param linear_layer: use this module as linear layer.
:param flatten_input: whether to flatten input data for the last layer. :param flatten_input: whether to flatten input data for the last layer.
:param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before
concatenating with the action) - and without the observations being modified in any way beforehand.
This allows the actor's preprocessing network to be reused for the critic.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
@ -122,11 +126,13 @@ class Critic(CriticBase):
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,
linear_layer: TLinearLayer = nn.Linear, linear_layer: TLinearLayer = nn.Linear,
flatten_input: bool = True, flatten_input: bool = True,
apply_preprocess_net_to_obs_only: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.preprocess = preprocess_net self.preprocess = preprocess_net
self.output_dim = 1 self.output_dim = 1
self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP( self.last = MLP(
input_dim, input_dim,
@ -137,6 +143,14 @@ class Critic(CriticBase):
flatten_input=flatten_input, flatten_input=flatten_input,
) )
def __setstate__(self, state: dict) -> None:
setstate(
Critic,
self,
state,
new_default_properties={"apply_preprocess_net_to_obs_only": False},
)
def forward( def forward(
self, self,
obs: np.ndarray | torch.Tensor, obs: np.ndarray | torch.Tensor,
@ -148,7 +162,10 @@ class Critic(CriticBase):
obs, obs,
device=self.device, device=self.device,
dtype=torch.float32, dtype=torch.float32,
).flatten(1) )
if self.apply_preprocess_net_to_obs_only:
obs, _ = self.preprocess(obs)
obs = obs.flatten(1)
if act is not None: if act is not None:
act = torch.as_tensor( act = torch.as_tensor(
act, act,
@ -156,8 +173,9 @@ class Critic(CriticBase):
dtype=torch.float32, dtype=torch.float32,
).flatten(1) ).flatten(1)
obs = torch.cat([obs, act], dim=1) obs = torch.cat([obs, act], dim=1)
values_B, hidden_BH = self.preprocess(obs) if not self.apply_preprocess_net_to_obs_only:
return self.last(values_B) obs, _ = self.preprocess(obs)
return self.last(obs)
class ActorProb(BaseActor): class ActorProb(BaseActor):

97
tianshou/utils/pickle.py Normal file

@ -0,0 +1,97 @@
"""Helper functions for persistence/pickling, which have been copied from sensAI (specifically `sensai.util.pickle`)."""
from collections.abc import Iterable
from copy import copy
from typing import Any
def setstate(
cls: type,
obj: Any,
state: dict[str, Any],
renamed_properties: dict[str, str] | None = None,
new_optional_properties: list[str] | None = None,
new_default_properties: dict[str, Any] | None = None,
removed_properties: list[str] | None = None,
) -> None:
"""Helper function for safe implementations of `__setstate__` in classes, which appropriately handles the cases where
a parent class already implements `__setstate__` and where it does not. Call this function whenever you would actually
like to call the super-class' implementation.
Unfortunately, `__setstate__` is not implemented in `object`, rendering `super().__setstate__(state)` invalid in the general case.
:param cls: the class in which you are implementing `__setstate__`
:param obj: the instance of `cls`
:param state: the state dictionary
:param renamed_properties: a mapping from old property names to new property names
:param new_optional_properties: a list of names of new property names, which, if not present, shall be initialized with None
:param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present
:param removed_properties: a list of names of properties that are no longer being used
"""
# handle new/changed properties
if renamed_properties is not None:
for mOld, mNew in renamed_properties.items():
if mOld in state:
state[mNew] = state[mOld]
del state[mOld]
if new_optional_properties is not None:
for mNew in new_optional_properties:
if mNew not in state:
state[mNew] = None
if new_default_properties is not None:
for mNew, mValue in new_default_properties.items():
if mNew not in state:
state[mNew] = mValue
if removed_properties is not None:
for p in removed_properties:
if p in state:
del state[p]
# call super implementation, if any
s = super(cls, obj)
if hasattr(s, "__setstate__"):
s.__setstate__(state)
else:
obj.__dict__ = state
def getstate(
cls: type,
obj: Any,
transient_properties: Iterable[str] | None = None,
excluded_properties: Iterable[str] | None = None,
override_properties: dict[str, Any] | None = None,
excluded_default_properties: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Helper function for safe implementations of `__getstate__` in classes, which appropriately handles the cases where
a parent class already implements `__getstate__` and where it does not. Call this function whenever you would actually
like to call the super-class' implementation.
Unfortunately, `__getstate__` is not implemented in `object`, rendering `super().__getstate__()` invalid in the general case.
:param cls: the class in which you are implementing `__getstate__`
:param obj: the instance of `cls`
:param transient_properties: transient properties which shall be set to None in serializations
:param excluded_properties: properties which shall be completely removed from serializations
:param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set;
use this to set a fixed value for an existing property or to add a completely new property
:param excluded_default_properties: properties which shall be completely removed from serializations, if they are set
to the given default value
:return: the state dictionary, which may be modified by the receiver
"""
s = super(cls, obj)
d = s.__getstate__() if hasattr(s, "__getstate__") else obj.__dict__
d = copy(d)
if transient_properties is not None:
for p in transient_properties:
if p in d:
d[p] = None
if excluded_properties is not None:
for p in excluded_properties:
if p in d:
del d[p]
if override_properties is not None:
for k, v in override_properties.items():
d[k] = v
if excluded_default_properties is not None:
for p, v in excluded_default_properties.items():
if p in d and d[p] == v:
del d[p]
return d