33 Commits

Author SHA1 Message Date
Dominik Jain
4b270eaa2d Add documentation, improve structure of 'module' package 2023-10-18 20:44:18 +02:00
Dominik Jain
97e21b5ddf Remove obsolete mixin, improve class names 2023-10-18 20:44:18 +02:00
Dominik Jain
fc695a5394 Use logging to report trainer epoch status 2023-10-18 20:44:18 +02:00
Dominik Jain
023b33c917 Make mypy happy 2023-10-18 20:44:18 +02:00
Dominik Jain
76e870207d Improve persistence handling
* Add persistence/restoration of Experiment instance
* Add file logging in experiment
* Allow all persistence/logging to be disabled
* Disable persistence in tests
2023-10-18 20:44:18 +02:00
Dominik Jain
f6d49774a2 Reify policy persistence, introducing Wold representation 2023-10-18 20:44:17 +02:00
Dominik Jain
686fd555b0 Extend tests, fixing some default behaviour 2023-10-18 20:44:17 +02:00
Dominik Jain
a8a367c42d Support IQN in high-level API
* Add example atari_iqn_hl
* Factor out trainer callbacks to new module atari_callbacks
* Extract base class for DQN-based agent factories
* Improved module factory interface design, achieving higher generality
2023-10-18 20:44:17 +02:00
Dominik Jain
c7d0b6b4b2 Simplify agent factories by making better use of base classes 2023-10-18 20:44:17 +02:00
Dominik Jain
799beb79b4 Support discrete SAC in high-level API
* Changed machanism for reusing actor's preprocessing module in critics
  to avoid special handling in AgentFactory implementations, improving
  separation of concerns:
    - Added CriticFactoryReuseActor as the new critic factory
    - Added ActorFactoryTransientStorageDecorator to pass on the actor
      data
    - Added helper classes ActorFuture, ActorFutureProviderProtocol
* Add example atari_sac_hl
2023-10-18 20:44:17 +02:00
Dominik Jain
17ef4dd5eb Support REDQ in high-level API
* Implement example mujoco_redq_hl
* Add abstraction CriticEnsembleFactory with default implementations
  to suit REDQ
* Fix type annotation of linear_layer in Net, MLP, Critic
  (was incompatible with REDQ usage)
2023-10-18 20:44:17 +02:00
Dominik Jain
7af836bd6a Support TRPO in high-level API and add example mujoco_trpo_hl 2023-10-18 20:44:17 +02:00
Dominik Jain
383a4a6083 Support NPG in high-level API and add example mujoco_npg_hl 2023-10-18 20:44:17 +02:00
Dominik Jain
1bb52a6a5c Simplify critic/agent with optimizer generation
After adding a function to create ModuleOpt instances directly from
AgentFactory and CriticFactory,
  * several mixins for AgentFactories are no longer needed (deleted)
  * additional abstractions for ModuleOptFactories are no longer needed (deleted)
2023-10-18 20:44:17 +02:00
Dominik Jain
6bb3abb2f0 Support PG/Reinforce in high-level API
* Add example mujoco_reinforce_hl
* Extended functionality of ActorFactory to support creation of ModuleOpt
2023-10-18 20:44:17 +02:00
Dominik Jain
a161a9cf58 Improve type annotations, fix type issues and add checks 2023-10-18 20:44:17 +02:00
Dominik Jain
d269063e6a Remove 'RL' prefix from class names 2023-10-18 20:44:17 +02:00
Dominik Jain
b54fcd12cb Change high-level DQN interface to expect an actor instead of a critic,
because that is what is functionally required
2023-10-18 20:44:16 +02:00
Dominik Jain
1cba589bd4 Add DQN support in high-level API
* Allow to specify trainer callbacks (train_fn, test_fn, stop_fn)
  in high-level API, adding the necessary abstractions and pass-on
  mechanisms
* Add example atari_dqn_hl
2023-10-18 20:44:16 +02:00
Dominik Jain
9f0a410bb1 Log full experiment configuration, adding string representations to relevant classes 2023-10-18 20:44:16 +02:00
Dominik Jain
2671580c6c Add DDPG high-level API and MuJoCo example 2023-10-18 20:44:16 +02:00
Dominik Jain
6b6d9ea609 Add support for discrete PPO
* Refactored module `module` (split into submodules)
* Basic support for discrete environments
* Implement Atari env. factory
* Implement DQN-based actor factory
* Implement notion of reusing agent preprocessing network for critic
* Add example atari_ppo_hl
2023-10-18 20:44:16 +02:00
Dominik Jain
cd79cf8661 Add A2C high-level API
* Add common based class for A2C and PPO agent factories
* Add default for dist_fn parameter, adding corresponding factories
* Add example mujoco_a2c_hl
2023-10-18 20:44:16 +02:00
Dominik Jain
d4e604b46e Move parameter transformation directly into parameter objects,
achieving greater separation of concerns and improved maintainability
2023-10-18 20:44:16 +02:00
Dominik Jain
e993425aa1 Add high-level API support for TD3
* Created mixins for agent factories to reduce code duplication
 * Further factorised params & mixins for experiment factories
 * Additional parameter abstractions
 * Implement high-level MuJoCo TD3 example
2023-10-18 20:44:16 +02:00
Dominik Jain
367778d37f Improve high-level policy parametrisation
Policy objects are now parametrised by converting the parameter
dataclass instances to kwargs, using some injectable conversions
along the way
2023-10-18 20:44:16 +02:00
Dominik Jain
37dc07e487 Add high-level experiment builder interface 2023-10-18 20:44:05 +02:00
Dominik Jain
3fd60f9e70 Unify PPO configuration objects, use experiment-specific configuration
in mujoco_ppo_hl
2023-10-09 13:02:29 +02:00
Dominik Jain
8ec42009cb Move RLSamplingConfig to separate module config, fixing cyclic import 2023-10-09 13:02:23 +02:00
Dominik Jain
d26b8cb40c Use experiment-specific config in mujoco_sac_hl, adding auto-alpha 2023-10-09 13:02:18 +02:00
Dominik Jain
997b520580 Refactoring, dropping package config 2023-10-09 13:02:07 +02:00
Dominik Jain
316eb3c579 Add SAC high-level interface 2023-10-09 13:02:01 +02:00
Dominik Jain
16ed5fd2a5 Initial high-level interfaces, demonstrated in mujoco_ppo_hl 2023-10-09 13:01:35 +02:00