This repository has been archived by the owner on Dec 15, 2023. It is now read-only.
generated from minerllabs/basalt_2022_competition_submission_template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
xirl_config.py
102 lines (74 loc) · 2.98 KB
/
xirl_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import minerl
from dataclasses import dataclass, field
import os
from typing import List
import gym
ACTION_SPACE = gym.make("MineRLBasaltMakeWaterfall-v0").action_space
TASKS = [
{"name": "BuildHouse", "dataset_dir": "MineRLBasaltBuildVillageHouse-v0"},
{"name": "AnimalPen", "dataset_dir": "MineRLBasaltCreateVillageAnimalPen-v0"},
{"name": "FindCave", "dataset_dir": "MineRLBasaltFindCave-v0"},
{"name": "BuildWaterfall", "dataset_dir": "MineRLBasaltMakeWaterfall-v0"},
]
MINERL_DATA_ROOT = os.getenv("MINERL_DATA_ROOT", "data/")
VPT_MODELS_ROOT = os.path.join(MINERL_DATA_ROOT, "VPT-models/")
@dataclass
class XIRLConfig:
# model_filename: str = "foundation-model-1x.model"
model_filename: str = "foundation-model-2x.model"
# model_filename: str = "foundation-model-3x.model"
# weights_filename: str = "foundation-model-1x.weights"
weights_filename: str = "rl-from-early-game-2x.weights"
# enable tasks according to the indices of the `TASKS` list. 0="BuildHouse", etc.
# ordering of these tasks determines the ordering of the responsibility of the discriminator logits.
# ENABLED_TASKS = [0, 1, 2, 3] # all 4 tasks.
enabled_tasks: List[int] = field(default_factory=lambda: [3]) # waterfall only.
# when this is true, FMC will not be run and the discriminator will only try to discriminate between
# the enabled tasks.
disable_fmc_detection: bool = False
verbose: bool = False
force_cpu: bool = False
num_frames_per_pair: int = 128
num_frames_per_trajectory_to_load: int = 128
temperature: float = 0.1
data_workers: int = 4
batch_size: int = 32
embed_batch_size: int = 64
learning_rate: float = 0.00008
consistency_loss_coeff: float = 0.0 # if 0, consistency loss is ignored.
num_walkers: int = 128
fmc_steps: int = 16
unroll_steps: int = 4
fmc_random_policy: bool = False
use_wandb: bool = False
action_space = ACTION_SPACE
@property
def fmc_logit(self):
# the last logit is dedicated to FMC
return len(self.enabled_tasks)
@property
def num_discriminator_classes(self):
# plus 1 because FMC is a class label
if self.disable_fmc_detection:
return len(self.enabled_tasks)
return len(self.enabled_tasks) + 1
@property
def model_path(self) -> str:
return os.path.join(VPT_MODELS_ROOT, self.model_filename)
@property
def weights_path(self) -> str:
return os.path.join(VPT_MODELS_ROOT, self.weights_filename)
@property
def dataset_paths(self) -> List[str]:
return [
os.path.join(MINERL_DATA_ROOT, TASKS[task_id]["dataset_dir"])
for task_id in self.enabled_tasks
]
@property
def environment_id_to_task_logit(self):
return {
TASKS[task_id]["dataset_dir"]: i
for i, task_id in enumerate(self.enabled_tasks)
}
def asdict(self):
return {**self.__dict__, "enabled_tasks": self.enabled_tasks}