Merge branch 'refs/heads/thuml-master' into policy-train-eval
# Conflicts: # CHANGELOG.md
This commit is contained in:
commit
4e38aeb829
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -1,6 +1,6 @@
|
||||
- [ ] I have added the correct label(s) to this Pull Request or linked the relevant issue(s)
|
||||
- [ ] I have provided a description of the changes in this Pull Request
|
||||
- [ ] I have added documentation for my changes
|
||||
- [ ] I have added documentation for my changes and have listed relevant changes in CHANGELOG.md
|
||||
- [ ] If applicable, I have added tests to cover my changes.
|
||||
- [ ] I have reformatted the code using `poe format`
|
||||
- [ ] I have checked style and types with `poe lint` and `poe type-check`
|
||||
|
13
CHANGELOG.md
13
CHANGELOG.md
@ -19,11 +19,24 @@
|
||||
- New `evaluation` package for repeating the same experiment with multiple seeds and aggregating the results (important extension!).
|
||||
Launchers for parallelization currently in alpha state. #1074
|
||||
- Loggers can now restore the logged data into python by using the new `restore_logged_data` method. #1074
|
||||
- `continuous.Critic`:
|
||||
- Add flag `apply_preprocess_net_to_obs_only` to allow the
|
||||
preprocessing network to be applied to the observations only (without
|
||||
the actions concatenated), which is essential for the case where we want
|
||||
to reuse the actor's preprocessing network #1128
|
||||
- Base class for collectors: `BaseCollector` #1122
|
||||
- Collectors can now explicitly specify whether to use the policy in training or evaluation mode. #1122
|
||||
- New util context managers `in_eval_mode` and `in_train_mode` for torch modules. #1122
|
||||
- `reset` of `Collectors` now returns `obs` and `info`. #1122
|
||||
|
||||
### Fixes
|
||||
- `CriticFactoryReuseActor`: Enable the Critic flag `apply_preprocess_net_to_obs_only` for continuous critics,
|
||||
fixing the case where we want to reuse an actor's preprocessing network for the critic (affects usages
|
||||
of the experiment builder method `with_critic_factory_use_actor` with continuous environments) #1128
|
||||
- `atari_network.DQN`:
|
||||
- Fix constructor input validation #1128
|
||||
- Fix `output_dim` not being set if `features_only`=True and `output_dim_added_layer` is not None #1128
|
||||
|
||||
### Internal Improvements
|
||||
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
|
||||
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
|
||||
|
@ -152,7 +152,7 @@
|
||||
"id": "Lh2-hwE5Dn9I"
|
||||
},
|
||||
"source": [
|
||||
"Once we have defined the actor, the critic and the optimizer. We can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
||||
"Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -266,3 +266,8 @@ postfix
|
||||
backend
|
||||
rliable
|
||||
hl
|
||||
v_s
|
||||
v_s_
|
||||
obs
|
||||
obs_next
|
||||
|
||||
|
@ -66,7 +66,7 @@ class DQN(NetBase[Any]):
|
||||
layer_init: Callable[[nn.Module], nn.Module] = lambda x: x,
|
||||
) -> None:
|
||||
# TODO: Add docstring
|
||||
if features_only and output_dim_added_layer is not None:
|
||||
if not features_only and output_dim_added_layer is not None:
|
||||
raise ValueError(
|
||||
"Should not provide explicit output dimension using `output_dim_added_layer` when `features_only` is true.",
|
||||
)
|
||||
@ -98,6 +98,7 @@ class DQN(NetBase[Any]):
|
||||
layer_init(nn.Linear(base_cnn_output_dim, output_dim_added_layer)),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.output_dim = output_dim_added_layer
|
||||
else:
|
||||
self.output_dim = base_cnn_output_dim
|
||||
|
||||
|
134
poetry.lock
generated
134
poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@ -156,7 +156,7 @@ files = [
|
||||
name = "arch"
|
||||
version = "5.3.1"
|
||||
description = "ARCH for Python"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "arch-5.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:75fa6f9386ecc2df81bcbf5d055a290a697482ca51e0b3459dab183d288993cb"},
|
||||
@ -3815,75 +3815,6 @@ sql-other = ["SQLAlchemy (>=1.4.36)"]
|
||||
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"]
|
||||
xml = ["lxml (>=4.8.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "pandas"
|
||||
version = "2.2.2"
|
||||
description = "Powerful data structures for data analysis, time series, and statistics"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8e5a0b00e1e56a842f922e7fae8ae4077aee4af0acb5ae3622bd4b4c30aedf99"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:ddf818e4e6c7c6f4f7c8a12709696d193976b591cc7dc50588d3d1a6b5dc8772"},
|
||||
{file = "pandas-2.2.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:696039430f7a562b74fa45f540aca068ea85fa34c244d0deee539cb6d70aa288"},
|
||||
{file = "pandas-2.2.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e90497254aacacbc4ea6ae5e7a8cd75629d6ad2b30025a4a8b09aa4faf55151"},
|
||||
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58b84b91b0b9f4bafac2a0ac55002280c094dfc6402402332c0913a59654ab2b"},
|
||||
{file = "pandas-2.2.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d2123dc9ad6a814bcdea0f099885276b31b24f7edf40f6cdbc0912672e22eee"},
|
||||
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:2925720037f06e89af896c70bca73459d7e6a4be96f9de79e2d440bd499fe0db"},
|
||||
{file = "pandas-2.2.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0cace394b6ea70c01ca1595f839cf193df35d1575986e484ad35c4aeae7266c1"},
|
||||
{file = "pandas-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:873d13d177501a28b2756375d59816c365e42ed8417b41665f346289adc68d24"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:9dfde2a0ddef507a631dc9dc4af6a9489d5e2e740e226ad426a05cabfbd7c8ef"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e9b79011ff7a0f4b1d6da6a61aa1aa604fb312d6647de5bad20013682d1429ce"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1cb51fe389360f3b5a4d57dbd2848a5f033350336ca3b340d1c53a1fad33bcad"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eee3a87076c0756de40b05c5e9a6069c035ba43e8dd71c379e68cab2c20f16ad"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:3e374f59e440d4ab45ca2fffde54b81ac3834cf5ae2cdfa69c90bc03bde04d76"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92fd6b027924a7e178ac202cfbe25e53368db90d56872d20ffae94b96c7acc57"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:640cef9aa381b60e296db324337a554aeeb883ead99dc8f6c18e81a93942f5f4"},
|
||||
{file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""}
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
tzdata = ">=2022.7"
|
||||
|
||||
[package.extras]
|
||||
all = ["PyQt5 (>=5.15.9)", "SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)", "beautifulsoup4 (>=4.11.2)", "bottleneck (>=1.3.6)", "dataframe-api-compat (>=0.1.7)", "fastparquet (>=2022.12.0)", "fsspec (>=2022.11.0)", "gcsfs (>=2022.11.0)", "html5lib (>=1.1)", "hypothesis (>=6.46.1)", "jinja2 (>=3.1.2)", "lxml (>=4.9.2)", "matplotlib (>=3.6.3)", "numba (>=0.56.4)", "numexpr (>=2.8.4)", "odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "pandas-gbq (>=0.19.0)", "psycopg2 (>=2.9.6)", "pyarrow (>=10.0.1)", "pymysql (>=1.0.2)", "pyreadstat (>=1.2.0)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "qtpy (>=2.3.0)", "s3fs (>=2022.11.0)", "scipy (>=1.10.0)", "tables (>=3.8.0)", "tabulate (>=0.9.0)", "xarray (>=2022.12.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)", "zstandard (>=0.19.0)"]
|
||||
aws = ["s3fs (>=2022.11.0)"]
|
||||
clipboard = ["PyQt5 (>=5.15.9)", "qtpy (>=2.3.0)"]
|
||||
compression = ["zstandard (>=0.19.0)"]
|
||||
computation = ["scipy (>=1.10.0)", "xarray (>=2022.12.0)"]
|
||||
consortium-standard = ["dataframe-api-compat (>=0.1.7)"]
|
||||
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.1.0)", "python-calamine (>=0.1.7)", "pyxlsb (>=1.0.10)", "xlrd (>=2.0.1)", "xlsxwriter (>=3.0.5)"]
|
||||
feather = ["pyarrow (>=10.0.1)"]
|
||||
fss = ["fsspec (>=2022.11.0)"]
|
||||
gcp = ["gcsfs (>=2022.11.0)", "pandas-gbq (>=0.19.0)"]
|
||||
hdf5 = ["tables (>=3.8.0)"]
|
||||
html = ["beautifulsoup4 (>=4.11.2)", "html5lib (>=1.1)", "lxml (>=4.9.2)"]
|
||||
mysql = ["SQLAlchemy (>=2.0.0)", "pymysql (>=1.0.2)"]
|
||||
output-formatting = ["jinja2 (>=3.1.2)", "tabulate (>=0.9.0)"]
|
||||
parquet = ["pyarrow (>=10.0.1)"]
|
||||
performance = ["bottleneck (>=1.3.6)", "numba (>=0.56.4)", "numexpr (>=2.8.4)"]
|
||||
plot = ["matplotlib (>=3.6.3)"]
|
||||
postgresql = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "psycopg2 (>=2.9.6)"]
|
||||
pyarrow = ["pyarrow (>=10.0.1)"]
|
||||
spss = ["pyreadstat (>=1.2.0)"]
|
||||
sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-driver-sqlite (>=0.8.0)"]
|
||||
test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"]
|
||||
xml = ["lxml (>=4.9.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "pandocfilters"
|
||||
version = "1.5.0"
|
||||
@ -3946,7 +3877,7 @@ files = [
|
||||
name = "patsy"
|
||||
version = "0.5.6"
|
||||
description = "A Python package for describing statistical models and for building design matrices."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "patsy-0.5.6-py2.py3-none-any.whl", hash = "sha256:19056886fd8fa71863fa32f0eb090267f21fb74be00f19f5c70b2e9d76c883c6"},
|
||||
@ -4213,7 +4144,7 @@ wcwidth = "*"
|
||||
name = "property-cached"
|
||||
version = "1.6.4"
|
||||
description = "A decorator for caching properties in classes (forked from cached-property)."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">= 3.5"
|
||||
files = [
|
||||
{file = "property-cached-1.6.4.zip", hash = "sha256:3e9c4ef1ed3653909147510481d7df62a3cfb483461a6986a6f1dcd09b2ebb73"},
|
||||
@ -5077,7 +5008,7 @@ files = [
|
||||
name = "rliable"
|
||||
version = "1.0.8"
|
||||
description = "rliable: Reliable evaluation on reinforcement learning and machine learning benchmarks."
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = []
|
||||
develop = false
|
||||
@ -5366,7 +5297,7 @@ test = ["asv", "gmpy2", "mpmath", "pooch", "pytest", "pytest-cov", "pytest-timeo
|
||||
name = "seaborn"
|
||||
version = "0.13.2"
|
||||
description = "Statistical data visualization"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"},
|
||||
@ -6214,7 +6145,7 @@ tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
|
||||
name = "statsmodels"
|
||||
version = "0.14.0"
|
||||
description = "Statistical computations and models for Python"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "statsmodels-0.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16bfe0c96a53b20fa19067e3b6bd2f1d39e30d4891ea0d7bc20734a0ae95942d"},
|
||||
@ -6260,51 +6191,6 @@ build = ["cython (>=0.29.26)"]
|
||||
develop = ["colorama", "cython (>=0.29.26)", "cython (>=0.29.28,<3.0.0)", "flake8", "isort", "joblib", "matplotlib (>=3)", "oldest-supported-numpy (>=2022.4.18)", "pytest (>=7.0.1,<7.1.0)", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=7.0.0,<7.1.0)"]
|
||||
docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"]
|
||||
|
||||
[[package]]
|
||||
name = "statsmodels"
|
||||
version = "0.14.2"
|
||||
description = "Statistical computations and models for Python"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "statsmodels-0.14.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df5d6f95c46f0341da6c79ee7617e025bf2b371e190a8e60af1ae9cabbdb7a97"},
|
||||
{file = "statsmodels-0.14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a87ef21fadb445b650f327340dde703f13aec1540f3d497afb66324499dea97a"},
|
||||
{file = "statsmodels-0.14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5827a12e3ede2b98a784476d61d6bec43011fedb64aa815f2098e0573bece257"},
|
||||
{file = "statsmodels-0.14.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10f2b7611a61adb7d596a6d239abdf1a4d5492b931b00d5ed23d32844d40e48e"},
|
||||
{file = "statsmodels-0.14.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c254c66142f1167b4c7d031cf8db55294cc62ff3280e090fc45bd10a7f5fd029"},
|
||||
{file = "statsmodels-0.14.2-cp310-cp310-win_amd64.whl", hash = "sha256:0e46e9d59293c1af4cc1f4e5248f17e7e7bc596bfce44d327c789ac27f09111b"},
|
||||
{file = "statsmodels-0.14.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:50fcb633987779e795142f51ba49fb27648d46e8a1382b32ebe8e503aaabaa9e"},
|
||||
{file = "statsmodels-0.14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:876794068abfaeed41df71b7887000031ecf44fbfa6b50d53ccb12ebb4ab747a"},
|
||||
{file = "statsmodels-0.14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a91f6c4943de13e3ce2e20ee3b5d26d02bd42300616a421becd53756f5deb37"},
|
||||
{file = "statsmodels-0.14.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4864a1c4615c5ea5f2e3b078a75bdedc90dd9da210a37e0738e064b419eccee2"},
|
||||
{file = "statsmodels-0.14.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:afbd92410e0df06f3d8c4e7c0e2e71f63f4969531f280fb66059e2ecdb6e0415"},
|
||||
{file = "statsmodels-0.14.2-cp311-cp311-win_amd64.whl", hash = "sha256:8e004cfad0e46ce73fe3f3812010c746f0d4cfd48e307b45c14e9e360f3d2510"},
|
||||
{file = "statsmodels-0.14.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:eb0ba1ad3627705f5ae20af6b2982f500546d43892543b36c7bca3e2f87105e7"},
|
||||
{file = "statsmodels-0.14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90fd2f0110b73fc3fa5a2f21c3ca99b0e22285cccf38e56b5b8fd8ce28791b0f"},
|
||||
{file = "statsmodels-0.14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac780ad9ff552773798829a0b9c46820b0faa10e6454891f5e49a845123758ab"},
|
||||
{file = "statsmodels-0.14.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55d1742778400ae67acb04b50a2c7f5804182f8a874bd09ca397d69ed159a751"},
|
||||
{file = "statsmodels-0.14.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f870d14a587ea58a3b596aa994c2ed889cc051f9e450e887d2c83656fc6a64bf"},
|
||||
{file = "statsmodels-0.14.2-cp312-cp312-win_amd64.whl", hash = "sha256:f450fcbae214aae66bd9d2b9af48e0f8ba1cb0e8596c6ebb34e6e3f0fec6542c"},
|
||||
{file = "statsmodels-0.14.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:201c3d00929c4a67cda1fe05b098c8dcf1b1eeefa88e80a8f963a844801ed59f"},
|
||||
{file = "statsmodels-0.14.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9edefa4ce08e40bc1d67d2f79bc686ee5e238e801312b5a029ee7786448c389a"},
|
||||
{file = "statsmodels-0.14.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c78a7601fdae1aa32104c5ebff2e0b72c26f33e870e2f94ab1bcfd927ece9b"},
|
||||
{file = "statsmodels-0.14.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f36494df7c03d63168fccee5038a62f469469ed6a4dd6eaeb9338abedcd0d5f5"},
|
||||
{file = "statsmodels-0.14.2-cp39-cp39-win_amd64.whl", hash = "sha256:8875823bdd41806dc853333cc4e1b7ef9481bad2380a999e66ea42382cf2178d"},
|
||||
{file = "statsmodels-0.14.2.tar.gz", hash = "sha256:890550147ad3a81cda24f0ba1a5c4021adc1601080bd00e191ae7cd6feecd6ad"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.22.3"
|
||||
packaging = ">=21.3"
|
||||
pandas = ">=1.4,<2.1.0 || >2.1.0"
|
||||
patsy = ">=0.5.6"
|
||||
scipy = ">=1.8,<1.9.2 || >1.9.2"
|
||||
|
||||
[package.extras]
|
||||
build = ["cython (>=0.29.33)"]
|
||||
develop = ["colorama", "cython (>=0.29.33)", "cython (>=3.0.10,<4)", "flake8", "isort", "joblib", "matplotlib (>=3)", "pytest (>=7.3.0,<8)", "pytest-cov", "pytest-randomly", "pytest-xdist", "pywinpty", "setuptools-scm[toml] (>=8.0,<9.0)"]
|
||||
docs = ["ipykernel", "jupyter-client", "matplotlib", "nbconvert", "nbformat", "numpydoc", "pandas-datareader", "sphinx"]
|
||||
|
||||
[[package]]
|
||||
name = "swig"
|
||||
version = "4.2.0"
|
||||
@ -6530,13 +6416,13 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.66.1"
|
||||
version = "4.66.3"
|
||||
description = "Fast, Extensible Progress Meter"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"},
|
||||
{file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"},
|
||||
{file = "tqdm-4.66.3-py3-none-any.whl", hash = "sha256:4f41d54107ff9a223dca80b53efe4fb654c67efaba7f47bada3ee9d50e05bd53"},
|
||||
{file = "tqdm-4.66.3.tar.gz", hash = "sha256:23097a41eba115ba99ecae40d06444c15d1c0c698d527a01c6c8bd1c5d0647e5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -172,18 +172,19 @@ class ReplayBuffer:
|
||||
return np.array([last] if not self.done[last] and self._size else [], int)
|
||||
|
||||
def prev(self, index: int | np.ndarray) -> np.ndarray:
|
||||
"""Return the index of previous transition.
|
||||
|
||||
The index won't be modified if it is the beginning of an episode.
|
||||
"""Return the index of preceding step within the same episode if it exists.
|
||||
If it does not exist (because it is the first index within the episode),
|
||||
the index remains unmodified.
|
||||
"""
|
||||
index = (index - 1) % self._size
|
||||
index = (index - 1) % self._size # compute preceding index with wrap-around
|
||||
# end_flag will be 1 if the previous index is the last step of an episode or
|
||||
# if it is the very last index of the buffer (wrap-around case), and 0 otherwise
|
||||
end_flag = self.done[index] | (index == self.last_index[0])
|
||||
return (index + end_flag) % self._size
|
||||
|
||||
def next(self, index: int | np.ndarray) -> np.ndarray:
|
||||
"""Return the index of next transition.
|
||||
|
||||
The index won't be modified if it is the end of an episode.
|
||||
"""Return the index of next step if there is a next step within the episode.
|
||||
If there isn't a next step, the index remains unmodified.
|
||||
"""
|
||||
end_flag = self.done[index] | (index == self.last_index[0])
|
||||
return (index + (1 - end_flag)) % self._size
|
||||
|
@ -118,9 +118,12 @@ class SamplingConfig(ToStringMixin):
|
||||
replay_buffer_ignore_obs_next: bool = False
|
||||
|
||||
replay_buffer_save_only_last_obs: bool = False
|
||||
"""if True, only the most recent frame is saved when appending to experiences rather than the
|
||||
full stacked frames. This avoids duplicating observations in buffer memory. Set to False to
|
||||
save stacked frames in full.
|
||||
"""if True, for the case where the environment outputs stacked frames (e.g. because it
|
||||
is using a `FrameStack` wrapper), save only the most recent frame so as not to duplicate
|
||||
observations in buffer memory. Specifically, if the environment outputs observations `obs` with
|
||||
shape (N, ...), only obs[-1] of shape (...) will be stored.
|
||||
Frame stacking with a fixed number of frames can then be recreated at the buffer level by setting
|
||||
:attr:`replay_buffer_stack_num`.
|
||||
"""
|
||||
|
||||
replay_buffer_stack_num: int = 1
|
||||
@ -128,6 +131,9 @@ class SamplingConfig(ToStringMixin):
|
||||
the number of consecutive environment observations to stack and use as the observation input
|
||||
to the agent for each time step. Setting this to a value greater than 1 can help agents learn
|
||||
temporal aspects (e.g. velocities of moving objects for which only positions are observed).
|
||||
|
||||
If the environment already stacks frames (e.g. using a `FrameStack` wrapper), this should either not
|
||||
be used or should be used in conjunction with :attr:`replay_buffer_save_only_last_obs`.
|
||||
"""
|
||||
|
||||
@property
|
||||
|
@ -197,7 +197,11 @@ class CriticFactoryReuseActor(CriticFactory):
|
||||
last_size=last_size,
|
||||
).to(device)
|
||||
elif envs.get_type().is_continuous():
|
||||
return continuous.Critic(actor.get_preprocess_net(), device=device).to(device)
|
||||
return continuous.Critic(
|
||||
actor.get_preprocess_net(),
|
||||
device=device,
|
||||
apply_preprocess_net_to_obs_only=True,
|
||||
).to(device)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
@ -323,9 +323,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
|
||||
|
||||
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
|
||||
* ``act`` a numpy.ndarray or a torch.Tensor, the action over \
|
||||
given batch data.
|
||||
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
|
||||
* ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \
|
||||
internal state of the policy, ``None`` as default.
|
||||
|
||||
Other keys are user-defined. It depends on the algorithm. For example,
|
||||
@ -587,19 +587,23 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
|
||||
for estimating returns.
|
||||
|
||||
Setting `v_s_` and `v_s` to None (or all zeros) and `gae_lambda` to 1.0 calculates the
|
||||
discounted return-to-go/ Monte-Carlo return.
|
||||
|
||||
:param batch: a data batch which contains several episodes of data in
|
||||
sequential order. Mind that the end of each finished episode of batch
|
||||
should be marked by done flag, unfinished (or collecting) episodes will be
|
||||
recognized by buffer.unfinished_index().
|
||||
:param buffer: the corresponding replay buffer.
|
||||
:param numpy.ndarray indices: tell batch's location in buffer, batch is equal
|
||||
:param indices: tells the batch's location in buffer, batch is equal
|
||||
to buffer[indices].
|
||||
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
||||
:param v_s_: the value function of all next states :math:`V(s')`.
|
||||
If None, it will be set to an array of 0.
|
||||
:param v_s: the value function of all current states :math:`V(s)`.
|
||||
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
||||
:param v_s: the value function of all current states :math:`V(s)`. If None,
|
||||
it is set based upon `v_s_` rolled by 1.
|
||||
:param gamma: the discount factor, should be in [0, 1].
|
||||
:param gae_lambda: the parameter for Generalized Advantage Estimation,
|
||||
should be in [0, 1]. Default to 0.95.
|
||||
should be in [0, 1].
|
||||
|
||||
:return: two numpy arrays (returns, advantage) with each shape (bsz, ).
|
||||
"""
|
||||
@ -643,10 +647,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
:param indices: tell batch's location in buffer
|
||||
:param function target_q_fn: a function which compute target Q value
|
||||
of "obs_next" given data buffer and wanted indices.
|
||||
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
||||
:param gamma: the discount factor, should be in [0, 1].
|
||||
:param n_step: the number of estimation step, should be an int greater
|
||||
than 0. Default to 1.
|
||||
:param rew_norm: normalize the reward to Normal(0, 1), Default to False.
|
||||
than 0.
|
||||
:param rew_norm: normalize the reward to Normal(0, 1).
|
||||
TODO: passing True is not supported and will cause an error!
|
||||
:return: a Batch. The result will be stored in batch.returns as a
|
||||
torch.Tensor with the same shape as target_q_fn's return tensor.
|
||||
|
@ -15,6 +15,7 @@ from tianshou.utils.net.common import (
|
||||
TLinearLayer,
|
||||
get_output_dim,
|
||||
)
|
||||
from tianshou.utils.pickle import setstate
|
||||
|
||||
SIGMA_MIN = -20
|
||||
SIGMA_MAX = 2
|
||||
@ -109,6 +110,9 @@ class Critic(CriticBase):
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
:param linear_layer: use this module as linear layer.
|
||||
:param flatten_input: whether to flatten input data for the last layer.
|
||||
:param apply_preprocess_net_to_obs_only: whether to apply `preprocess_net` to the observations only (before
|
||||
concatenating with the action) - and without the observations being modified in any way beforehand.
|
||||
This allows the actor's preprocessing network to be reused for the critic.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
@ -122,11 +126,13 @@ class Critic(CriticBase):
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
linear_layer: TLinearLayer = nn.Linear,
|
||||
flatten_input: bool = True,
|
||||
apply_preprocess_net_to_obs_only: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = 1
|
||||
self.apply_preprocess_net_to_obs_only = apply_preprocess_net_to_obs_only
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim,
|
||||
@ -137,6 +143,14 @@ class Critic(CriticBase):
|
||||
flatten_input=flatten_input,
|
||||
)
|
||||
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
setstate(
|
||||
Critic,
|
||||
self,
|
||||
state,
|
||||
new_default_properties={"apply_preprocess_net_to_obs_only": False},
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
@ -148,7 +162,10 @@ class Critic(CriticBase):
|
||||
obs,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
)
|
||||
if self.apply_preprocess_net_to_obs_only:
|
||||
obs, _ = self.preprocess(obs)
|
||||
obs = obs.flatten(1)
|
||||
if act is not None:
|
||||
act = torch.as_tensor(
|
||||
act,
|
||||
@ -156,8 +173,9 @@ class Critic(CriticBase):
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
values_B, hidden_BH = self.preprocess(obs)
|
||||
return self.last(values_B)
|
||||
if not self.apply_preprocess_net_to_obs_only:
|
||||
obs, _ = self.preprocess(obs)
|
||||
return self.last(obs)
|
||||
|
||||
|
||||
class ActorProb(BaseActor):
|
||||
|
97
tianshou/utils/pickle.py
Normal file
97
tianshou/utils/pickle.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Helper functions for persistence/pickling, which have been copied from sensAI (specifically `sensai.util.pickle`)."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
|
||||
def setstate(
|
||||
cls: type,
|
||||
obj: Any,
|
||||
state: dict[str, Any],
|
||||
renamed_properties: dict[str, str] | None = None,
|
||||
new_optional_properties: list[str] | None = None,
|
||||
new_default_properties: dict[str, Any] | None = None,
|
||||
removed_properties: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Helper function for safe implementations of `__setstate__` in classes, which appropriately handles the cases where
|
||||
a parent class already implements `__setstate__` and where it does not. Call this function whenever you would actually
|
||||
like to call the super-class' implementation.
|
||||
Unfortunately, `__setstate__` is not implemented in `object`, rendering `super().__setstate__(state)` invalid in the general case.
|
||||
|
||||
:param cls: the class in which you are implementing `__setstate__`
|
||||
:param obj: the instance of `cls`
|
||||
:param state: the state dictionary
|
||||
:param renamed_properties: a mapping from old property names to new property names
|
||||
:param new_optional_properties: a list of names of new property names, which, if not present, shall be initialized with None
|
||||
:param new_default_properties: a dictionary mapping property names to their default values, which shall be added if they are not present
|
||||
:param removed_properties: a list of names of properties that are no longer being used
|
||||
"""
|
||||
# handle new/changed properties
|
||||
if renamed_properties is not None:
|
||||
for mOld, mNew in renamed_properties.items():
|
||||
if mOld in state:
|
||||
state[mNew] = state[mOld]
|
||||
del state[mOld]
|
||||
if new_optional_properties is not None:
|
||||
for mNew in new_optional_properties:
|
||||
if mNew not in state:
|
||||
state[mNew] = None
|
||||
if new_default_properties is not None:
|
||||
for mNew, mValue in new_default_properties.items():
|
||||
if mNew not in state:
|
||||
state[mNew] = mValue
|
||||
if removed_properties is not None:
|
||||
for p in removed_properties:
|
||||
if p in state:
|
||||
del state[p]
|
||||
# call super implementation, if any
|
||||
s = super(cls, obj)
|
||||
if hasattr(s, "__setstate__"):
|
||||
s.__setstate__(state)
|
||||
else:
|
||||
obj.__dict__ = state
|
||||
|
||||
|
||||
def getstate(
|
||||
cls: type,
|
||||
obj: Any,
|
||||
transient_properties: Iterable[str] | None = None,
|
||||
excluded_properties: Iterable[str] | None = None,
|
||||
override_properties: dict[str, Any] | None = None,
|
||||
excluded_default_properties: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Helper function for safe implementations of `__getstate__` in classes, which appropriately handles the cases where
|
||||
a parent class already implements `__getstate__` and where it does not. Call this function whenever you would actually
|
||||
like to call the super-class' implementation.
|
||||
Unfortunately, `__getstate__` is not implemented in `object`, rendering `super().__getstate__()` invalid in the general case.
|
||||
|
||||
:param cls: the class in which you are implementing `__getstate__`
|
||||
:param obj: the instance of `cls`
|
||||
:param transient_properties: transient properties which shall be set to None in serializations
|
||||
:param excluded_properties: properties which shall be completely removed from serializations
|
||||
:param override_properties: a mapping from property names to values specifying (new or existing) properties which are to be set;
|
||||
use this to set a fixed value for an existing property or to add a completely new property
|
||||
:param excluded_default_properties: properties which shall be completely removed from serializations, if they are set
|
||||
to the given default value
|
||||
:return: the state dictionary, which may be modified by the receiver
|
||||
"""
|
||||
s = super(cls, obj)
|
||||
d = s.__getstate__() if hasattr(s, "__getstate__") else obj.__dict__
|
||||
d = copy(d)
|
||||
if transient_properties is not None:
|
||||
for p in transient_properties:
|
||||
if p in d:
|
||||
d[p] = None
|
||||
if excluded_properties is not None:
|
||||
for p in excluded_properties:
|
||||
if p in d:
|
||||
del d[p]
|
||||
if override_properties is not None:
|
||||
for k, v in override_properties.items():
|
||||
d[k] = v
|
||||
if excluded_default_properties is not None:
|
||||
for p, v in excluded_default_properties.items():
|
||||
if p in d and d[p] == v:
|
||||
del d[p]
|
||||
return d
|
Loading…
x
Reference in New Issue
Block a user