updated result, requirements and torch version
This commit is contained in:
parent
2504426164
commit
942eae10a9
19
README.md
19
README.md
@ -2,7 +2,7 @@
|
|||||||
Pytorch implementation of [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1).
|
Pytorch implementation of [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1).
|
||||||
|
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
## Instructions
|
## Instructions
|
||||||
|
|
||||||
@ -10,21 +10,26 @@ Get dependencies:
|
|||||||
```
|
```
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
```
|
```
|
||||||
Train the agent:
|
Train the agent on Walker Walk in Vision DMC:
|
||||||
```
|
```
|
||||||
python3 dreamer.py --configs defaults --logdir $ABSOLUTEPATH_TO_SAVE_LOG
|
python3 dreamer.py --configs defaults --task dmc_walker_walk --logdir ~/dreamerv3-torch/logdir/dmc_walker_walk
|
||||||
|
```
|
||||||
|
Train the agent on Alien in Atari 100K:
|
||||||
|
```
|
||||||
|
python3 dreamer.py --configs defaults atari --task atari_alien --logdir ~/dreamerv3-torch/logdir/atari_alien
|
||||||
```
|
```
|
||||||
Monitor results:
|
Monitor results:
|
||||||
```
|
```
|
||||||
tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG
|
tensorboard --logdir ~/dreamerv3-torch/logdir
|
||||||
```
|
```
|
||||||
|
|
||||||
## ToDo
|
## ToDo
|
||||||
- [x] Prototyping
|
- [x] Prototyping
|
||||||
- [x] Modify implementation details based on the author's implementation
|
- [x] Modify implementation details based on the author's implementation
|
||||||
- [ ] Evaluate on visual DMC suite
|
- [x] Evaluate on DMC vision
|
||||||
- [ ] Add state input capability and evaluate on Proprio Control Suite environment
|
- [ ] Evaluate on Atari 100K
|
||||||
- [ ] Add model size options and evaluate on environments which requires that (like Minecraft)
|
- [ ] Add state input capability
|
||||||
|
- [ ] Evaluate on DMC Proprio
|
||||||
- [ ] etc.
|
- [ ] etc.
|
||||||
|
|
||||||
|
|
||||||
|
24
configs.yaml
24
configs.yaml
@ -12,6 +12,7 @@ defaults:
|
|||||||
log_every: 1e4
|
log_every: 1e4
|
||||||
reset_every: 0
|
reset_every: 0
|
||||||
device: 'cuda:0'
|
device: 'cuda:0'
|
||||||
|
compile: False
|
||||||
precision: 16
|
precision: 16
|
||||||
debug: False
|
debug: False
|
||||||
expl_gifs: False
|
expl_gifs: False
|
||||||
@ -63,6 +64,7 @@ defaults:
|
|||||||
reward_scale: 1.0
|
reward_scale: 1.0
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
unimix_ratio: 0.01
|
unimix_ratio: 0.01
|
||||||
|
action_unimix_ratio: 0.01
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
@ -119,6 +121,16 @@ defaults:
|
|||||||
disag_units: 400
|
disag_units: 400
|
||||||
disag_action_cond: False
|
disag_action_cond: False
|
||||||
|
|
||||||
|
visual_dmc:
|
||||||
|
|
||||||
|
atari:
|
||||||
|
steps: 4e5
|
||||||
|
action_repeat: 4
|
||||||
|
actor_dist: 'onehot'
|
||||||
|
train_ratio: 1024
|
||||||
|
imag_gradient: 'reinforce'
|
||||||
|
time_limit: 108000
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
|
|
||||||
debug: True
|
debug: True
|
||||||
@ -127,15 +139,3 @@ debug:
|
|||||||
train_steps: 1
|
train_steps: 1
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
batch_length: 20
|
batch_length: 20
|
||||||
|
|
||||||
cheetah:
|
|
||||||
task: 'dmc_cheetah_run'
|
|
||||||
|
|
||||||
pendulum:
|
|
||||||
task: 'dmc_pendulum_swingup'
|
|
||||||
|
|
||||||
cup:
|
|
||||||
task: 'dmc_cup_catch'
|
|
||||||
|
|
||||||
acrobot:
|
|
||||||
task: 'dmc_acrobot_swingup'
|
|
||||||
|
@ -54,6 +54,9 @@ class Dreamer(nn.Module):
|
|||||||
self._task_behavior = models.ImagBehavior(
|
self._task_behavior = models.ImagBehavior(
|
||||||
config, self._wm, config.behavior_stop_grad
|
config, self._wm, config.behavior_stop_grad
|
||||||
)
|
)
|
||||||
|
if config.compile:
|
||||||
|
self._wm = torch.compile(self._wm)
|
||||||
|
self._task_behavior = torch.compile(self._task_behavior)
|
||||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
|
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
|
||||||
self._expl_behavior = dict(
|
self._expl_behavior = dict(
|
||||||
greedy=lambda: self._task_behavior,
|
greedy=lambda: self._task_behavior,
|
||||||
@ -192,8 +195,8 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
|||||||
config.size,
|
config.size,
|
||||||
grayscale=config.grayscale,
|
grayscale=config.grayscale,
|
||||||
life_done=False and ("train" in mode),
|
life_done=False and ("train" in mode),
|
||||||
sticky_actions=True,
|
sticky_actions=False,
|
||||||
all_actions=True,
|
all_actions=False,
|
||||||
)
|
)
|
||||||
env = wrappers.OneHotAction(env)
|
env = wrappers.OneHotAction(env)
|
||||||
elif suite == "dmlab":
|
elif suite == "dmlab":
|
||||||
|
@ -241,6 +241,7 @@ class ImagBehavior(nn.Module):
|
|||||||
config.actor_max_std,
|
config.actor_max_std,
|
||||||
config.actor_temp,
|
config.actor_temp,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
|
unimix_ratio=config.action_unimix_ratio,
|
||||||
) # action_dist -> action_disc?
|
) # action_dist -> action_disc?
|
||||||
if config.value_head == "twohot_symlog":
|
if config.value_head == "twohot_symlog":
|
||||||
self.value = networks.DenseHead(
|
self.value = networks.DenseHead(
|
||||||
|
@ -514,6 +514,7 @@ class ActionHead(nn.Module):
|
|||||||
max_std=1.0,
|
max_std=1.0,
|
||||||
temp=0.1,
|
temp=0.1,
|
||||||
outscale=1.0,
|
outscale=1.0,
|
||||||
|
unimix_ratio=0.01,
|
||||||
):
|
):
|
||||||
super(ActionHead, self).__init__()
|
super(ActionHead, self).__init__()
|
||||||
self._size = size
|
self._size = size
|
||||||
@ -525,6 +526,7 @@ class ActionHead(nn.Module):
|
|||||||
self._min_std = min_std
|
self._min_std = min_std
|
||||||
self._max_std = max_std
|
self._max_std = max_std
|
||||||
self._init_std = init_std
|
self._init_std = init_std
|
||||||
|
self._unimix_ratio = unimix_ratio
|
||||||
self._temp = temp() if callable(temp) else temp
|
self._temp = temp() if callable(temp) else temp
|
||||||
|
|
||||||
pre_layers = []
|
pre_layers = []
|
||||||
@ -591,7 +593,7 @@ class ActionHead(nn.Module):
|
|||||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||||
elif self._dist == "onehot":
|
elif self._dist == "onehot":
|
||||||
x = self._dist_layer(x)
|
x = self._dist_layer(x)
|
||||||
dist = tools.OneHotDist(x)
|
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
|
||||||
elif self._dist == "onehot_gumble":
|
elif self._dist == "onehot_gumble":
|
||||||
x = self._dist_layer(x)
|
x = self._dist_layer(x)
|
||||||
temp = self._temp
|
temp = self._temp
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
torch==1.13.0
|
torch==2.0.0
|
||||||
|
torchvision==0.15.1
|
||||||
numpy==1.20.1
|
numpy==1.20.1
|
||||||
torchvision==0.14.0
|
|
||||||
tensorboard==2.5.0
|
tensorboard==2.5.0
|
||||||
pandas==1.2.4
|
pandas==1.2.4
|
||||||
matplotlib==3.4.1
|
matplotlib==3.4.1
|
||||||
ruamel.yaml==0.17.4
|
ruamel.yaml==0.17.4
|
||||||
gym[atari]==0.18.0
|
gym[atari]==0.17.0
|
||||||
moviepy==1.0.3
|
moviepy==1.0.3
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
protobuf==3.20.0
|
protobuf==3.20.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user