diff --git a/README.md b/README.md index 3d536f8..2f7c2b3 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Pytorch implementation of [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1). -![results](https://user-images.githubusercontent.com/70328564/226332682-acaef8b5-d825-4266-b4ea-6ce4b169a3a2.png) +![1](https://user-images.githubusercontent.com/70328564/227377956-4a0d7e48-22fb-4f44-aa10-e5878a5ef901.png) ## Instructions @@ -10,21 +10,26 @@ Get dependencies: ``` pip install -r requirements.txt ``` -Train the agent: +Train the agent on Walker Walk in Vision DMC: ``` -python3 dreamer.py --configs defaults --logdir $ABSOLUTEPATH_TO_SAVE_LOG +python3 dreamer.py --configs defaults --task dmc_walker_walk --logdir ~/dreamerv3-torch/logdir/dmc_walker_walk +``` +Train the agent on Alien in Atari 100K: +``` +python3 dreamer.py --configs defaults atari --task atari_alien --logdir ~/dreamerv3-torch/logdir/atari_alien ``` Monitor results: ``` -tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG +tensorboard --logdir ~/dreamerv3-torch/logdir ``` ## ToDo - [x] Prototyping - [x] Modify implementation details based on the author's implementation -- [ ] Evaluate on visual DMC suite -- [ ] Add state input capability and evaluate on Proprio Control Suite environment -- [ ] Add model size options and evaluate on environments which requires that (like Minecraft) +- [x] Evaluate on DMC vision +- [ ] Evaluate on Atari 100K +- [ ] Add state input capability +- [ ] Evaluate on DMC Proprio - [ ] etc. diff --git a/configs.yaml b/configs.yaml index 40b58f2..72a8167 100644 --- a/configs.yaml +++ b/configs.yaml @@ -12,6 +12,7 @@ defaults: log_every: 1e4 reset_every: 0 device: 'cuda:0' + compile: False precision: 16 debug: False expl_gifs: False @@ -63,6 +64,7 @@ defaults: reward_scale: 1.0 weight_decay: 0.0 unimix_ratio: 0.01 + action_unimix_ratio: 0.01 # Training batch_size: 16 @@ -119,6 +121,16 @@ defaults: disag_units: 400 disag_action_cond: False +visual_dmc: + +atari: + steps: 4e5 + action_repeat: 4 + actor_dist: 'onehot' + train_ratio: 1024 + imag_gradient: 'reinforce' + time_limit: 108000 + debug: debug: True @@ -127,15 +139,3 @@ debug: train_steps: 1 batch_size: 10 batch_length: 20 - -cheetah: - task: 'dmc_cheetah_run' - -pendulum: - task: 'dmc_pendulum_swingup' - -cup: - task: 'dmc_cup_catch' - -acrobot: - task: 'dmc_acrobot_swingup' diff --git a/dreamer.py b/dreamer.py index 79951b7..9bdfcd0 100644 --- a/dreamer.py +++ b/dreamer.py @@ -54,6 +54,9 @@ class Dreamer(nn.Module): self._task_behavior = models.ImagBehavior( config, self._wm, config.behavior_stop_grad ) + if config.compile: + self._wm = torch.compile(self._wm) + self._task_behavior = torch.compile(self._task_behavior) reward = lambda f, s, a: self._wm.heads["reward"](f).mean self._expl_behavior = dict( greedy=lambda: self._task_behavior, @@ -192,8 +195,8 @@ def make_env(config, logger, mode, train_eps, eval_eps): config.size, grayscale=config.grayscale, life_done=False and ("train" in mode), - sticky_actions=True, - all_actions=True, + sticky_actions=False, + all_actions=False, ) env = wrappers.OneHotAction(env) elif suite == "dmlab": diff --git a/models.py b/models.py index 7460130..3455f5a 100644 --- a/models.py +++ b/models.py @@ -241,6 +241,7 @@ class ImagBehavior(nn.Module): config.actor_max_std, config.actor_temp, outscale=1.0, + unimix_ratio=config.action_unimix_ratio, ) # action_dist -> action_disc? if config.value_head == "twohot_symlog": self.value = networks.DenseHead( diff --git a/networks.py b/networks.py index 9e2ae7b..8f360d1 100644 --- a/networks.py +++ b/networks.py @@ -514,6 +514,7 @@ class ActionHead(nn.Module): max_std=1.0, temp=0.1, outscale=1.0, + unimix_ratio=0.01, ): super(ActionHead, self).__init__() self._size = size @@ -525,6 +526,7 @@ class ActionHead(nn.Module): self._min_std = min_std self._max_std = max_std self._init_std = init_std + self._unimix_ratio = unimix_ratio self._temp = temp() if callable(temp) else temp pre_layers = [] @@ -591,7 +593,7 @@ class ActionHead(nn.Module): dist = tools.ContDist(torchd.independent.Independent(dist, 1)) elif self._dist == "onehot": x = self._dist_layer(x) - dist = tools.OneHotDist(x) + dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio) elif self._dist == "onehot_gumble": x = self._dist_layer(x) temp = self._temp diff --git a/requirements.txt b/requirements.txt index 30f5b3d..ce85a91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,11 @@ -torch==1.13.0 +torch==2.0.0 +torchvision==0.15.1 numpy==1.20.1 -torchvision==0.14.0 tensorboard==2.5.0 pandas==1.2.4 matplotlib==3.4.1 ruamel.yaml==0.17.4 -gym[atari]==0.18.0 +gym[atari]==0.17.0 moviepy==1.0.3 einops==0.3.0 protobuf==3.20.0