@@ -5,9 +5,12 @@
```
from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs
# 初始化 Accelerator
-accelerator = Accelerator()
+kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+accelerator = Accelerator(kwargs_handlers=[kwargs])
# 移除代码中所有的 to(device),会自动分配gpu,原有的model交给 prepare 处理
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
@@ -23,3 +26,4 @@ for epoch in range(10):
+