|
@@ -5,9 +5,12 @@
|
|
|
|
|
|
```
|
|
|
from accelerate import Accelerator
|
|
|
+from accelerate.logging import get_logger
|
|
|
+from accelerate.utils import DistributedDataParallelKwargs
|
|
|
|
|
|
# 初始化 Accelerator
|
|
|
-accelerator = Accelerator()
|
|
|
+kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
|
|
+accelerator = Accelerator(kwargs_handlers=[kwargs])
|
|
|
|
|
|
# 移除代码中所有的 to(device),会自动分配gpu,原有的model交给 prepare 处理
|
|
|
model, optimizer, train_dataloader = accelerator.prepare(model, optimizer, train_dataloader)
|
|
@@ -23,3 +26,4 @@ for epoch in range(10):
|
|
|
|
|
|
```
|
|
|
|
|
|
+
|