pytorch 添加c++实现的自定义op

pytorch已经基本实现了常见的各种op,然而,当想实现一个pytorch中没有的op时,有两种方式。一种方式是这个op可以由pytorch中已有的op进行组合而成,因此只需要使用python接口进行组合就可以了。反之,就必须使用c++或者cuda实现该op,然后添加到pytorch中。本文将介绍添加c++实现的自定义op,因为我还不会cuda : (

本文介绍使用python的setuptools将c++实现的op添加到pytorch中。首先要用c++实现定义的op。比如想实现一个op为 z=3x-y 。头文件为my_op.h

#include <torch/extension.h> //这一句是无论要实现任何op都必须添加的
#include <vector>

//前向传播
torch::Tensor my_op_forward(const torch::Tensor& x, const torch::Tensor& y);
//反向传播
std::vector<torch::Tensor> my_op_backward(const torch::Tensor& gradOutput);

源文件为my_op.cpp

#include "my_op.h"

torch::Tensor my_op_forward(const torch::Tensor& x,                             const torch::Tensor& y) {     
     AT_ASSERTM(x.sizes() == y.sizes(), "x must be the same size as y");
     torch::Tensor z = torch::zeros(x.sizes());
     z = 3 * x - y;
     return z; } 

std::vector<torch::Tensor> my_op_backward(const torch::Tensor& gradOutput) {
     torch::Tensor gradOutputX = 3 * gradOutput * torch::ones(gradOutput.sizes());
     torch::Tensor gradOutputY = -1 * gradOutput * torch::ones(gradOutput.sizes());
     return {gradOutputX, gradOutputY}; } 

// pybind11 绑定 
PYBIND11_MODULE(my_op_api, m) {
     m.def("forward", &my_op_forward, "MY_OP forward");
     m.def("backward", &my_op_backward, "MY_OP backward"); 
} 

其中最后的PYBIND11_MODULE是用来将C++函数绑定到python上的。其中第一个参数my_op_api为要生成的python模块名,以后import my_op_api就可以调用该op了。第二个参数固定为m
函数体中的两个语句分别是绑定前向传播与反向传播到实现的两个函数上。

然后编写setup.py,用来构建pytorch的c++扩展。

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(name='my_op_api',
      version='0.l',
      ext_modules=[CppExtension('my_op_api', sources=['my_op.cpp'], extra_compile_args=['-std=c++11'])],
      cmdclass={'build_ext':BuildExtension})

其中,setup中的name以及CppExtension中第一个参数(也是name)要和PYBIND11_MODULE里设的模块名保持一致,这里都是my_op_api。CppExtension中的extra_compile_args=[‘-std=c++11’]是趟坑发现的,不加的话gcc可能会报n多错(pytorch是用c++11编译的,因此这里用gcc编译的时候也要使用c++11)

然后运行python setup.py install,如果没有问题的话,就生成了所需的python模块。可以从输出的信息看到该模块在所在python环境下的site-packages文件夹下,以.egg结尾。另外,在当前目录下会有3个文件夹生成,build、dist、my_op_api.egg-info,其中dist下也有.egg文件,可以发布到其它python环境。

然后就可以在python中import my_op_api进行调用扩展的op了。这里需要注意一点的是,在import 自定义的op之前,必须先import torch。 但是,这样的op和我们日常使用的还是不太一样,这时需要将它包装为pytorch中的函数和模块,以便我们像使用其它模块一样使用自定义的op。要包装为模块,首先包装成函数。

包装成函数,需要继承torch.autograd.Function。然后包装成模块,需要继承torch.nn.Module

import torch
from torch.autograd import Function
from torch.nn import Module
import my_op_api

class MyOpFunction(Function):
    @staticmethod
    def forward(ctx, x, y):
        #如果有一些信息,需要在梯度反向传播时用到,可以使用ctx.save_for_backward()进行保存
        return my_op_api.forward(x, y)
    @staticmethod
    def backward(ctx, gradOutput):
        #如果在forward中保存了信息,可以使用ctx.saved_tensors取回
        grad_x, grad_y = my_op_api.backward(gradOutput)
        return grad_x, grad_y

class MyOpModule(Module):
    def __init__(self):
        super(MyOpModule, self).__init__()
    def forward(self, input_x, input_y):#只需要定义forward的函数就可以了
        return MyOpFunction.apply(input_x, input_y)