updated result, requirements and torch version
This commit is contained in:
parent
2504426164
commit
942eae10a9
19
README.md
19
README.md
@ -2,7 +2,7 @@
|
||||
Pytorch implementation of [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1).
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
## Instructions
|
||||
|
||||
@ -10,21 +10,26 @@ Get dependencies:
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
Train the agent:
|
||||
Train the agent on Walker Walk in Vision DMC:
|
||||
```
|
||||
python3 dreamer.py --configs defaults --logdir $ABSOLUTEPATH_TO_SAVE_LOG
|
||||
python3 dreamer.py --configs defaults --task dmc_walker_walk --logdir ~/dreamerv3-torch/logdir/dmc_walker_walk
|
||||
```
|
||||
Train the agent on Alien in Atari 100K:
|
||||
```
|
||||
python3 dreamer.py --configs defaults atari --task atari_alien --logdir ~/dreamerv3-torch/logdir/atari_alien
|
||||
```
|
||||
Monitor results:
|
||||
```
|
||||
tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG
|
||||
tensorboard --logdir ~/dreamerv3-torch/logdir
|
||||
```
|
||||
|
||||
## ToDo
|
||||
- [x] Prototyping
|
||||
- [x] Modify implementation details based on the author's implementation
|
||||
- [ ] Evaluate on visual DMC suite
|
||||
- [ ] Add state input capability and evaluate on Proprio Control Suite environment
|
||||
- [ ] Add model size options and evaluate on environments which requires that (like Minecraft)
|
||||
- [x] Evaluate on DMC vision
|
||||
- [ ] Evaluate on Atari 100K
|
||||
- [ ] Add state input capability
|
||||
- [ ] Evaluate on DMC Proprio
|
||||
- [ ] etc.
|
||||
|
||||
|
||||
|
24
configs.yaml
24
configs.yaml
@ -12,6 +12,7 @@ defaults:
|
||||
log_every: 1e4
|
||||
reset_every: 0
|
||||
device: 'cuda:0'
|
||||
compile: False
|
||||
precision: 16
|
||||
debug: False
|
||||
expl_gifs: False
|
||||
@ -63,6 +64,7 @@ defaults:
|
||||
reward_scale: 1.0
|
||||
weight_decay: 0.0
|
||||
unimix_ratio: 0.01
|
||||
action_unimix_ratio: 0.01
|
||||
|
||||
# Training
|
||||
batch_size: 16
|
||||
@ -119,6 +121,16 @@ defaults:
|
||||
disag_units: 400
|
||||
disag_action_cond: False
|
||||
|
||||
visual_dmc:
|
||||
|
||||
atari:
|
||||
steps: 4e5
|
||||
action_repeat: 4
|
||||
actor_dist: 'onehot'
|
||||
train_ratio: 1024
|
||||
imag_gradient: 'reinforce'
|
||||
time_limit: 108000
|
||||
|
||||
debug:
|
||||
|
||||
debug: True
|
||||
@ -127,15 +139,3 @@ debug:
|
||||
train_steps: 1
|
||||
batch_size: 10
|
||||
batch_length: 20
|
||||
|
||||
cheetah:
|
||||
task: 'dmc_cheetah_run'
|
||||
|
||||
pendulum:
|
||||
task: 'dmc_pendulum_swingup'
|
||||
|
||||
cup:
|
||||
task: 'dmc_cup_catch'
|
||||
|
||||
acrobot:
|
||||
task: 'dmc_acrobot_swingup'
|
||||
|
@ -54,6 +54,9 @@ class Dreamer(nn.Module):
|
||||
self._task_behavior = models.ImagBehavior(
|
||||
config, self._wm, config.behavior_stop_grad
|
||||
)
|
||||
if config.compile:
|
||||
self._wm = torch.compile(self._wm)
|
||||
self._task_behavior = torch.compile(self._task_behavior)
|
||||
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
|
||||
self._expl_behavior = dict(
|
||||
greedy=lambda: self._task_behavior,
|
||||
@ -192,8 +195,8 @@ def make_env(config, logger, mode, train_eps, eval_eps):
|
||||
config.size,
|
||||
grayscale=config.grayscale,
|
||||
life_done=False and ("train" in mode),
|
||||
sticky_actions=True,
|
||||
all_actions=True,
|
||||
sticky_actions=False,
|
||||
all_actions=False,
|
||||
)
|
||||
env = wrappers.OneHotAction(env)
|
||||
elif suite == "dmlab":
|
||||
|
@ -241,6 +241,7 @@ class ImagBehavior(nn.Module):
|
||||
config.actor_max_std,
|
||||
config.actor_temp,
|
||||
outscale=1.0,
|
||||
unimix_ratio=config.action_unimix_ratio,
|
||||
) # action_dist -> action_disc?
|
||||
if config.value_head == "twohot_symlog":
|
||||
self.value = networks.DenseHead(
|
||||
|
@ -514,6 +514,7 @@ class ActionHead(nn.Module):
|
||||
max_std=1.0,
|
||||
temp=0.1,
|
||||
outscale=1.0,
|
||||
unimix_ratio=0.01,
|
||||
):
|
||||
super(ActionHead, self).__init__()
|
||||
self._size = size
|
||||
@ -525,6 +526,7 @@ class ActionHead(nn.Module):
|
||||
self._min_std = min_std
|
||||
self._max_std = max_std
|
||||
self._init_std = init_std
|
||||
self._unimix_ratio = unimix_ratio
|
||||
self._temp = temp() if callable(temp) else temp
|
||||
|
||||
pre_layers = []
|
||||
@ -591,7 +593,7 @@ class ActionHead(nn.Module):
|
||||
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
|
||||
elif self._dist == "onehot":
|
||||
x = self._dist_layer(x)
|
||||
dist = tools.OneHotDist(x)
|
||||
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
|
||||
elif self._dist == "onehot_gumble":
|
||||
x = self._dist_layer(x)
|
||||
temp = self._temp
|
||||
|
@ -1,11 +1,11 @@
|
||||
torch==1.13.0
|
||||
torch==2.0.0
|
||||
torchvision==0.15.1
|
||||
numpy==1.20.1
|
||||
torchvision==0.14.0
|
||||
tensorboard==2.5.0
|
||||
pandas==1.2.4
|
||||
matplotlib==3.4.1
|
||||
ruamel.yaml==0.17.4
|
||||
gym[atari]==0.18.0
|
||||
gym[atari]==0.17.0
|
||||
moviepy==1.0.3
|
||||
einops==0.3.0
|
||||
protobuf==3.20.0
|
||||
|
Loading…
x
Reference in New Issue
Block a user