Add calibration to CQL as in CalQL paper arXiv:2303.05479 (#915)
- [X] I have marked all applicable categories: + [ ] exception-raising fix + [ ] algorithm implementation fix + [ ] documentation modification + [X] new feature - [X] I have reformatted the code using `make format` (**required**) - [X] I have checked the code using `make commit-checks` (**required**) - [X] If applicable, I have mentioned the relevant/related issue(s) - [X] If applicable, I have listed every items in this Pull Request below
This commit is contained in:
parent
6449a43261
commit
c30b4abb8f
@ -22,41 +22,182 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2")
|
||||
parser.add_argument("--buffer-size", type=int, default=1000000)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||
parser.add_argument("--actor-lr", type=float, default=1e-4)
|
||||
parser.add_argument("--critic-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--alpha", type=float, default=0.2)
|
||||
parser.add_argument("--auto-alpha", default=True, action="store_true")
|
||||
parser.add_argument("--alpha-lr", type=float, default=1e-4)
|
||||
parser.add_argument("--cql-alpha-lr", type=float, default=3e-4)
|
||||
parser.add_argument("--start-timesteps", type=int, default=10000)
|
||||
parser.add_argument("--epoch", type=int, default=200)
|
||||
parser.add_argument("--step-per-epoch", type=int, default=5000)
|
||||
parser.add_argument("--n-step", type=int, default=3)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
|
||||
parser.add_argument("--tau", type=float, default=0.005)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--cql-weight", type=float, default=1.0)
|
||||
parser.add_argument("--with-lagrange", type=bool, default=True)
|
||||
parser.add_argument("--lagrange-threshold", type=float, default=10.0)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
|
||||
parser.add_argument("--eval-freq", type=int, default=1)
|
||||
parser.add_argument("--test-num", type=int, default=10)
|
||||
parser.add_argument("--logdir", type=str, default="log")
|
||||
parser.add_argument("--render", type=float, default=1 / 35)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="Hopper-v2",
|
||||
help="The name of the OpenAI Gym environment to train on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The random seed to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expert-data-task",
|
||||
type=str,
|
||||
default="hopper-expert-v2",
|
||||
help="The name of the OpenAI Gym environment to use for expert data collection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--buffer-size",
|
||||
type=int,
|
||||
default=1000000,
|
||||
help="The size of the replay buffer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden-sizes",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[256, 256],
|
||||
help="The list of hidden sizes for the neural networks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--actor-lr",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="The learning rate for the actor network.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--critic-lr",
|
||||
type=float,
|
||||
default=3e-4,
|
||||
help="The learning rate for the critic network.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="The weight of the entropy term in the loss function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--auto-alpha",
|
||||
default=True,
|
||||
action="store_true",
|
||||
help="Whether to use automatic entropy tuning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha-lr",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="The learning rate for the entropy tuning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cql-alpha-lr",
|
||||
type=float,
|
||||
default=3e-4,
|
||||
help="The learning rate for the CQL entropy tuning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-timesteps",
|
||||
type=int,
|
||||
default=10000,
|
||||
help="The number of timesteps before starting to train.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=200,
|
||||
help="The number of epochs to train for.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--step-per-epoch",
|
||||
type=int,
|
||||
default=5000,
|
||||
help="The number of steps per epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-step",
|
||||
type=int,
|
||||
default=3,
|
||||
help="The number of steps to use for N-step TD learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=256,
|
||||
help="The batch size for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tau",
|
||||
type=float,
|
||||
default=0.005,
|
||||
help="The soft target update coefficient.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The temperature for the Boltzmann policy.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cql-weight",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The weight of the CQL loss term.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with-lagrange",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to use the Lagrange multiplier for CQL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--calibrated",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to use calibration for CQL.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lagrange-threshold",
|
||||
type=float,
|
||||
default=10.0,
|
||||
help="The Lagrange multiplier threshold for CQL.",
|
||||
)
|
||||
parser.add_argument("--gamma", type=float, default=0.99, help="The discount factor")
|
||||
parser.add_argument(
|
||||
"--eval-freq",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The frequency of evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-num",
|
||||
type=int,
|
||||
default=10,
|
||||
help="The number of episodes to evaluate for.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logdir",
|
||||
type=str,
|
||||
default="log",
|
||||
help="The directory to save logs to.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--render",
|
||||
type=float,
|
||||
default=1 / 35,
|
||||
help="The frequency of rendering the environment.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="The device to train on (cpu or cuda).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the checkpoint to resume from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The ID of the checkpoint to resume from.",
|
||||
)
|
||||
parser.add_argument("--resume-path", type=str, default=None)
|
||||
parser.add_argument("--resume-id", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--logger",
|
||||
type=str,
|
||||
@ -145,6 +286,8 @@ def test_cql():
|
||||
critic1_optim,
|
||||
critic2,
|
||||
critic2_optim,
|
||||
calibrated=args.calibrated,
|
||||
action_space=env.action_space,
|
||||
cql_alpha_lr=args.cql_alpha_lr,
|
||||
cql_weight=args.cql_weight,
|
||||
tau=args.tau,
|
||||
|
44
poetry.lock
generated
44
poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@ -1238,6 +1238,16 @@ files = [
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"},
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"},
|
||||
{file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"},
|
||||
{file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"},
|
||||
{file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"},
|
||||
@ -1731,14 +1741,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
|
||||
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
|
||||
{version = ">=1.23.5", markers = "python_version >= \"3.11\""},
|
||||
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
|
||||
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
|
||||
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
|
||||
]
|
||||
numpy = {version = ">=1.23.5", markers = "python_version >= \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "optree"
|
||||
@ -1789,6 +1792,17 @@ benchmark = ["dm-tree (>=0.1,<0.2.0a0)", "jax[cpu] (>=0.4.6,<0.5.0a0)", "pandas"
|
||||
lint = ["black (>=22.6.0)", "cpplint", "doc8 (<1.0.0a0)", "flake8", "flake8-bugbear", "flake8-comprehensions", "flake8-docstrings", "flake8-pyi", "flake8-simplify", "isort (>=5.11.0)", "mypy (>=0.990)", "pre-commit", "pydocstyle", "pyenchant", "pylint[spelling] (>=2.15.0)", "ruff", "xdoctest"]
|
||||
test = ["pytest", "pytest-cov", "pytest-xdist"]
|
||||
|
||||
[[package]]
|
||||
name = "overrides"
|
||||
version = "7.4.0"
|
||||
description = "A decorator to automatically detect mismatch when overriding a method."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "overrides-7.4.0-py3-none-any.whl", hash = "sha256:3ad24583f86d6d7a49049695efe9933e67ba62f0c7625d53c59fa832ce4b8b7d"},
|
||||
{file = "overrides-7.4.0.tar.gz", hash = "sha256:9502a3cca51f4fac40b5feca985b6703a5c1f6ad815588a7ca9e285b9dca6757"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "23.1"
|
||||
@ -2374,6 +2388,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"},
|
||||
{file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"},
|
||||
@ -2381,8 +2396,15 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"},
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"},
|
||||
{file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"},
|
||||
@ -2399,6 +2421,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"},
|
||||
{file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"},
|
||||
@ -2406,6 +2429,7 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"},
|
||||
{file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"},
|
||||
{file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},
|
||||
@ -3431,4 +3455,4 @@ pybullet = ["pybullet"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "aa8f433504ed81ea43b66f7760f0fccd3cb89895c17a1f5170bfdda67969c275"
|
||||
content-hash = "c3dfcadc09636fdc28b9350cbe7d1c65fd87d51c72a2c094ed2c2258e26d0722"
|
||||
|
@ -31,6 +31,7 @@ gymnasium = "^0.29.0"
|
||||
h5py = "^3.9.0"
|
||||
numba = "^0.57.1"
|
||||
numpy = "^1"
|
||||
overrides = "^7.4.0"
|
||||
packaging = "*"
|
||||
pettingzoo = "^1.22"
|
||||
tensorboard = "^2.5.0"
|
||||
|
@ -185,14 +185,6 @@ def test_cql(args=get_args()):
|
||||
def stop_fn(mean_rewards):
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
def watch():
|
||||
policy.load_state_dict(
|
||||
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
|
||||
)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
|
||||
# trainer
|
||||
trainer = OfflineTrainer(
|
||||
policy=policy,
|
||||
|
@ -418,6 +418,8 @@ class Batch(BatchProtocol):
|
||||
batch_dict = cast(Sequence[dict | BatchProtocol], batch_dict)
|
||||
self.stack_(batch_dict)
|
||||
if len(kwargs) > 0:
|
||||
# TODO: that's a rather weird pattern, is it really needed?
|
||||
# Feels like kwargs could be just merged into batch_dict in the beginning
|
||||
self.__init__(kwargs, copy=copy) # type: ignore
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Self, cast
|
||||
from typing import Any, Self, TypeVar, cast
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@ -8,6 +8,8 @@ from tianshou.data.batch import alloc_by_keys_diff, create_value
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.data.utils.converter import from_hdf5, to_hdf5
|
||||
|
||||
TBuffer = TypeVar("TBuffer", bound="ReplayBuffer")
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
""":class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment.
|
||||
|
@ -1,5 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -159,10 +159,12 @@ class HERReplayBuffer(ReplayBuffer):
|
||||
future_obs = self[future_t[unique_ep_close_indices]].obs_next
|
||||
else:
|
||||
future_obs = self[self.next(future_t[unique_ep_close_indices])].obs
|
||||
future_obs = cast(BatchProtocol, future_obs)
|
||||
|
||||
# Re-assign goals and rewards via broadcast assignment
|
||||
ep_obs.desired_goal[:, her_ep_indices] = future_obs.achieved_goal[None, her_ep_indices]
|
||||
if self._save_obs_next:
|
||||
ep_obs_next = cast(BatchProtocol, ep_obs_next)
|
||||
ep_obs_next.desired_goal[:, her_ep_indices] = future_obs.achieved_goal[
|
||||
None,
|
||||
her_ep_indices,
|
||||
@ -182,7 +184,7 @@ class HERReplayBuffer(ReplayBuffer):
|
||||
assert isinstance(self._meta.obs, BatchProtocol)
|
||||
self._meta.obs[unique_ep_indices] = ep_obs
|
||||
if self._save_obs_next:
|
||||
self._meta.obs_next[unique_ep_indices] = ep_obs_next
|
||||
self._meta.obs_next[unique_ep_indices] = ep_obs_next # type: ignore
|
||||
self._meta.rew[unique_ep_indices] = ep_rew.astype(np.float32)
|
||||
|
||||
def _compute_reward(self, obs: BatchProtocol, lead_dims: int = 2) -> np.ndarray:
|
||||
|
@ -181,7 +181,7 @@ class Collector:
|
||||
info = processed_data.get("info", info)
|
||||
self.data.info[local_ids] = info # type: ignore
|
||||
|
||||
self.data.obs_next[local_ids] = obs_reset
|
||||
self.data.obs_next[local_ids] = obs_reset # type: ignore
|
||||
|
||||
def collect(
|
||||
self,
|
||||
|
@ -9,6 +9,7 @@ class RolloutBatchProtocol(BatchProtocol):
|
||||
"""Typically, the outcome of sampling from a replay buffer."""
|
||||
|
||||
obs: arr_type | BatchProtocol
|
||||
obs_next: arr_type | BatchProtocol
|
||||
act: arr_type
|
||||
rew: np.ndarray
|
||||
terminated: arr_type
|
||||
|
@ -12,6 +12,7 @@ from torch import nn
|
||||
|
||||
from tianshou.data import ReplayBuffer, to_numpy, to_torch_as
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.buffer.base import TBuffer
|
||||
from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
|
||||
@ -259,6 +260,18 @@ class BasePolicy(ABC, nn.Module):
|
||||
act = (np.log(1.0 + act) - np.log(1.0 - act)) / 2.0 # type: ignore
|
||||
return act
|
||||
|
||||
def process_buffer(self, buffer: TBuffer) -> TBuffer:
|
||||
"""Pre-process the replay buffer, e.g., to add new keys.
|
||||
|
||||
Used in BaseTrainer initialization method, usually used by offline trainers.
|
||||
|
||||
Note: this will only be called once, when the trainer is initialized!
|
||||
If the buffer is empty by then, there will be nothing to process.
|
||||
This method is meant to be overridden by policies which will be trained
|
||||
offline at some stage, e.g., in a pre-training step.
|
||||
"""
|
||||
return buffer
|
||||
|
||||
def process_fn(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
@ -267,7 +280,12 @@ class BasePolicy(ABC, nn.Module):
|
||||
) -> RolloutBatchProtocol:
|
||||
"""Pre-process the data from the provided replay buffer.
|
||||
|
||||
Used in :meth:`update`. Check out :ref:`process_fn` for more information.
|
||||
Meant to be overridden by subclasses. Typical usage is to add new keys to the
|
||||
batch, e.g., to add the value function of the next state. Used in :meth:`update`,
|
||||
which is usually called repeatedly during training.
|
||||
|
||||
For modifying the replay buffer only once at the beginning
|
||||
(e.g., for offline learning) see :meth:`process_buffer`.
|
||||
"""
|
||||
return batch
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from overrides import override
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.data.buffer.base import TBuffer
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
@ -45,6 +47,9 @@ class CQLPolicy(SACPolicy):
|
||||
:param float alpha_min: lower bound for clipping cql_alpha. Default to 0.0.
|
||||
:param float alpha_max: upper bound for clipping cql_alpha. Default to 1e6.
|
||||
:param float clip_grad: clip_grad for updating critic network. Default to 1.0.
|
||||
:param calibrated: calibrate Q-values as in CalQL paper arXiv:2303.05479.
|
||||
Useful for offline pre-training followed by online training,
|
||||
and also was observed to achieve better results than vanilla cql.
|
||||
:param Union[str, torch.device] device: which device to create this model on.
|
||||
Default to "cpu".
|
||||
:param lr_scheduler: a learning rate scheduler that adjusts the learning rate in
|
||||
@ -78,6 +83,7 @@ class CQLPolicy(SACPolicy):
|
||||
alpha_min: float = 0.0,
|
||||
alpha_max: float = 1e6,
|
||||
clip_grad: float = 1.0,
|
||||
calibrated: bool = True,
|
||||
device: str | torch.device = "cpu",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -114,6 +120,8 @@ class CQLPolicy(SACPolicy):
|
||||
self.alpha_max = alpha_max
|
||||
self.clip_grad = clip_grad
|
||||
|
||||
self.calibrated = calibrated
|
||||
|
||||
def train(self, mode: bool = True) -> "CQLPolicy":
|
||||
"""Set the module in training mode, except for the target network."""
|
||||
self.training = mode
|
||||
@ -167,6 +175,31 @@ class CQLPolicy(SACPolicy):
|
||||
|
||||
return random_value1 - random_log_prob1, random_value2 - random_log_prob2
|
||||
|
||||
@override
|
||||
def process_buffer(self, buffer: TBuffer) -> TBuffer:
|
||||
"""If `self.calibrated = True`, adds `calibration_returns` to buffer._meta.
|
||||
|
||||
:param buffer:
|
||||
:return:
|
||||
"""
|
||||
if self.calibrated:
|
||||
# otherwise _meta hack cannot work
|
||||
assert isinstance(buffer, ReplayBuffer)
|
||||
batch, indices = buffer.sample(0)
|
||||
returns, _ = self.compute_episodic_return(
|
||||
batch=batch,
|
||||
buffer=buffer,
|
||||
indices=indices,
|
||||
gamma=self._gamma,
|
||||
gae_lambda=1.0,
|
||||
)
|
||||
# TODO: don't access _meta directly
|
||||
buffer._meta = cast(
|
||||
RolloutBatchProtocol,
|
||||
Batch(**buffer._meta.__dict__, calibration_returns=returns),
|
||||
)
|
||||
return buffer
|
||||
|
||||
def process_fn(
|
||||
self,
|
||||
batch: RolloutBatchProtocol,
|
||||
@ -251,6 +284,23 @@ class CQLPolicy(SACPolicy):
|
||||
]:
|
||||
value.reshape(batch_size, self.num_repeat_actions, 1)
|
||||
|
||||
if self.calibrated:
|
||||
returns = (
|
||||
batch.calibration_returns.unsqueeze(1)
|
||||
.repeat(
|
||||
(1, self.num_repeat_actions),
|
||||
)
|
||||
.view(-1, 1)
|
||||
)
|
||||
random_value1 = torch.max(random_value1, returns)
|
||||
random_value2 = torch.max(random_value2, returns)
|
||||
|
||||
current_pi_value1 = torch.max(current_pi_value1, returns)
|
||||
current_pi_value2 = torch.max(current_pi_value2, returns)
|
||||
|
||||
next_pi_value1 = torch.max(next_pi_value1, returns)
|
||||
next_pi_value2 = torch.max(next_pi_value2, returns)
|
||||
|
||||
# cat q values
|
||||
cat_q1 = torch.cat([random_value1, current_pi_value1, next_pi_value1], 1)
|
||||
cat_q2 = torch.cat([random_value2, current_pi_value2, next_pi_value2], 1)
|
||||
|
@ -211,6 +211,7 @@ class PSRLPolicy(BasePolicy):
|
||||
rew_count = np.zeros((n_s, n_a))
|
||||
for minibatch in batch.split(size=1):
|
||||
obs, act, obs_next = minibatch.obs, minibatch.act, minibatch.obs_next
|
||||
obs_next = cast(np.ndarray, obs_next)
|
||||
assert not isinstance(obs, BatchProtocol), "Observations cannot be Batches here"
|
||||
trans_count[obs, act, obs_next] += 1
|
||||
rew_sum[obs, act] += minibatch.rew
|
||||
|
@ -32,6 +32,9 @@ class BaseTrainer(ABC):
|
||||
:param train_collector: the collector used for training.
|
||||
:param test_collector: the collector used for testing. If it's None,
|
||||
then no testing will be performed.
|
||||
:param buffer: the replay buffer used for off-policy algorithms or for pre-training.
|
||||
If a policy overrides the ``process_buffer`` method, the replay buffer will
|
||||
be pre-processed before training.
|
||||
:param max_epoch: the maximum number of epochs for training. The training
|
||||
process might be finished before reaching ``max_epoch`` if ``stop_fn``
|
||||
is set.
|
||||
@ -167,6 +170,9 @@ class BaseTrainer(ABC):
|
||||
save_best_fn = save_fn
|
||||
|
||||
self.policy = policy
|
||||
|
||||
if buffer is not None:
|
||||
buffer = policy.process_buffer(buffer)
|
||||
self.buffer = buffer
|
||||
|
||||
self.train_collector = train_collector
|
||||
|
Loading…
x
Reference in New Issue
Block a user