基于pytorch,更轻松地进行分布式训练和混合精度训练,提高模型训练效率 https://github.com/huggingface/accelerate

天问 c32306f400 Update 'README.md' 6 months ago
README.md c32306f400 Update 'README.md' 6 months ago

README.md

accelerate

基于pytorch,更轻松地进行分布式训练和混合精度训练,提高模型训练效率

Usage

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs

# 初始化 Accelerator
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[kwargs])

# 移除代码中所有的 to(device),会自动分配gpu,原有的model交给 prepare 处理
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)

# 训练循环
for epoch in range(10):
    for source, targets in train_dataloader:
        optimizer.zero_grad()
        output = model(source)
        loss = F.cross_entropy(output, targets)
        accelerator.backward(loss)
        optimizer.step()