Pytorch中的分布式神经网络训练 pytorch 训练神经网络
itomcoil 2024-12-15 13:58 27 浏览
随着深度学习的多项进步,复杂的网络(例如大型transformer 网络,更广更深的Resnet等)已经发展起来,从而需要了更大的内存空间。 经常,在训练这些网络时,深度学习从业人员需要使用多个GPU来有效地训练它们。 在本文中,我将向您介绍如何使用PyTorch在GPU集群上设置分布式神经网络训练。
通常,分布式训练会在有一下两种情况。
1. 在GPU之间拆分模型:如果模型太大而无法容纳在单个GPU的内存中,则需要在不同GPU之间拆分模型的各个部分。
1. 跨GPU进行批量拆分数据。当mini-batch太大而无法容纳在单个GPU的内存中时,您需要将mini-batch拆分到不同的GPU上。
跨GPU的模型拆分
跨GPU拆分模型非常简单,不需要太多代码更改。 在设置网络本身时,可以将模型的某些部分移至特定的GPU。 之后,在通过网络转发数据时,数据也需要移动到相应的GPU。 下面是执行相同操作的PyTorch代码段。
from torch import nn
class Network(nn.Module):
def __init__(self, split_gpus=False):
super().__init__()
self.module1 = ...
self.module2 = ...
self.split_gpus = split_gpus
if split_gpus: #considering only two gpus
self.module1.cuda(0)
self.module2.cuda(1)
def forward(self, x):
if self.split_gpus:
x = x.cuda(0)
x = self.module1(x)
if self.split_gpus:
x = x.cuda(1)
x = self.module2(x)
return x
跨GPU的数据拆分
有3种在GPU之间拆分批处理的方法。
· 积累梯度
· 使用nn.DataParallel
· 使用nn.DistributedDataParallel
积累梯度
在GPU之间拆分批次的最简单方法是累积梯度。 假设我们要训练的批处理大小为256,但是一个GPU内存只能容纳32个批处理大小。 我们可以执行8(= 256/32)个梯度下降迭代而无需执行优化步骤,并继续通过loss.backward()步骤添加计算出的梯度。 一旦我们累积了256个数据点的梯度,就执行优化步骤,即调用optimizer.step()。 以下是用于实现累积渐变的PyTorch代码段。
TARGET_BATCH_SIZE, BATCH_FIT_IN_MEMORY = 256, 32
accumulation_steps = int(TARGET_BATCH_SIZE / BATCH_FIT_IN_MEMORY)
network.zero_grad() # Reset gradients tensors
for i, (imgs, labels) in enumerate(dataloader):
preds = network(imgs) # Forward pass
loss = loss_function(preds, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optim.step() # Perform an optimizer step
network.zero_grad() # Reset gradients tensors
优点: 不需要多个GPU即可进行大批量训练。 即使使用单个GPU,此方法也可以进行大批量训练。
缺点: 比在多个GPU上并行训练要花费更多的时间。
使用nn.DataParallel
如果您可以访问多个GPU,则将不同的批处理拆分分配给不同的GPU,在不同的GPU上进行梯度计算,然后累积梯度以执行梯度下降是很有意义的。
多GPU下的forward和backward
基本上,给定的输入通过在批处理维度中分块在GPU之间进行分配。 在前向传递中,模型在每个设备上复制,每个副本处理批次的一部分。 在向后传递过程中,将每个副本的梯度求和以生成最终的梯度,并将其应用于主gpu(上图中的GPU-1)以更新模型权重。 在下一次迭代中,主GPU上的更新模型将再次复制到每个GPU设备上。
在PyTorch中,只需要一行就可以使用nn.DataParallel进行分布式训练。 该模型只需要包装在nn.DataParallel中。
model = torch.nn.DataParallel(model)
...
...
loss = ...
loss.backward()
优点:并行化多个GPU上的NN训练,因此与累积梯度相比,它减少了训练时间。因为代码更改很少,所以适合快速原型制作。
缺点:nn.DataParallel使用单进程多线程方法在不同的GPU上训练相同的模型。 它将主进程保留在一个GPU上,并在其他GPU上运行不同的线程。 由于python中的线程存在GIL(全局解释器锁定)问题,因此这限制了完全并行的分布式训练设置。
使用DistributedDataParallel
与nn.DataParallel不同,DistributedDataParallel在GPU上生成单独的进程进行多重处理,并利用GPU之间通信实现的完全并行性。但是,设置DistributedDataParallel管道比nn.DataParallel更复杂,需要执行以下步骤(但不一定按此顺序)。
将模型包装在torch.nn.Parallel.DistributedDataParallel中。
设置数据加载器以使用distributedSampler在所有GPU之间高效地分配样本。 Pytorch为此提供了torch.utils.data.Distributed.DistributedSampler。设置分布式后端以管理GPU的同步。 torch.distributed.initprocessgroup(backend ='nccl')。
pytorch提供了用于分布式通讯后端(nccl,gloo,mpi,tcp)。根据经验,一般情况下使用nccl可以通过GPU进行分布式训练,而使用gloo可以通过CPU进行分布式训练。在此处了解有关它们的更多信息https://pytorch.org/tutorials/intermediate/dist_tuto.html#advanced-topics
在每个GPU上启动单独的进程。同样使用torch.distributed.launch实用程序功能。假设我们在群集节点上有4个GPU,我们希望在这些GPU上用于设置分布式培训。可以使用以下shell命令来执行此操作。
python -m torch.distributed.launch --nproc_per_node=4
--nnodes=1 --node_rank=0
--master_port=1234 train.py <OTHER TRAINING ARGS>
在设置启动脚本时,我们必须在将运行主进程并用于与其他GPU通信的节点上提供一个空闲端口(在这种情况下为1234)。
以下是涵盖所有步骤的完整PyTorch要点。
import argparse
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
#prase the local_rank argument from command line for the current process
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
#setup the distributed backend for managing the distributed training
torch.distributed.init_process_group('nccl')
#Setup the distributed sampler to split the dataset to each GPU.
dist_sampler = DistributedSampler(dataset)
dataloader = DataLoader(dataset, sampler=dist_sampler)
#set the cuda device to a GPU allocated to current process .
device = torch.device('cuda', args.local_rank)
model = model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank)
#Start training the model normally.
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
loss = loss_fn(preds, labels)
loss.backward()
optimizer.step()
请注意,上述实用程序调用是针对GPU集群上的单个节点的。 此外,如果要使用多节点设置,则必须在选择启动实用程序时选择一个节点作为主节点,并提供master_addr参数,如下所示。 假设我们有2个节点,每个节点有4个GPU,第一个IP地址为" 192.168.1.1"的节点是主节点。 我们必须分别在每个节点上启动启动脚本,如下所示。
在第一个节点上运行
python -m torch.distributed.launch --nproc_per_node=4
--nnodes=1 --node_rank=0
--master_addr="192.168.1.1" --master_port=1234 train.py <OTHER TRAINING ARGS>
在第二个节点上,运行
python -m torch.distributed.launch --nproc_per_node=4
--nnodes=1 --node_rank=1
--master_addr="192.168.1.1" --master_port=1234 train.py <OTHER TRAINING ARGS>
其他实用程序功能:
在评估模型或生成日志时,需要从所有GPU收集当前批次统计信息,例如损失,准确率等,并将它们在一台机器上进行整理以进行日志记录。 PyTorch提供了以下方法,用于在所有GPU之间同步变量。
1. torch.distributed.gather(inputtensor,collectlist,dst):从所有设备收集指定的inputtensor并将它们放置在collectlist中的dst设备上。
1. torch.distributed.allgather(tensorlist,inputtensor):从所有设备收集指定的inputtensor并将其放置在所有设备上的tensor_list变量中。
1. torch.distributed.reduce(inputtensor,dst,reduceop = ReduceOp.SUM):收集所有设备的input_tensor并使用指定的reduce操作(例如求和,均值等)进行缩减。最终结果放置在dst设备上。
1. torch.distributed.allreduce(inputtensor,reduce_op = ReduceOp.SUM):与reduce操作相同,但最终结果被复制到所有设备。
有关参数和方法的更多详细信息,请阅读torch.distributed软件包。 https://pytorch.org/docs/stable/distributed.html
例如,以下代码从所有GPU提取损失值,并将其减少到主设备(cuda:0)。
#In continuation with distributedDataParallel.py abovedef get_reduced_loss(loss, dest_device):
loss_tensor = loss.clone()
torch.distributed.reduce(loss_tensor, dst=dest_device)
return loss_tensorif args.local_rank==0:
loss_tensor = get_reduced_loss(loss.detach(), 0)
print(f'Current batch Loss = {loss_tensor.item()}'
优点:相同的代码设置可用于单个GPU,而无需任何代码更改。 单个GPU设置仅需要具有适当设置的启动脚本。
缺点: BatchNorm之类的层在其计算中使用了整个批次统计信息,因此无法仅使用一部分批次在每个GPU上独立进行操作。 PyTorch提供SyncBatchNorm作为BatchNorm的替换/包装模块,该模块使用跨GPU划分的整个批次计算批次统计信息。 请参阅下面的示例代码以了解SyncBatchNorm的用法。
network = .... #some network with BatchNorm layers in itsync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network)
ddp_network = nn.parallel.DistributedDataParallel(
sync_bn_network,
device_ids=[args.local_rank], output_device=args.local_rank)
总结
· 要在GPU之间拆分模型,请将模型拆分为submodules,然后将每个submodule推送到单独的GPU。
· 要在GPU上拆分批次,请使用累积梯度nn.DataParallel或nn.DistributedDataParallel。
· 为了快速进行原型制作,可以首选nn.DataParallel。
· 为了训练大型模型并利用跨多个GPU的完全并行训练,应使用nn.DistributedDataParallel。
· 在使用nn.DistributedDataParallel时,用nn.SyncBatchNorm替换或包装nn.BatchNorm层。
作者:Nilesh Vijayrania
deephub翻译组
相关推荐
- Excel新函数TEXTSPLIT太强大了,轻松搞定数据拆分!
-
我是【桃大喵学习记】,欢迎大家关注哟~,每天为你分享职场办公软件使用技巧干货!最近我把WPS软件升级到了版本号:12.1.0.15990的最新版本,最版本已经支持文本拆分函数TEXTSPLIT了,并...
- Excel超强数据拆分函数TEXTSPLIT,从入门到精通!
-
我是【桃大喵学习记】,欢迎大家关注哟~,每天为你分享职场办公软件使用技巧干货!今天跟大家分享的是Excel超强数据拆分函数TEXTSPLIT,带你从入门到精通!TEXTSPLIT函数真是太强大了,轻松...
- 看完就会用的C++17特性总结(c++11常用新特性)
-
作者:taoklin,腾讯WXG后台开发一、简单特性1.namespace嵌套C++17使我们可以更加简洁使用命名空间:2.std::variant升级版的C语言Union在C++17之前,通...
- plsql字符串分割浅谈(plsql字符集设置)
-
工作之中遇到的小问题,在此抛出问题,并给出解决方法。一方面是为了给自己留下深刻印象,另一方面给遇到相似问题的同学一个解决思路。如若其中有写的不好或者不对的地方也请不加不吝赐教,集思广益,共同进步。遇到...
- javascript如何分割字符串(javascript切割字符串)
-
javascript如何分割字符串在JavaScript中,您可以使用字符串的`split()`方法来将一个字符串分割成一个数组。`split()`方法接收一个参数,这个参数指定了分割字符串的方式。如...
- TextSplit函数的使用方法(入门+进阶+高级共八种用法10个公式)
-
在Excel和WPS新增的几十个函数中,如果按实用性+功能性排名,textsplit排第二,无函数敢排第一。因为它不仅使用简单,而且解决了以前用超复杂公式才能搞定的难题。今天小编用10个公式,让你彻底...
- Python字符串split()方法使用技巧
-
在Python中,字符串操作可谓是基础且关键的技能,而今天咱们要重点攻克的“堡垒”——split()方法,它能将看似浑然一体的字符串,按照我们的需求进行拆分,极大地便利了数据处理与文本解析工作。基本语...
- go语言中字符串常用的系统函数(golang 字符串)
-
最近由于工作比较忙,视频有段时间没有更新了,在这里跟大家说声抱歉了,我尽快抽些时间整理下视频今天就发一篇关于go语言的基础知识吧!我这我工作中用到的一些常用函数,汇总出来分享给大家,希望对...
- 无规律文本拆分,这些函数你得会(没有分隔符没规律数据拆分)
-
今天文章来源于表格学员训练营群内答疑,混合文本拆分。其实拆分不难,只要规则明确就好办。就怕规则不清晰,或者规则太多。那真是,Oh,mygod.如上图所示进行拆分,文字表达实在是有点难,所以小熊变身灵...
- Python之文本解析:字符串格式化的逆操作?
-
引言前面的文章中,提到了关于Python中字符串中的相关操作,更多地涉及到了字符串的格式化,有些地方也称为字符串插值操作,本质上,就是把多个字符串拼接在一起,以固定的格式呈现。关于字符串的操作,其实还...
- 忘记【分列】吧,TEXTSPLIT拆分文本好用100倍
-
函数TEXTSPLIT的作用是:按分隔符将字符串拆分为行或列。仅ExcelM365版本可用。基本应用将A2单元格内容按逗号拆分。=TEXTSPLIT(A2,",")第二参数设置为逗号...
- Excel365版本新函数TEXTSPLIT,专攻文本拆分
-
Excel中字符串的处理,拆分和合并是比较常见的需求。合并,当前最好用的函数非TEXTJOIN不可。拆分,Office365于2022年3月更新了一个专业函数:TEXTSPLIT语法参数:【...
- 站长在线Python精讲使用正则表达式的split()方法分割字符串详解
-
欢迎你来到站长在线的站长学堂学习Python知识,本文学习的是《在Python中使用正则表达式的split()方法分割字符串详解》。使用正则表达式分割字符串在Python中使用正则表达式的split(...
- Java中字符串分割的方法(java字符串切割方法)
-
技术背景在Java编程中,经常需要对字符串进行分割操作,例如将一个包含多个信息的字符串按照特定的分隔符拆分成多个子字符串。常见的应用场景包括解析CSV文件、处理网络请求参数等。实现步骤1.使用Str...
- 因为一个函数strtok踩坑,我被老工程师无情嘲笑了
-
在用C/C++实现字符串切割中,strtok函数经常用到,其主要作用是按照给定的字符集分隔字符串,并返回各子字符串。但是实际上,可不止有strtok(),还有strtok、strtok_s、strto...
- 一周热门
- 最近发表
- 标签列表
-
- ps像素和厘米换算 (32)
- ps图案在哪里 (33)
- super().__init__ (33)
- python 获取日期 (34)
- 0xa (36)
- super().__init__()详解 (33)
- python安装包在哪里找 (33)
- linux查看python版本信息 (35)
- python怎么改成中文 (35)
- php文件怎么在浏览器运行 (33)
- eval在python中的意思 (33)
- python安装opencv库 (35)
- python div (34)
- sticky css (33)
- python中random.randint()函数 (34)
- python去掉字符串中的指定字符 (33)
- python入门经典100题 (34)
- anaconda安装路径 (34)
- yield和return的区别 (33)
- 1到10的阶乘之和是多少 (35)
- python安装sklearn库 (33)
- dom和bom区别 (33)
- js 替换指定位置的字符 (33)
- python判断元素是否存在 (33)
- sorted key (33)