百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

pytorch的一个最简单的cpp扩展_pytorch cdist

itomcoil 2025-02-20 15:56 26 浏览

pytorch的一个最最简单cpp扩展

首先为什么要进行pytorch的cpp扩展?

1. 性能提升:虽然 Python 作为 PyTorch 的主要开发语言,提供了极高的开发效率和便利性,但是 Python 解释器在执行效率上相比编译型语言如 C++ 有所不足。通过 C++ 扩展,可以实现对性能敏感的代码段进行优化,直接利用 C++ 的高速运行能力,减少解释器的开销,从而提升整体应用的运行效率。

2. 集成自定义算法:对于一些复杂的自定义算法或模型组件,可能已有现成的高性能 C++ 实现,或者需要直接操作硬件(如GPU通过CUDA编程)以达到最佳性能。C++ 扩展允许直接在 PyTorch 中复用这些现有代码,无需完全重写为 Python。

3. 灵活性与定制化:C++ 扩展提供了更低层次的访问权限,开发者可以直接操作 PyTorch 的内部数据结构(如 Tensor),进行更精细的内存管理和计算优化,实现更灵活的模型定制。

4. 跨平台兼容性:C++ 是一种广泛支持的系统级编程语言,通过 C++ 扩展,可以更容易地在不同平台上(包括服务器和嵌入式设备)部署 PyTorch 模型,同时利用 C++ 的标准库和第三方库来处理平台相关的任务。

5. 与现有C/C++项目集成:对于那些已经有大量C++代码基础的项目,通过C++扩展,可以平滑地将PyTorch集成到现有的系统中,重用现有代码,减少开发工作量。

那么对pytorch进行cpp扩展的步骤包含哪些?

1. 安装依赖

首先确保你的系统中已安装了 PyTorch。此外,需要安装 pybind11,这是一个用于绑定 C++ 和 Python 的库,使得 C++ 代码能够被 Python 调用。可以通过 pip 或 conda 安装 pybind11

2. 编写 C++/CUDA 代码

根据需求编写 C++ 代码实现你的功能,这可能包括定义类或函数,尤其是实现前向传播 (forward) 和反向传播 (backward) 函数,如果是 GPU 相关的计算,还需要编写 CUDA 代码。

使用 pybind11 在 C++ 代码中创建 Python 绑定,这样 Python 就能调用这些 C++ 函数或类。

3. 编写 setup.py

创建一个 setup.py 文件,这个文件指导 Python 如何编译你的 C++ 代码,并将其作为 Python 包安装。你需要在 setup.py 中指定你的源文件、编译选项以及依赖关系等。

4. 编译与安装

运行 python setup.py install 或类似的命令来编译你的 C++ 代码,并将其作为 Python 模块安装到当前环境。这一步骤会使用 setuptools 或类似工具自动完成编译和链接过程。

5. 测试与使用

编译安装完成后,你可以在 Python 脚本中通过常规的 import 语句来导入你的 C++ 扩展,并像使用任何其他 Python 模块一样使用它。

我们来做一个最最简单的cpp函数扩展


1. c++代码


#include


torch::Tensor custom_op(torch::Tensor a,torch::Tensor b) {

auto s = torch::max(a);

return s;

}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

m.def("custom_op", &custom_op);

}


首先代码包含了PyTorch提供的扩展头文件,torch/extension.h。这个头文件内部包含了对pybind11以及其他必要的PyTorch API的引用,使得开发者能够轻松地将C++代码与PyTorch集成。

然后函数定义了一个名为custom_op的操作,它接受两个torch::Tensor类型的参数(尽管第二个参数b在这个示例中未被使用)。函数内部,它调用了torch::max()函数来计算输入张量a的最大值,并返回这个最大值。注意,实际上这个操作并没有利用到b,可能是因为示例简化的原因。

最后使用了pybind11库来将C++函数custom_op绑定到Python。PYBIND11_MODULE是一个宏,用于定义模块的入口点。TORCH_EXTENSION_NAME是一个宏,它在编译时会被替换为实际的模块名称,确保模块名称的一致性和唯一性。m.def("custom_op", &custom_op);这一行代码将C++函数custom_op注册为Python模块中的一个函数,使得Python代码可以通过这个名称调用这个C++实现的函数。


2. setup.py

from setuptools import setup, Extension

from torch.utils import cpp_extension


setup(name='custom_op',

ext_modules=[cpp_extension.CppExtension('custom_op', ['custon_op.cpp'])],

cmdclass={'build_ext': cpp_extension.BuildExtension})


setuptools 是Python用来创建和打包Python项目的库,它扩展了distutils库的功能,提供了更多的灵活性和特性。setupExtension 类便是来自此库,分别用于定义项目元数据和C扩展模块。

cpp_extension 是PyTorch提供的工具,用于简化C++扩展的构建过程,特别是当这些扩展需要与PyTorch交互时(例如,使用CUDA或Torch API)。

setup() 函数是 setuptools 的核心,用于定义项目的元数据和构建指令。这里,name='custom_op' 指定了包的名称,即将创建的Python包的名字为custom_op

ext_modules 参数告诉 setuptools 需要构建的扩展模块列表。这里只有一个扩展,即使用 cpp_extension.CppExtension 定义的。

