介绍
本文介绍了PyTorch中单机多卡的训练方法。
背景
单机多卡基本上是调参侠最为常见的炼丹方式,但是当前关于如何进行单机多卡训练的介绍文档并不多,而且有些文档严重过时,例如PyTorch官网的一份文档的示例代码还停留在PyTorch0.3版本。
Naive方法
比较Naive的方法是只用DistributedDataParallel
包住model,但在最新的PyTorch版本中,会有警告提示,如下所示:
/home/lab/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py:364: UserWarning: Single-Process Multi-GPU is not the recommended mode for DDP. In this mode, each DDP instance operates on multiple devices and creates multiple module replicas within one process. The overhead of scatter/gather and GIL contention in every forward pass can slow down training. Please consider using one DDP instance per device or per module replica by explicitly setting device_ids or CUDA_VISIBLE_DEVICES.
显然,这样做会降低性能。
本文主要参考资料
本文代码主要来源于detectron2和SlowFast,在其实现的多机多卡代码中,删去了多机代码,仅保留了多卡代码。
本文还参考了PyTorch官方文档ddp_tutorial。
实现思想
现在官方推荐的实现方式为在各个device,即GPU上,各创建一个进程,使用DistributedDataParallel
来包裹模型。因此需要注意的地方有
- 基本初始化代码
- 分配数据
- 数据在多卡之间同步
- 抑制输出
基本初始化代码
基本初始化代码推荐参阅上文给出的链接,即detectron2和SlowFast的实现方式,官方文档的实现方式不够优雅,需要手动管理多个进程,不建议采用。具体代码见下文。
分配数据
因为有多个进程,而且dataloader之间都是独立的,因此如果不统一安排,则在一个epoch内,必然会造成多次训练同一个数据。而且如果是测试集,则必须在每个卡上均测试一整个epoch。因此需要使用DistributedSampler
来统一管理数据分配。
数据在多卡间同步
因为是多卡训练,因此数据在多个卡之间必然是不同的。DistributedDataParallel
会自动帮我们同步模型参数,但是有些数据需要手动同步,主要是需要保存在日志中的数据,例如loss,在各个卡中,loss是不相同的,因此如果我们要保存loss,最好取各张卡中loss的均值。推荐阅读dist_tuto以了解各个API。例如dist.all_reduce(tensor)
可以将各个卡中名为tensor
的变量相加,并输出到各个卡中。
抑制输出
在训练过程中,我们显然会用到多种输出,例如保存日志,保存模型以及tensorboard的输出。
对于保存模型,可以在每次需要保存时检查是否是主进程来实现。对于tensorboard,可以在非主进程中不创建SummaryWriter。对于保存日志,可以通过修改内置print函数,使其无任何输出,从而避免多次输出。
示例代码
#!/usr/bin/env python
import builtins
import logging
import torch as t
import torch.distributed as dist
from torch.nn.functional import mse_loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.models import resnet50
from torch.optim import SGD
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from tqdm import tqdm
# 创建Dataset时,尽可能减小CPU负担,这样测试性能时可以保证是GPU的计算能力/通信为主要瓶颈。
class DummyDataset(Dataset):
def __getitem__(self, item):
img = t.rand(3, 224, 224)
label = int(img.mean() < 0.5)
return img, label
def __len__(self):
return 1000
# 判断是否为主进程
def is_master_proc(num_gpus=8):
if t.distributed.is_initialized():
return dist.get_rank() % num_gpus == 0
else:
return True
def setup_logging():
if is_master_proc():
logging.basicConfig(filename='test.log', filemode='w', level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler())
logging.info('logging started.')
else:
def print_none(*args, **kwargs):
pass
# 将内置print函数变为一个空函数,从而使非主进程的进程不会输出。
builtins.print = print_none
def train(cfg):
setup_logging()
dataset = DummyDataset()
sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, cfg.batch_size, sampler=sampler, num_workers=4, pin_memory=True)
# 因为每个进程初始化时已经指定了GPU,所以此处不需要明确指出需要将模型迁移到哪个GPU。
model = resnet50(num_classes=1).cuda()
cur_device = t.cuda.current_device()
if cfg.num_gpus > 1:
# 使用DDP包裹模型,可以自动帮我们同步参数。
model = DDP(model, device_ids=[cur_device], output_device=cur_device)
optim = SGD(model.parameters(), 0.001)
# 如果用到tqdm,需要在非主进程的进程进行抑制。
for i, (data, label) in tqdm(enumerate(dataloader), disable=not is_master_proc()):
data = data.to(cur_device)
label = label.to(t.float32).reshape(-1, 1).to(cur_device)
pred = model(data)
optim.zero_grad()
loss = mse_loss(pred, label)
loss.backward()
optim.step()
logging.info('finish train for one epoch.')
logging.info(f'last loss before reduce:{loss.item()}')
dist.all_reduce(loss)
# 可以看到,单卡loss和其他卡上的loss不同,如果涉及到计算准确率等,需要先同步其他卡的结果,然后进行统计。
logging.info(f'last loss after reduce:{loss.item()}')
def run(local_rank, func, cfg):
t.distributed.init_process_group(
backend='nccl',
init_method='tcp://127.0.0.1:9999',
world_size=cfg.num_gpus,
rank=local_rank,
)
t.cuda.set_device(local_rank)
func(cfg)
def launch_job(cfg, func, daemon=False):
if cfg.num_gpus > 1:
t.multiprocessing.spawn(
run,
nprocs=cfg.num_gpus,
args=(
func,
cfg,
),
daemon=daemon,
)
else:
func(cfg)
class Cfg:
def __init__(self):
self.num_gpus = 2
self.batch_size = 12
if __name__ == '__main__':
cfg = Cfg()
launch_job(cfg, train)