1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
| """
train.py
Training script for Vision-Language-Action (VLA) Policies, built on top of pretrained VLMs, trained using mixtures of
the Open-X Embodiment dataset. Performs training in native PyTorch, using Fully-Sharded Data Parallel (FSDP) to run
distributed across GPUs (and nodes). By default, assumes that CUDA toolkit is >= 11.0 (to support BF16 mixed precision).
Notes & Prerequisites:
- If you want to set a custom location for all HF / TIMM artifacts --> `export HF_HOME="<PATH>"` *before* running!
=> For example (add to end of .bashrc): `export HF_HOME="/mnt/fsx/skaramcheti/cache"`
- If you want to suppress random Tensorflow logs --> `export TF_CPP_MIN_LOG_LEVEL=3`
Run with:
- [Single Node One-GPU (Debug)] : torchrun --standalone --nnodes 1 --nproc-per-node 1 vla-scripts/train.py
- [Single Node Multi-GPU (= $K)]: torchrun --standalone --nnodes 1 --nproc-per-node $K vla-scripts/train.py
"""
import json # 导入json模块,用于处理JSON数据
import os # 导入os模块,用于与操作系统交互
import re # 导入re模块,用于正则表达式操作
from dataclasses import dataclass, field # 从dataclasses模块导入dataclass和field,用于定义数据类
from pathlib import Path # 从pathlib模块导入Path,用于文件路径操作
from typing import Optional, Tuple, Union # 从typing模块导入一些类型提示
import draccus # 导入draccus库,用于配置管理
import torch # 导入torch库,用于深度学习
import torch.distributed as dist # 导入torch.distributed模块,用于分布式训练
import yaml # 导入yaml模块,用于处理YAML文件
from prismatic.conf import VLAConfig, VLARegistry # 从prismatic.conf导入VLAConfig和VLARegistry
from prismatic.models import load, load_vla # 从prismatic.models导入load和load_vla函数
from prismatic.overwatch import initialize_overwatch # 从prismatic.overwatch导入initialize_overwatch函数
from prismatic.training import VLAMetrics, get_train_strategy # 从prismatic.training导入VLAMetrics和get_train_strategy
from prismatic.util import set_global_seed # 从prismatic.util导入set_global_seed函数
from prismatic.vla import get_vla_dataset_and_collator # 从prismatic.vla导入get_vla_dataset_and_collator函数
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics # 从prismatic.vla.datasets.rlds.utils.data_utils导入save_dataset_statistics函数
# 设置合理的默认值
os.environ["TOKENIZERS_PARALLELISM"] = "false" # 禁用分词器的并行处理
# 初始化Overwatch =>> 包装`logging.Logger`
overwatch = initialize_overwatch(__name__) # 初始化日志记录工具
@dataclass # 使用dataclass装饰器定义数据类
class TrainConfig:
# fmt: off
# VLAConfig (`prismatic/conf/vla.py`); override with --vla.type `VLARegistry.<VLA>.vla_id`
vla: VLAConfig = field(
default_factory=VLAConfig.get_choice_class(VLARegistry.DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS.vla_id)
) # VLA配置,默认使用VLARegistry.DINOSIGLIP_224PX_MX_OXE_MAGIC_SOUP_PLUS.vla_id
# 目录路径
data_root_dir: Path = Path( # Open-X数据集目录的路径
"datasets/open-x-embodiment"
)
run_root_dir: Path = Path("runs") # 存储日志和检查点的目录路径
# 恢复运行参数
pretrained_checkpoint: Optional[Path] = None # 预训练检查点的绝对路径
is_resume: bool = True # 是否继续之前的训练
resume_step: Optional[int] = None # 恢复的全局步骤
resume_epoch: Optional[int] = None # 恢复的训练周期
# 运行参数
run_id: Optional[str] = None # 用于日志记录的运行ID
run_id_note: Optional[str] = None # 用于日志记录的额外注释
save_interval: int = 2500 # 保存检查点的间隔(以步骤为单位)
image_aug: bool = False # 是否启用图像增强
seed: int = 7 # 随机种子(用于可重复性)
# HF Hub 凭证(用于任何受限模型)
hf_token: Union[str, Path] = Path(".hf_token") # 环境变量或HF Token的路径
# 跟踪参数
trackers: Tuple[str, ...] = ("jsonl", "wandb") # 初始化的跟踪器
wandb_project: str = "openvla" # W&B项目名称
wandb_entity: str = "stanford-voltron" # W&B实体名称
def __post_init__(self) -> None:
"""提升优化参数的可用性,并验证`expected_world_size`"""
self.epochs = self.vla.epochs # 设置训练周期数
self.max_steps = self.vla.max_steps # 设置最大训练步骤数
self.global_batch_size = self.vla.global_batch_size # 设置全局批次大小
self.per_device_batch_size = self.vla.per_device_batch_size # 设置每个设备的批次大小
self.learning_rate = self.vla.learning_rate # 设置学习率
self.weight_decay = self.vla.weight_decay # 设置权重衰减
self.max_grad_norm = self.vla.max_grad_norm # 设置最大梯度范数
self.lr_scheduler_type = self.vla.lr_scheduler_type # 设置学习率调度器类型
self.warmup_ratio = self.vla.warmup_ratio # 设置预热比率
self.train_strategy = self.vla.train_strategy # 设置训练策略
# [验证] 断言`expected_world_size`
assert (
self.vla.expected_world_size == overwatch.world_size()
), f"Expected World Size = {self.vla.expected_world_size} but Found {overwatch.world_size()} GPUs!" # 验证期望的世界大小是否与实际一致
# fmt: on
@draccus.wrap() # 使用draccus.wrap装饰器定义训练函数
def train(cfg: TrainConfig) -> None:
overwatch.info("OpenVLA Training :: Warming Up") # 记录训练开始的信息
# 注意 => 在`torchrun`下初始化`overwatch`会自动设置`torch.distributed`
torch.cuda.set_device(device_id := overwatch.local_rank()) # 设置CUDA设备
torch.cuda.empty_cache() # 清空CUDA缓存
# 配置唯一的运行名称和保存目录
vla_id = cfg.vla.vla_id # 获取VLA ID
cfg.run_id = (
f"{vla_id}+n{cfg.vla.expected_world_size // 8}+b{cfg.per_device_batch_size}+x{cfg.seed}"
if cfg.run_id is None
else cfg.run_id
) # 如果运行ID为空,则生成唯一的运行ID
if cfg.run_id_note is not None:
cfg.run_id += f"--{cfg.run_id_note}" # 如果有运行ID注释,则添加到运行ID中
if cfg.image_aug:
cfg.run_id += "--image_aug" # 如果启用了图像增强,则添加到运行ID中
# 开始 =>> 创建目录并设置随机性
overwatch.info('"Do or do not; there is no try."', ctx_level=1) # 记录日志信息
hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token] # 读取HF Token
worker_init_fn = set_global_seed(cfg.seed, get_worker_init_fn=True) # 设置全局随机种子
os.makedirs(run_dir := (cfg.run_root_dir / cfg.run_id), exist_ok=True) # 创建运行目录
os.makedirs(cfg.run_root_dir / cfg.run_id / "checkpoints", exist_ok=True) # 创建检查点目录
# 保存配置 =>> 另外保存一个JSON版本以供以后HF集成
if overwatch.is_rank_zero():
draccus.dump(cfg, open(run_dir / "config.yaml", "w")) # 保存配置到YAML文件
with open(run_dir / "config.yaml", "r") as f_yaml, open(run_dir / "config.json", "w") as f_json:
yaml_cfg = yaml.safe_load(f_yaml)
json.dump(yaml_cfg, f_json, indent=2) # 保存配置到JSON文件
# 加载VLA检查点(如果从训练中恢复)或基础VLM(从`cfg.vla.base_vlm` ID或路径)
# =>> 注意::验证所有参数在加载时都以FP32加载!
overwatch.info(f"Loading Base VLM `{cfg.vla.base_vlm}` from ID/Path") # 记录日志信息
if cfg.pretrained_checkpoint is not None:
# [验证] 预训练检查点的`step`和`epoch`应与`resume_step`和`resume_epoch`匹配
# =>> 注意::我们要求开发人员传递`resume_*`参数作为额外的健全性检查!
if cfg.is_resume:
assert int(re.search("step-(.+?)-", cfg.pretrained_checkpoint.name).group(1)) == cfg.resume_step
assert int(re.search("epoch-(.+?)-", cfg.pretrained_checkpoint.name).group(1)) == cfg.resume_epoch
vlm = load_vla(cfg.pretrained_checkpoint, hf_token=hf_token, load_for_training=True) # 加载VLA检查点
else:
vlm = load(cfg.vla.base_vlm, hf_token=hf_token, load_for_training=True) # 加载基础VLM
# [验证] 模型应为全精度!
for param in vlm.parameters():
assert param.dtype == torch.float32, f"Loaded VLM parameter not in full precision: {param}" # 验证模型参数类型
# 根据冻结与未冻结的参数确定训练“阶段”-->支持不同的微调方案!
if not cfg.vla.freeze_vision_backbone and not cfg.vla.freeze_llm_backbone:
stage = "vla-full-train" # 完全微调
elif cfg.vla.freeze_vision_backbone and not cfg.vla.freeze_llm_backbone:
stage = "vla-train" # 冻结视觉编码器
elif not cfg.vla.freeze_vision_backbone and cfg.vla.freeze_llm_backbone:
assert cfg.vla.unfreeze_last_llm_layer, "You should unfreeze at least the last layer of your LLM!"
stage = "vla-sandwich-train" # 微调视觉编码器、投影器和LLM最后一层
elif cfg.vla.freeze_vision_backbone and cfg.vla.freeze_llm_backbone:
assert cfg.vla.unfreeze_last_llm_layer, "Need to unfreeze at least last LLM layer to train!"
stage = "vla-last-layer-train" # 仅微调LLM最后一层
else:
raise ValueError(
"Weight freezing configuration not supported. VLA config has the following parameters: "
f"freeze_vision_backbone: {cfg.vla.freeze_vision_backbone}"
f"freeze_llm_backbone: {cfg.vla.freeze_llm_backbone}"
f"unfreeze_last_llm_layer: {cfg.vla.unfreeze_last_llm_layer}"
) # 如果配置不支持,则引发错误
# [显式] 调用`freeze_backbones`以提高清晰度 =>> 将准确记录哪些被冻结
overwatch.info(f"Invoking `VLM.freeze_backbones()` for `{vla_id}` => Stage: `{stage}`") # 记录日志信息
vlm.freeze_backbones(stage) # 冻结模型参数
# 打印总参数和可训练参数的数量
num_params = sum(p.numel() for p in vlm.parameters())
num_trainable_params = sum(p.numel() for p in vlm.parameters() if p.requires_grad)
overwatch.info(
f"# Parameters (in millions): {num_params / 10**6:.3f} Total, {num_trainable_params / 10**6:.3f} Trainable"
) # 记录参数数量
# 获取VLA数据集和collator
overwatch.info(f"Creating VLA Open-X Dataset with Mixture `{cfg.vla.data_mix}`") # 记录日志信息
vla_dataset, action_tokenizer, collator = get_vla_dataset_and_collator(
cfg.data_root_dir,
cfg.vla.data_mix,
image_transform=vlm.vision_backbone.get_image_transform(),
tokenizer=vlm.llm_backbone.get_tokenizer(),
prompt_builder_fn=vlm.llm_backbone.prompt_builder_fn,
default_image_resolution=vlm.vision_backbone.default_image_resolution,
shuffle_buffer_size=cfg.vla.shuffle_buffer_size,
image_aug=cfg.image_aug,
) # 获取VLA数据集和collator
# 保存数据集统计信息以便在推理时去归一化
if overwatch.is_rank_zero():
save_dataset_statistics(vla_dataset.dataset_statistics, run_dir) # 保存数据集统计信息
# 创建训练策略
overwatch.info(f"Initializing Train Strategy `{cfg.train_strategy}`") # 记录日志信息
train_strategy = get_train_strategy(
train_strategy=cfg.train_strategy,
vlm=vlm,
device_id=device_id,
stage=stage,
epochs=cfg.epochs,
max_steps=cfg.max_steps,
global_batch_size=cfg.global_batch_size,
per_device_batch_size=cfg.per_device_batch_size,
learning_rate=cfg.learning_rate,
weight_decay=cfg.weight_decay,
max_grad_norm=cfg.max_grad_norm,
lr_scheduler_type=cfg.lr_scheduler_type,
warmup_ratio=cfg.warmup_ratio,
enable_gradient_checkpointing=cfg.vla.enable_gradient_checkpointing,
enable_mixed_precision_training=cfg.vla.enable_mixed_precision_training,
reduce_in_full_precision=cfg.vla.reduce_in_full_precision,
worker_init_fn=worker_init_fn,
) # 初始化训练策略
train_strategy.run_setup(run_dir=run_dir, n_train_examples=len(vla_dataset)) # 设置训练策略
# 创建度量工具 =>> 动态跟踪,记录到指定的跟踪器(例如JSONL,Weights & Biases)
overwatch.info(f"Creating Metrics with Active Trackers => `{cfg.trackers}`") # 记录日志信息
metrics = VLAMetrics(
cfg.trackers,
cfg.run_id,
run_dir,
draccus.encode(cfg),
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
resume_step=cfg.resume_step,
resume_epoch=cfg.resume_epoch,
) # 创建度量工具
# 运行VLA训练
overwatch.info("Starting VLA Training Loop") # 记录日志信息
train_strategy.run_vla_training(
vla_dataset,
collator,
action_tokenizer,
metrics,
save_interval=cfg.save_interval,
) # 运行VLA训练
# 完成
overwatch.info("Done with Training =>> Finalizing Metrics") # 记录日志信息
metrics.finalize() # 完成度量工具
# 完成所有操作
overwatch.info("... and that's all, folks!") # 记录日志信息
dist.barrier() # 同步所有进程
dist.destroy_process_group() # 销毁进程组
if __name__ == "__main__":
train() # 如果是主模块,则运行训练函数
|