tensorflow注册自己实现的Op

编程入门 行业动态 更新时间:2024-10-06 18:32:37

<a href=https://www.elefans.com/category/jswz/34/1769423.html style=tensorflow注册自己实现的Op"/>

tensorflow注册自己实现的Op

目的是云端算法中执行LSTM部分计算过程的加速,即用cu文件编译出so,用此so中的LSTM类或函数替代tf.LSTMCell进行运算。
整个项目见Github,流程见博客,博主也刚入门cuda,欢迎留言探讨~

1. 源代码编译tensorflow

因为我们要对tf库进行修改,所以需要用源码编译方式重新安装tensorflow,官方步骤写的很清楚,就不自己瞎写了。

2. 注册OP流程:

  1. 定义 Op 的接口,即按规则写好cc文件

  2. 为 Op 实现 kernel,即你自己的.cu文件

  3. 编译出so,即(BUILD.sh)文件,上述三个文件如下,同样先看官方网站,再来看例子会豁然开朗

3. 例子

  1. 按上图准备好cc文件
#include <stdio.h>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/framework/allocator.h"
#include "fsmn_forward.h"
#include <cstddef>
#include <iostream>
#include <algorithm>
using namespace tensorflow;
void LSTMTest(int miniBatch, int seqLength, int inputSize, int hiddenSize, int outSize,float* input, float* h_data_in, float *c_data_in, float* weight_i, float* weight_h, float*bias_data_in, float* w_i_diag_in,float* w_f_diag_in, float* w_o_diag_in, float* proj_kernel_in, float* h_data_out, float* c_data_out, float* output,bool use_peepholes = true, float cell_clip = 0.0, float proj_clip = 0.0);REGISTER_OP("CudaLstmForward")
.Input("input: float32")
.Input("cdata_in: float32")
.Input("hdata_in: float32")
.Input("weight_i: float32")
.Input("weight_h: float32")
.Input("bias: float32")
.Input("w_i_diag: float32")
.Input("w_f_diag: float32")
.Input("w_o_diag: float32")
.Input("proj_kernel: float32")
.Output("cdata_out: float32")
.Output("hdata_out: float32")
.Output("output: float32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {c->set_output(0, c->Matrix(c->Dim(c->input(1), 0), c->Dim(c->input(1), 1)));c->set_output(1, c->Matrix(c->Dim(c->input(1), 0), c->Dim(c->input(2), 1)));c->set_output(2, c->Matrix(c->Dim(c->input(0), 0), c->Dim(c->input(2), 1)));return Status::OK();
});class CudaLstmForwardOp : public OpKernel {
public:explicit CudaLstmForwardOp(OpKernelConstruction* ctx) : OpKernel(ctx){}void Compute(OpKernelContext* ctx) override {//printf("begin....");Tensor input_data= ctx->mutable_input(0, true);Tensor cdata_in= ctx->mutable_input(1, true);Tensor hdata_in= ctx->mutable_input(2, true);//printf("input....");OP_REQUIRES(ctx, input_data.shape().dims() == 2,errors::InvalidArgument("input data is not a 2-Tensor"));OP_REQUIRES(ctx, hdata_in.shape().dims() == 2,errors::InvalidArgument("hdata in is not a 2-Tensor"));OP_REQUIRES(ctx, cdata_in.shape().dims() == 2,errors::InvalidArgument("cdata in is not a 2-Tensor"));//printf("weight....");Tensor weight_i= ctx->mutable_input(3, true);Tensor weight_h= ctx->mutable_input(4, true);Tensor bias= ctx->mutable_input(5, true);Tensor wi_diag= ctx->mutable_input(6, true);Tensor wf_diaf= ctx->mutable_input(7, true);Tensor wo_diag= ctx->mutable_input(8, true);Tensor proj= ctx->mutable_input(9, true);auto inputdata_t = input_data.tensor<float, 2>();auto cdatain_t = cdata_in.tensor<float, 2>();auto hdatain_t = hdata_in.tensor<float, 2>();auto weighti_t = weight_i.tensor<float, 2>();auto weighth_t = weight_h.tensor<float, 2>();auto bias_t = bias.tensor<float, 1>();auto widiag_t = wi_diag.tensor<float, 1>();auto wfdiag_t = wf_diaf.tensor<float, 1>();auto wodiag_t = wo_diag.tensor<float, 1>();auto proj_t = proj.tensor<float, 2>();const auto &acti_shape = input_data.shape();int seq_batch = acti_shape.dim_size(0);int inputsize = acti_shape.dim_size(1);const auto &acth_shape = cdata_in.shape();int batch = acth_shape.dim_size(0);int hiddensize = acth_shape.dim_size(1);const auto &actc_shape = hdata_in.shape();int outputsize = actc_shape.dim_size(1);int length = seq_batch/batch;// Create an state out tensorTensor *state_outc = nullptr;TensorShape indice_shape({batch, hiddensize});OP_REQUIRES_OK(ctx, ctx->allocate_output("cdata_out", indice_shape, &state_outc));auto statec_t = state_outc->tensor<float, 2>();// Create an state out tensorTensor *state_outh = nullptr;TensorShape indice_shape1({batch, outputsize});OP_REQUIRES_OK(ctx, ctx->allocate_output("hdata_out", indice_shape1, &state_outh));auto stateh_t = state_outh->tensor<float, 2>();// Create an output tensorTensor *out_put = nullptr;TensorShape indice_shape2({seq_batch, outputsize});OP_REQUIRES_OK(ctx, ctx->allocate_output("output", indice_shape2, &out_put));auto out_t = out_put->tensor<float, 2>();// 执行计算操作LSTMTest(batch, length, inputsize, hiddensize, outputsize,inputdata_t.data(), cdatain_t.data(), hdatain_t.data(),weighti_t.data(), weighth_t.data(), bias_t.data(),widiag_t.data(), wfdiag_t.data(), wodiag_t.data(), proj_t.data(),statec_t.data(), stateh_t.data(), out_t.data(),true, 0.0, 50.0);}private:}; //class CudaLstmForward end
REGISTER_KERNEL_BUILDER(Name("CudaLstmForward").Device(::tensorflow::DEVICE_CPU), CudaLstmForwardOp);
REGISTER_KERNEL_BUILDER(Name("CudaLstmForward").Device(DEVICE_GPU), CudaLstmForwardOp);

00_lstm.cu

extern "C" void LSTMTest(int miniBatch, int seqLength, int inputSize, int hiddenSize, int outSize,float* input, float* c_data_in, float *h_data_in, float* weight_i, float* weight_h, float*bias_data_in, float* w_i_diag_in,float* w_f_diag_in, float* w_o_diag_in, float* proj_kernel_in, float* c_data_out, float* h_data_out, float* output,bool use_peepholes = true, float cell_clip = 0.0, float proj_clip = 0.0){static int layer0_size = (7 + 320 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 320 + (miniBatch * 4) * 320 + (320 + 320 + miniBatch * 4 + miniBatch) * 4 * 1536;static int layer1_size = (7 + 320 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 320 + (miniBatch * 4) * 320 + (320 + 320 + miniBatch * 4 + miniBatch) * 4 * 1536;static int layer2_size = (7 + 448 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 448 + (miniBatch * 4) * 320 + (320 + 448 + miniBatch * 4 + miniBatch) * 4 * 1536;static int layer3_size = (7 + 448 + 2 * miniBatch + 4 * miniBatch) * 1536 + (2 * miniBatch + miniBatch * 4) * 448 + (miniBatch * 4) * 448 + (448 + 448 + miniBatch * 4 + miniBatch) * 4 * 1536;static float *w_diag_bias_proj;static float *init_pointer;static int all_layer_size = layer0_size + layer1_size + layer2_size + layer3_size;static int flag = 0;float alpha = 1.f;float beta  = 0.f;int input_depth = inputSize;int gateSize = hiddenSize * 4;int h_depth;int numElements = miniBatch * hiddenSize;if(use_peepholes == true)h_depth = outSize;elseh_depth = hiddenSize;int w_diag_bias_proj_size = (7 + h_depth) * hiddenSize;int h_op_data_size = 2 * miniBatch * h_depth + miniBatch * seqLength * h_depth;int input_T_size = miniBatch * seqLength * input_depth + (input_depth + h_depth) * gateSize;int c_o_data_size = 2 * miniBatch * hiddenSize + miniBatch * seqLength * hiddenSize;int tmp_i_size = miniBatch * seqLength * gateSize;// int tmp_h_size = miniBatch * gateSize;// printf("the seqLength is: %d, inputSize: %d, input_depth: %d, hiddenSize: %d, outSize: %d\n", seqLength, inputSize, input_depth, hiddenSize, outSize);cudaErrCheck(cudaGetLastError());if(flag == 0){cudaErrCheck(cudaMalloc((void**)&w_diag_bias_proj, all_layer_size * sizeof(float)));init_pointer = w_diag_bias_proj;}if(flag == 1)w_diag_bias_proj = w_diag_bias_proj + layer0_size;else if(flag == 2)w_diag_bias_proj = w_diag_bias_proj + layer1_size;else if(flag == 3)w_diag_bias_proj = w_diag_bias_proj + layer2_size;flag++;printf("flag: %d\n", flag);//b = a + size(a);float *input_T = w_diag_bias_proj + w_diag_bias_proj_size;//c = b + size(b);float *h_op_data = input_T + input_T_size;float *c_o_data = h_op_data + h_op_data_size;float *tmp_i = c_o_data + c_o_data_size;float *tmp_h = tmp_i + tmp_i_size;cudaStream_t stream_i, stream_h;cudaErrCheck(cudaStreamCreate(&stream_i));cudaErrCheck(cudaStreamCreate(&stream_h));bool stream_i_flag = true;cudaErrCheck(cudaMemcpyAsync(input_T, input, miniBatch * input_depth * seqLength * sizeof(float), cudaMemcpyHostToDevice, stream_i));cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength, weight_i, input_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_i));cudaErrCheck(cudaMemcpyAsync(input_T + miniBatch * input_depth * seqLength + input_depth * gateSize, weight_h, h_depth * gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));// printf("*************************%d\n", hiddenSize);cudaErrCheck(cudaMemcpyAsync(h_op_data, h_data_in, h_depth * miniBatch * sizeof(float), cudaMemcpyHostToDevice, stream_h));cudaErrCheck(cudaMemcpyAsync(c_o_data, c_data_in, numElements * sizeof(float), cudaMemcpyHostToDevice, stream_h));// printf("i_data up and i_data_beforeProj down and the seqLength is%d\n", seqLength);cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj, w_i_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + hiddenSize, w_f_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 2 * hiddenSize, w_o_diag_in, hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 3 * hiddenSize , bias_data_in, gateSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));cudaErrCheck(cudaMemcpyAsync(w_diag_bias_proj + 7 * hiddenSize, proj_kernel_in, h_depth * hiddenSize * sizeof(float), cudaMemcpyHostToDevice, stream_h));cudaErrCheck(cudaGetLastError());// cudaDeviceSynchronize();// Need a cuBLAS handle.cublasHandle_t handle;cublasErrCheck(cublasCreate(&handle));cublasErrCheck(cublasSetStream(handle, stream_i));cublasErrCheck(cublasSgemm(handle,CUBLAS_OP_N, CUBLAS_OP_N,gateSize, miniBatch * seqLength, input_depth,&alpha,input_T + miniBatch * input_depth * seqLength,gateSize,input_T,input_depth,&beta,tmp_i,gateSize));cudaErrCheck(cudaGetLastError());for(int i = 0; i < seqLength; ++i){// cudaEventRecord(event1, 0);cublasErrCheck(cublasSetStream(handle, stream_h));cublasErrCheck(cublasSgemm(handle,CUBLAS_OP_N, CUBLAS_OP_N,gateSize, miniBatch, h_depth,&alpha,input_T + miniBatch * input_depth * seqLength + input_depth * gateSize,gateSize,h_op_data,h_depth ,&beta,tmp_h,gateSize));dim3 blockDim;dim3 gridDim;blockDim.x = 256;gridDim.x = (miniBatch * hiddenSize + blockDim.x - 1) / blockDim.x;if(stream_i_flag == true)cudaErrCheck(cudaStreamSynchronize(stream_i));elementWise_fp <<< gridDim, blockDim, 0 , stream_h >>>(hiddenSize, miniBatch,tmp_h,tmp_i + i * miniBatch * gateSize,w_diag_bias_proj + 3 * hiddenSize,NULL,h_op_data + miniBatch * h_depth,c_o_data + 2 * numElements + i * miniBatch * hiddenSize,c_o_data,c_o_data + numElements,false,w_diag_bias_proj,w_diag_bias_proj + hiddenSize,w_diag_bias_proj + 2 * hiddenSize,use_peepholes,h_depth,cell_clip);if(stream_i_flag == true){cudaErrCheck(cudaStreamDestroy(stream_i));stream_i_flag = false;}cudaErrCheck(cudaGetLastError());if(use_peepholes != 0){cublasErrCheck(cublasSgemm(handle,CUBLAS_OP_N, CUBLAS_OP_N,h_depth, miniBatch, hiddenSize,&alpha,w_diag_bias_proj + 7 * hiddenSize,h_depth,c_o_data + 2 * numElements + i * miniBatch * hiddenSize,hiddenSize,&beta,h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth,h_depth));if(proj_clip != 0){// printf("in proj_clip\n");dim3 blockDim;dim3 gridDim;blockDim.x = 256;gridDim.x = (h_depth * miniBatch + blockDim.x - 1) / blockDim.x;clip_by_value <<< gridDim, blockDim, 0, stream_h >>>(h_op_data + 2 * miniBatch * h_depth + i * h_depth * miniBatch, proj_clip, miniBatch * h_depth);}//h_data和i_data保持同步}cudaErrCheck(cudaMemcpy(h_op_data + miniBatch * h_depth, h_op_data + 2 * miniBatch * h_depth + i * miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));cudaErrCheck(cudaMemcpy(h_op_data, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToDevice));cudaErrCheck(cudaMemcpy(c_o_data, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToDevice));cudaErrCheck(cudaGetLastError());cudaErrCheck(cudaGetLastError());}cudaErrCheck(cudaMemcpy(h_data_out, h_op_data + miniBatch * h_depth, miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));cudaErrCheck(cudaMemcpy(c_data_out, c_o_data + numElements, miniBatch * hiddenSize * sizeof(float), cudaMemcpyDeviceToHost));cudaErrCheck(cudaMemcpy(output, h_op_data + 2 * miniBatch * h_depth, seqLength * miniBatch * h_depth * sizeof(float), cudaMemcpyDeviceToHost));if(flag == 4){cudaErrCheck(cudaFree(init_pointer));flag = 0;}cudaErrCheck(cudaStreamDestroy(stream_h));}  int main(int argc, char* argv[]) {int seqLength;int numLayers;int hiddenSize;int miniBatch; bool use_peepholes;int num_proj;float cell_clip = 0.0;float proj_clip = 0.0;if (argc == 5) {seqLength = atoi(argv[1]);numLayers =  atoi(argv[2]);hiddenSize =  atoi(argv[3]);miniBatch =  atoi(argv[4]);   }else if (argc == 1) {printf("Running with default settings\n");seqLength = 2;numLayers = 1;hiddenSize = 1536;miniBatch = 100;use_peepholes = true;num_proj = 320;cell_clip = 0.00;proj_clip = 50.00;}else {printf("Usage: ./LSTM <seqLength> <numLayers> <hiddenSize> <miniBatch>\n");return 1;      }printf("seqLength %d, numLayers %d, num_proj %d, miniBatch %d\n", seqLength, numLayers, num_proj, miniBatch);  int outSize = num_proj;int numRuns = 4;  float totalTime = 0.f;int input_depth = 320;float* input = init_Matrix(miniBatch * input_depth * seqLength);float* h_data_in = init_Matrix_zeros(miniBatch * num_proj);float* c_data_in = init_Matrix_zeros(miniBatch * hiddenSize);float* weight_i = init_Matrix(input_depth* hiddenSize * 4);float* weight_h = init_Matrix(num_proj* hiddenSize * 4);float* bias_data_in = init_Matrix_zeros(hiddenSize * 4);float* w_i_diag_in = init_Matrix(hiddenSize);float* w_f_diag_in = init_Matrix(hiddenSize);float* w_o_diag_in = init_Matrix(hiddenSize);float* proj_kernel_in = init_Matrix(hiddenSize * num_proj);float* h_data_out = init_Matrix_zeros(miniBatch * num_proj);float* c_data_out = init_Matrix_zeros(miniBatch * hiddenSize);float* output = init_Matrix_zeros(miniBatch * seqLength * num_proj);for (int run = 0; run < numRuns; run++) {LSTMTest(miniBatch, seqLength, input_depth, hiddenSize, outSize,input, c_data_in, h_data_in, weight_i, weight_h, bias_data_in, w_i_diag_in,w_f_diag_in, w_o_diag_in, proj_kernel_in, c_data_out, h_data_out, output, use_peepholes, cell_clip, proj_clip);}printf("Runtime %fms\n", totalTime / numRuns);return time < 0;
}
  1. 混合编译.c/.cpp与.cu文件
  • 即在cpp里使用cu文件,编译cpp时将编译好的cuda库链接进来

分别编译:g++ -o test 00_lstm.o 01_cpptest_cuda_lstm.o -lcudart -L/usr/local/cuda/lib64 -lcublas -lcurand -L/home/resources/yxwang/cuda-10.0/lib64/

静态库: nvcc -lib 00_lstm.cu -o lib00_lstm.a

g++ -o test 00_lstm.o 01_cpptest_cuda_lstm.o -L/usr/local/cuda/lib64

动态库(BUILD.sh):

#注意要在源码编译后的tensorflow文件夹编译,pwd=~/tensorflow/tensorflow/core/user_ops
#注意执行sh用cpu1.5.0版tf[py27tf15],单纯运行py或cuda用gpu1.4.0版tf[py27tf15gpu], [py27tf15s]为1.14版cpu
TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') );TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )#g++ -std=c++11 -shared -o cuda_lstm_forward.so -c cuda_lstm_forward -ltestcu -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2  -ltensorflow_framework -L /home/resources/yxwang/cuda-10.0/lib64/ -lcublas -lcurand#把00_lstm.cu -o 成lib00_lstm.so
nvcc -o lib00_lstm.so -shared -Xcompiler -fPIC 00_lstm.cu -lcublas -lcurand -L /home/resources/yxwang/cuda-10.0/lib64/#把cuda_lstm_forward -o成cuda_lstm_forward.so,用到-l00_lstm -L.
g++ -std=c++11 -shared cuda_lstm_forward -o cuda_lstm_forward.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2  -ltensorflow_framework -l00_lstm -L. -lcublas -lcurand -L/home/resources/yxwang/cuda-10.0/lib64/#-l链接库名 -库地址 
#单纯编译测试:nvcc -g -G 00_lstm.cu -o 00_lstm -L -arch=sm_52 -DPERFOPTS=31 -lcublas -lcurand

或写Makefile, : 后为依赖项,从下往上看

all : cppcpp : lib00_lstm.sog++ 01_cpptest_cuda_lstm.cpp -o 01_cpptest_cuda_lstm /home/resources/yxwang/cuda-10.0/lib64/libcublas.so -l00_lstm -L.lib00_lstm.so : 00_lstm.cunvcc -o lib00_lstm.so -shared -Xcompiler -fPIC 00_lstm.cu -lcublas -lcurand -L /home/resources/yxwang/cuda-10.0/lib64/

更多推荐

tensorflow注册自己实现的Op

本文发布于:2024-02-27 18:29:10,感谢您对本站的认可!
本文链接:https://www.elefans.com/category/jswz/34/1765509.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
本文标签:tensorflow   Op

发布评论

评论列表 (有 0 条评论)
草根站长

>www.elefans.com

编程频道|电子爱好者 - 技术资讯及电子产品介绍!