diff --git a/Dockerfile b/Dockerfile index 58ca0a7..d07ec0b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,10 +6,10 @@ # # 2. Start training: # docker build -f Dockerfile -t img . && \ -# docker run -it --rm --gpus all -v $PWD:/workspace img \ +# docker run -it --rm --gpus all -v $PWD:/workspace -u $(id -u):$(id -g) img \ # sh xvfb_run.sh python3 dreamer.py \ -# --logdir "./logdir/dmc_walker_walk" \ -# --configs dmc_vision --task dmc_walker_walk +# --configs dmc_vision --task dmc_walker_walk \ +# --logdir "./logdir/dmc_walker_walk" # # 3. See results: # tensorboard --logdir ~/logdir @@ -34,11 +34,11 @@ ENV NUMBA_CACHE_DIR=/tmp # dmc setup RUN pip3 install tensorboard RUN pip3 install gym==0.19.0 -RUN pip3 install dm_control +RUN pip3 install dm_control==1.0.9 RUN pip3 install moviepy # crafter setup -RUN pip3 install crafter +RUN pip3 install crafter==1.8.0 # atari setup RUN pip3 install atari-py==0.2.9 diff --git a/README.md b/README.md index 4e37bf5..919617a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,9 @@ Pytorch implementation of [Mastering Diverse Domains through World Models](https ## Instructions -Get dependencies: +### Method 1: Manual + +Get dependencies with python 3.9: ``` pip install -r requirements.txt ``` @@ -15,6 +17,9 @@ Monitor results: ``` tensorboard --logdir ./logdir ``` +### Method 2: Docker + +Please refer to the Dockerfile for the instructions, as they are included within. ## Benchmarks So far, the following benchmarks can be used for testing. diff --git a/exploration.py b/exploration.py index afe2580..4938fba 100644 --- a/exploration.py +++ b/exploration.py @@ -58,9 +58,10 @@ class Plan2Explore(nn.Module): "feat": config.dyn_stoch + config.dyn_deter, }[self._config.disag_target] kw = dict( - inp_dim=feat_size + (config.num_actions - if config.disag_action_cond - else 0), # pytorch version + inp_dim=feat_size + + ( + config.num_actions if config.disag_action_cond else 0 + ), # pytorch version shape=size, layers=config.disag_layers, units=config.disag_units, diff --git a/requirements.txt b/requirements.txt index 930f0f5..bee6f73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ setuptools==60.0.0 torch==2.0.0 torchvision==0.15.1 -tensorboard==2.10.0 pandas==1.2.4 matplotlib==3.5.0 ruamel.yaml==0.17.4 @@ -16,6 +15,7 @@ atari-py==0.2.9 crafter==1.8.0 opencv-python==4.7.0.72 numpy==1.21.0 +tensorboard # minerl==0.4.4 # This was needed for minerl # conda install -c conda-forge openjdk=8