updated result, requirements and torch version

This commit is contained in:
NM512 2023-03-24 07:51:57 +09:00
parent 2504426164
commit 942eae10a9
6 changed files with 36 additions and 25 deletions

View File

@ -2,7 +2,7 @@
Pytorch implementation of [Mastering Diverse Domains through World Models](https://arxiv.org/abs/2301.04104v1).
![results](https://user-images.githubusercontent.com/70328564/226332682-acaef8b5-d825-4266-b4ea-6ce4b169a3a2.png)
![1](https://user-images.githubusercontent.com/70328564/227377956-4a0d7e48-22fb-4f44-aa10-e5878a5ef901.png)
## Instructions
@ -10,21 +10,26 @@ Get dependencies:
```
pip install -r requirements.txt
```
Train the agent:
Train the agent on Walker Walk in Vision DMC:
```
python3 dreamer.py --configs defaults --logdir $ABSOLUTEPATH_TO_SAVE_LOG
python3 dreamer.py --configs defaults --task dmc_walker_walk --logdir ~/dreamerv3-torch/logdir/dmc_walker_walk
```
Train the agent on Alien in Atari 100K:
```
python3 dreamer.py --configs defaults atari --task atari_alien --logdir ~/dreamerv3-torch/logdir/atari_alien
```
Monitor results:
```
tensorboard --logdir $ABSOLUTEPATH_TO_SAVE_LOG
tensorboard --logdir ~/dreamerv3-torch/logdir
```
## ToDo
- [x] Prototyping
- [x] Modify implementation details based on the author's implementation
- [ ] Evaluate on visual DMC suite
- [ ] Add state input capability and evaluate on Proprio Control Suite environment
- [ ] Add model size options and evaluate on environments which requires that (like Minecraft)
- [x] Evaluate on DMC vision
- [ ] Evaluate on Atari 100K
- [ ] Add state input capability
- [ ] Evaluate on DMC Proprio
- [ ] etc.

View File

@ -12,6 +12,7 @@ defaults:
log_every: 1e4
reset_every: 0
device: 'cuda:0'
compile: False
precision: 16
debug: False
expl_gifs: False
@ -63,6 +64,7 @@ defaults:
reward_scale: 1.0
weight_decay: 0.0
unimix_ratio: 0.01
action_unimix_ratio: 0.01
# Training
batch_size: 16
@ -119,6 +121,16 @@ defaults:
disag_units: 400
disag_action_cond: False
visual_dmc:
atari:
steps: 4e5
action_repeat: 4
actor_dist: 'onehot'
train_ratio: 1024
imag_gradient: 'reinforce'
time_limit: 108000
debug:
debug: True
@ -127,15 +139,3 @@ debug:
train_steps: 1
batch_size: 10
batch_length: 20
cheetah:
task: 'dmc_cheetah_run'
pendulum:
task: 'dmc_pendulum_swingup'
cup:
task: 'dmc_cup_catch'
acrobot:
task: 'dmc_acrobot_swingup'

View File

@ -54,6 +54,9 @@ class Dreamer(nn.Module):
self._task_behavior = models.ImagBehavior(
config, self._wm, config.behavior_stop_grad
)
if config.compile:
self._wm = torch.compile(self._wm)
self._task_behavior = torch.compile(self._task_behavior)
reward = lambda f, s, a: self._wm.heads["reward"](f).mean
self._expl_behavior = dict(
greedy=lambda: self._task_behavior,
@ -192,8 +195,8 @@ def make_env(config, logger, mode, train_eps, eval_eps):
config.size,
grayscale=config.grayscale,
life_done=False and ("train" in mode),
sticky_actions=True,
all_actions=True,
sticky_actions=False,
all_actions=False,
)
env = wrappers.OneHotAction(env)
elif suite == "dmlab":

View File

@ -241,6 +241,7 @@ class ImagBehavior(nn.Module):
config.actor_max_std,
config.actor_temp,
outscale=1.0,
unimix_ratio=config.action_unimix_ratio,
) # action_dist -> action_disc?
if config.value_head == "twohot_symlog":
self.value = networks.DenseHead(

View File

@ -514,6 +514,7 @@ class ActionHead(nn.Module):
max_std=1.0,
temp=0.1,
outscale=1.0,
unimix_ratio=0.01,
):
super(ActionHead, self).__init__()
self._size = size
@ -525,6 +526,7 @@ class ActionHead(nn.Module):
self._min_std = min_std
self._max_std = max_std
self._init_std = init_std
self._unimix_ratio = unimix_ratio
self._temp = temp() if callable(temp) else temp
pre_layers = []
@ -591,7 +593,7 @@ class ActionHead(nn.Module):
dist = tools.ContDist(torchd.independent.Independent(dist, 1))
elif self._dist == "onehot":
x = self._dist_layer(x)
dist = tools.OneHotDist(x)
dist = tools.OneHotDist(x, unimix_ratio=self._unimix_ratio)
elif self._dist == "onehot_gumble":
x = self._dist_layer(x)
temp = self._temp

View File

@ -1,11 +1,11 @@
torch==1.13.0
torch==2.0.0
torchvision==0.15.1
numpy==1.20.1
torchvision==0.14.0
tensorboard==2.5.0
pandas==1.2.4
matplotlib==3.4.1
ruamel.yaml==0.17.4
gym[atari]==0.18.0
gym[atari]==0.17.0
moviepy==1.0.3
einops==0.3.0
protobuf==3.20.0