'custom_op' 是扩展模块的名称,它应该与你希望在Python中导入的模块名相匹配。

['custom_op.cpp'] 是源文件列表,指明了C++源代码文件的名称,这里假设有一个名为custom_op.cpp的文件,它包含了C++实现的逻辑。

cmdclass 参数允许覆盖或扩展默认的构建命令。这里,我们指定了 build_ext 命令应使用 cpp_extension.BuildExtension 类来执行。这个类是cpp_extension模块提供的,专为处理包含PyTorch C++扩展的构建过程设计,它知道如何正确地调用nvcc(对于CUDA代码)或其他C++编译器,并处理与PyTorch库的链接。


3. 编译安装

python setup.py install


4. 测试

import torch

import custom_op


a = torch.rand(2,3)

b = torch.rand(3,4)


print(a)

print(b)

print(custom_op.custom_op(a,b))


最后直接运行就可以验证扩展函数的正确性了

相关推荐

MySQL修改密码_mysql怎么改密码忘了怎么办

拥有原来的用户名账户的密码mysqladmin-uroot-ppassword"test123"Enterpassword:【输入原来的密码】忘记原来root密码第一...

数据库密码配置项都不加密?心也太大了吧!

先看一份典型的配置文件...省略...##配置MySQL数据库连接spring.datasource.driver-class-name=com.mysql.jdbc.Driverspr...

Linux基础知识_linux基础入门知识

系统目录结构/bin:命令和应用程序。/boot:这里存放的是启动Linux时使用的一些核心文件,包括一些连接文件以及镜像文件。/dev:dev是Device(设备)的缩写,该目录...

MySQL密码重置_mysql密码重置教程

之前由于修改MySQL加密模式为mysql_native_password时操作失误,导致无法登陆MySQL数据库,后来摸索了一下,对MySQL数据库密码进行重置后顺利解决,步骤如下:1.先停止MyS...

Mysql8忘记密码/重置密码_mysql密码忘了怎么办?

Mysql8忘记密码/重置密码UBUNTU下Mysql8忘记密码/重置密码步骤如下:先说下大概步骤:修改配置文件,使得用空密码可以进入mysql。然后置当前root用户为空密码。再次修改配置文件,不能...

MySQL忘记密码怎么办?Windows环境下MySQL密码重置图文教程

有不少小白在使用Windows进行搭建主机的时候,安装了一些环境后,其中有MySQL设置后,然后不少马大哈忘记了MySQL的密码,导致在一些程序安装及配置的时候无法进行。这个时候怎么办呢?重置密码呗?...

10种常见的MySQL错误,你可中招?_mysql常见错误提示及解决方法

【51CTO.com快译】如果未能对MySQL8进行恰当的配置,您非但可能遇到无法顺利访问、或调用MySQL的窘境,而且还可能给真实的应用生产环境带来巨大的影响。本文列举了十种MySQL...

Mysql解压版安装过程_mysql解压版安装步骤

Mysql是目前软件开发中使用最多的关系型数据库,具体安装步骤如下:第一步:Mysql官网下载最新版(mysql解压版(mysql-5.7.17-winx64)),Mysql官方下载地址为:https...

MySQL Root密码重置指南:Windows新手友好教程

如果你忘记了MySQLroot密码,请按照以下简单步骤进行重置。你需要准备的工具:已安装的MySQL以管理员身份访问命令提示符一点复制粘贴的能力分步操作指南1.创建密码重置文件以管理员...

安卓手机基于python3搜索引擎_python调用安卓so库

环境:安卓手机手机品牌:vivox9s4G运行内存手机软件:utermux环境安装:1.java环境的安装2.redis环境的安装aptinstallredis3.elasticsearch环...

Python 包管理 3 - poetry_python community包

Poetry是一款现代化的Python依赖管理和打包工具。它通过一个pyproject.toml文件来统一管理你的项目依赖、配置和元数据,并用一个poetry.lock文件来锁定所有依赖的精...

Python web在线服务生产环境真实部署方案,可直接用

各位志同道合的朋友大家好,我是一个一直在一线互联网踩坑十余年的编码爱好者,现在将我们的各种经验以及架构实战分享出来,如果大家喜欢,就关注我,一起将技术学深学透,我会每一篇分享结束都会预告下一专题最近经...

官方玩梗:Python 3.14(πthon)稳定版发布,正式支持自由线程

IT之家10月7日消息,当地时间10月7日,Python软件基金会宣布Python3.14.0正式发布,也就是用户期待已久的圆周率(约3.14)版本,再加上谐音梗可戏称为π...

第一篇:如何使用 uv 创建 Python 虚拟环境

想象一下,你有一个使用Python3.10的后端应用程序,系统全局安装了a2.1、b2.2和c2.3这些包。一切运行正常,直到你开始一个新项目,它也使用Python3.10,但需要...

我用 Python 写了个自动整理下载目录的工具

经常用电脑的一定会遇到这种情况:每天我们都在从浏览器、微信、钉钉里下各种文件,什么截图、合同、安装包、临时文档,全都堆在下载文件夹里。起初还想着“过两天再整理”,结果一放就是好几年。结果某天想找一个发...