OpenAI/Triton MLIR 第四章: ROCm-triton配置
最近在整理python-based的benchmark代码,反过来在NV的GPU上又把Triton装了一遍,发现Triton的github repo已经给出了对应的llvm的commit id以及对应的编译细节,然后跟着走了一遍,也顺利的安装成功,只需要按照如下方式即可完成NV GPU上的安装,
1.gitclonehttps://github.com/openai/triton.git; 2.cdtriton; 3.cd$HOME/llvm-project#yourcloneofLLVM. 4.gitcheckout49af6502 5.mkdirbuild 6.cdbuild 7.cmake-GNinja-DCMAKE_BUILD_TYPE=Release-DLLVM_ENABLE_ASSERTIONS=ON../llvm-DLLVM_ENABLE_PROJECTS="mlir;llvm" 8.ninja-j8 exportLLVM_BUILD_DIR=$HOME/llvm-project/build cdLLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib LLVM_SYSPATH=$LLVM_BUILD_DIR pipinstall-epython
出现3.0.0说明triton已经安装成功了,装完triton后一定要安装Torch,为个人使用的是CUDA 12.1版本,按照下面的命令无脑安装即可。
pipinstalltorch==2.1.2torchvision==0.16.2torchaudio==2.1.2--index-urlhttps://download.pytorch.org/whl/cu121
NV GPU上triton的安装和使用其实已经轻车熟路了,接下来,让我们来探索一下AMD GPU上如何安装和配置triton。
0x00 软件安装
关于triton amd的backend,虽然triton的官方将其作为third-party来进行支持,但是我还是推荐大家使用AMD专门维护的一套triton版本,因为在最开始的官方triton的main分支下,开启 TRITON_CODEGEN_AMD_HIP_BACKEND=1 没有正确完成编译。所以找到了
按照对应的安装流程进行安装即可,我推荐使用如下命令进行安装,亲测有效
1.gitclonehttps://github.com/ROCmSoftwarePlatform/triton.git 2.cdtriton 3.gitcheckouttriton-mlir
这里已经准备好了需要编译的triton,但是triton后端是基于LLVM的,所以要想借助triton去生成可以跑在对应设备上的代码,我们还需要对LLVM进行编译,本教程中将会手动编译LLVM,当然如果你选择直接编译好的LLVM也是没有问题的。关于LLVM,由于triton是基于b1115f8c这个commit id进行开发的,那么我们只需要将LLVM clone下来后,checkout到对应的commit id,然后按照如下完整命令进行编译即可。
1.gitclonehttps://github.com/llvm/llvm-project 2.gitcheckoutb1115f8c 3.cdllvm-project 4.mkdirbuild 5.cdbuild 6.cmake-GNinja-DCMAKE_BUILD_TYPE=Release-DLLVM_ENABLE_ASSERTIONS=ON../llvm-DLLVM_ENABLE_PROJECTS="mlir;llvm" 7.ninja-j8
等LLVM全部装好后,就可以去将当前这个LLVM的路径写入到你的bashrc下
exportPATH=/home/llvm-project/build/bin:$PATH
然后进入到一开始clone下来的triton目录下进行如下命令
1.cdtriton 2.vimCMakeLists.txt(option(TRITON_BUILD_PYTHON_MODULE"BuildPythonTritonbindings"ON)) 3.mkdirbuild 4.cdbuild 5.cmake.. 6.make-j8
在编译完全正确后,就会在当前的 build 目录下产生一个 libtriton.so 文件。那么接下来只要将
libtriton.so 文件移动到 triton/python/triton/_C 目录下,将 triton 的 python 路径下入 bashrc
exportTRITON_HOME=/home/Documents/compiler/triton exportPYTHONPATH=$TRITON_HOME/python:${PYTHONPATH}
如果在编译的过程中出现 goolge test 找不到的情况,按照如下命令进行安装:
1.gitclonehttps://github.com/google/googletest 2.cdgoogletest 3.cmakeCMakeLists.txt 4.make-j8 5.cp./lib/libgtest*.a/usr/lib 6.cdgoogletest 7.cp–ainclude/gtest/usr/include
如果在编译的过程中出现 pybind11 找不到的情况,按照如下命令进行按照:
1.pipinstallpytest 2.gitclonehttps://github.com/pybind/pybind11.git 3.cdpybind11 4.mkdirbuild 5.cdbuild 6.cmake.. 7.makecheck-j8 8.sudomakeinstal
关于 在AMD GPU上的pytorch 一定要去安装适配 ROCM 版本的 pytorch,由于我的机器使用的是5.6版本的ROCm,所以我的安装的命令如下,仅供参考:
pip3installtorch==2.1.0torchvision==0.16.0torchaudio==2.1.0--index-url https://download.pytorch.org/whl/rocm5.6
关于 ROCM 版本可以通过如下命令进行查询:
dpkg-l|greprocm
这里要记住,pytorch在AMD GPU上的使用和在NV GPU上的使用非常相似,也是用.cuda()来指定变量所在位置。
0x01 GEMM代码示例
全部编译好后,就可以通过执行下面的代码得到对应的 GEMM 在 AMD 显卡上针对 Triton和 rocBLAS 的 benchmark 了。
importtorch importtriton importtriton.languageastl importsys importargparse importpytest #`triton.jit`'edfunctionscanbeauto-tunedbyusingthe`triton.autotune`decorator,whichconsumes: #-Alistof`triton.Config`objectsthatdefinedifferentconfigurationsof #meta-parameters(e.g.,`BLOCK_SIZE_M`)andcompilationoptions(e.g.,`num_warps`)totry #-Anauto-tuning*key*whosechangeinvalueswilltriggerevaluationofallthe #providedconfigs @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), ]iftorch.version.hipisNoneelse[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':256,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':4,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8,'waves_per_eu':3}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':8}, num_warps=4,num_stages=0), ], key=['M','N','K'], ) @triton.heuristics({ 'EVEN_K':lambdaargs:args['K']%args['BLOCK_SIZE_K']==0, }) @triton.jit defmatmul_kernel( #Pointerstomatrices a_ptr,b_ptr,c_ptr, #Matrixdimensions M,N,K, #Thestridevariablesrepresenthowmuchtoincreasetheptrbywhenmovingby1 #elementinaparticulardimension.E.g.`stride_am`ishowmuchtoincrease`a_ptr` #bytogettheelementonerowdown(AhasMrows). stride_am,stride_ak, stride_bk,stride_bn, stride_cm,stride_cn, #Meta-parameters BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr, EVEN_K:tl.constexpr, GROUP_SIZE_M:tl.constexpr, ACTIVATION:tl.constexpr, ): """KernelforcomputingthematmulC=AxB. Ahasshape(M,K),Bhasshape(K,N)andChasshape(M,N) """ #----------------------------------------------------------- #Mapprogramids`pid`totheblockofCitshouldcompute. #ThisisdoneinagroupedorderingtopromoteL2datareuse. #Seeabove`L2CacheOptimizations`sectionfordetails. pid=tl.program_id(axis=0) num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) ifGROUP_SIZE_M==1: pid_m=pid//num_pid_n pid_n=pid%num_pid_n else: num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m #---------------------------------------------------------- #CreatepointersforthefirstblocksofAandB. #WewilladvancethispointeraswemoveintheKdirection #andaccumulate #`a_ptrs`isablockof[BLOCK_SIZE_M,BLOCK_SIZE_K]pointers #`b_ptrs`isablockof[BLOCK_SIZE_K,BLOCK_SIZE_N]pointers #Seeabove`PointerArithmetics`sectionfordetails offs_k=tl.arange(0,BLOCK_SIZE_K) offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn) #----------------------------------------------------------- #IteratetocomputeablockoftheCmatrix. #Weaccumulateintoa`[BLOCK_SIZE_M,BLOCK_SIZE_N]`block #offp32valuesforhigheraccuracy. #`accumulator`willbeconvertedbacktofp16aftertheloop. accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. ifEVEN_K: a=tl.load(a_ptrs) b=tl.load(b_ptrs) else: a=tl.load(a_ptrs,mask=offs_k[None,:]< K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # You can fuse arbitrary activation functions here # while the accumulator is still in FP32! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) # We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`. @triton.jit def leaky_relu(x): x = x + 1 return tl.where(x >=0,x,0.01*x) #%% #Wecannowcreateaconveniencewrapperfunctionthatonlytakestwoinputtensors, #and(1)checksanyshapeconstraint;(2)allocatestheoutput;(3)launchestheabovekernel. defmatmul(a,b,activation=""): #Checkconstraints. asserta.shape[1]==b.shape[0],"Incompatibledimensions" asserta.is_contiguous(),"MatrixAmustbecontiguous" assertb.is_contiguous(),"MatrixBmustbecontiguous" M,K=a.shape K,N=b.shape #Allocatesoutput. c=torch.empty((M,N),device=a.device,dtype=a.dtype) #1Dlaunchkernelwhereeachblockgetsitsownprogram. grid=lambdaMETA:(triton.cdiv(M,META['BLOCK_SIZE_M'])*triton.cdiv(N,META['BLOCK_SIZE_N']),) matmul_kernel[grid]( a,b,c,# M,N,K,# a.stride(0),a.stride(1),# b.stride(0),b.stride(1),# c.stride(0),c.stride(1),# ACTIVATION=activation# ) returnc #%% #UnitTest #--------- # #Wecantestourcustommatrixmultiplicationoperationagainstanativetorchimplementation(i.e.,cuBLAS). @pytest.mark.parametrize("M,N,K,in_dtype,out_dtype", [(*shape,in_dtype,out_dtype) forshapein[(128,256,32),(128,16,32),(32,128,64), (128,128,64),(64,128,128),(32,128,64), (64,64,32),(32,32,128),(128,128,64), (64,128,128),(512,512,512),(1024,1024,1024)] forin_dtype,out_dtypein[('int8','int8'), ('float16','float16'), ('bfloat16','bfloat16'), ('float16','float32'), ('float32','float32')]] ) deftest_correctness(M,N,K,in_dtype,out_dtype): torch.manual_seed(0) a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) triton_output=matmul(a,b) torch_output=torch.matmul(a,b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") rtol=0iftorch.version.hipisNoneelse1e-2 iftorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol): print("TritonandTorchmatch") else: print("TritonandTorchdiffer") asserttorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol) #%% #Benchmark #--------- # #SquareMatrixPerformance #~~~~~~~~~~~~~~~~~~~~~~~~~~ # #WecannowcomparetheperformanceofourkernelagainstthatofcuBLAS.Herewefocusonsquarematrices, #butfeelfreetoarrangethisscriptasyouwishtobenchmarkanyothermatrixshape. globalverbose verbose=False @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M','N','K'],#Argumentnamestouseasanx-axisfortheplot x_vals=[ (1024,1024,1024), (2048,2048,2048), (4096,4096,4096), (8192,8192,8192), (9728,8192,65536) ],#Differentpossiblevaluesfor`x_name` line_arg='provider',#Argumentnamewhosevaluecorrespondstoadifferentlineintheplot #Possiblevaluesfor`line_arg` line_vals=['rocblas','triton'], #Labelnameforthelines line_names=["rocBLAS","Triton"], #Linestyles styles=[('green','-'),('blue','-')], ylabel="TFLOPS",#Labelnameforthey-axis plot_name="matmul-performance",#Namefortheplot,usedalsoasafilenameforsavingtheplot. args={}, )) defbenchmark(M,N,K,provider): a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) quantiles=[0.5,0.2,0.8] ifprovider=='rocblas': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles) globalverbose ifverbose: print(f'SIZE:{M},{N},{K}Besttuningconfig:({matmul_kernel.get_best_config()})') perf=lambdams:2*M*N*K*1e-12/(ms*1e-3) returnperf(ms),perf(max_ms),perf(min_ms) defparse_args(): parser=argparse.ArgumentParser( prog="GEMMtutorialexample", allow_abbrev=False, ) parser.add_argument("-v",action='store_true',default=False,help="Printoutthebesttuningconfig") args=parser.parse_args() returnargs defmain(): #assigntoaglobalverbosevartoindicatewhetherprint #besttuningconfig globalverbose args=parse_args() verbose=args.v benchmark.run(show_plots=True,print_data=True) if__name__=='__main__': sys.exit(main())
0x10 GEMM代码详细解读
首先是对于搜索空间的定义,这里
@triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3, num_warps=8), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4, num_warps=4), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5, num_warps=2), ]iftorch.version.hipisNoneelse[ triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':256,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':16,'GROUP_SIZE_M':4,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':2}, num_warps=8,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8,'waves_per_eu':3}, num_warps=4,num_stages=0), triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':1,'waves_per_eu':8}, num_warps=4,num_stages=0), ], key=['M','N','K'], )
其中的torch.version.hip走的就是AMD GPU所对应的搜索空间,我们看到其对应的可以tuning的knob,有最常规的BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M外,还有了一个新的wave_per_eu,我一开始看到这个概念的时候也很陌生,随后和AMD的技术人员请教了下,总结下来就是:
AMD GPU由计算单元(CU)组成,这相当于NVIDIA GPU上的流处理器(SM)。在每个CU中,有4个SIMD单元(也称执行引擎或EU)。你可以把SIMD单元看成是一个矢量执行单元,它具有执行计算所需的一定数量的寄存器和ALUs。当你发起一个计算网格时,工作组(相当于NVIDIA GPU上的线程块)会安排在CU上运行。
在CU中,波前(相当于NVIDIA GPU上的波纹)会安排在SIMD单元上运行。这里提出了occupancy的概念,它表示每个SIMD单元上可同时运行的波前数。这取决于每个波前需要的资源量和每个SIMD单元的资源量。waves_per_eu参数重点关注寄存器使用情况。例如,每个SIMD(EU)有512个寄存器。
如果每个波前需要256个寄存器,那么occupancy为2。但如果我们设置waves_per_eu=3,编译器会试图将每个波前的寄存器使用量减少到170,这样occupancy就可以是3了。但是提高waves_per_eu存在寄存器溢出的风险和性能下降。所以增加waves_per_eu可能会增加occupancy,但不一定能提高性能。
然后是具体的kernel定义,这部分的定义其实和NV GPU上的写法没有本质区别
@triton.jit defmatmul_kernel( #Pointerstomatrices a_ptr,b_ptr,c_ptr, #Matrixdimensions M,N,K, #Thestridevariablesrepresenthowmuchtoincreasetheptrbywhenmovingby1 #elementinaparticulardimension.E.g.`stride_am`ishowmuchtoincrease`a_ptr` #bytogettheelementonerowdown(AhasMrows). stride_am,stride_ak, stride_bk,stride_bn, stride_cm,stride_cn, #Meta-parameters BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr, EVEN_K:tl.constexpr, GROUP_SIZE_M:tl.constexpr, ACTIVATION:tl.constexpr, ): """KernelforcomputingthematmulC=AxB. Ahasshape(M,K),Bhasshape(K,N)andChasshape(M,N) """ #----------------------------------------------------------- #Mapprogramids`pid`totheblockofCitshouldcompute. #ThisisdoneinagroupedorderingtopromoteL2datareuse. #Seeabove`L2CacheOptimizations`sectionfordetails. pid=tl.program_id(axis=0) num_pid_m=tl.cdiv(M,BLOCK_SIZE_M) num_pid_n=tl.cdiv(N,BLOCK_SIZE_N) ifGROUP_SIZE_M==1: pid_m=pid//num_pid_n pid_n=pid%num_pid_n else: num_pid_in_group=GROUP_SIZE_M*num_pid_n group_id=pid//num_pid_in_group first_pid_m=group_id*GROUP_SIZE_M group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M) pid_m=first_pid_m+(pid%group_size_m) pid_n=(pid%num_pid_in_group)//group_size_m #---------------------------------------------------------- #CreatepointersforthefirstblocksofAandB. #WewilladvancethispointeraswemoveintheKdirection #andaccumulate #`a_ptrs`isablockof[BLOCK_SIZE_M,BLOCK_SIZE_K]pointers #`b_ptrs`isablockof[BLOCK_SIZE_K,BLOCK_SIZE_N]pointers #Seeabove`PointerArithmetics`sectionfordetails offs_k=tl.arange(0,BLOCK_SIZE_K) offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak) b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn) #----------------------------------------------------------- #IteratetocomputeablockoftheCmatrix. #Weaccumulateintoa`[BLOCK_SIZE_M,BLOCK_SIZE_N]`block #offp32valuesforhigheraccuracy. #`accumulator`willbeconvertedbacktofp16aftertheloop. accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32) forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)): #LoadthenextblockofAandB,generateamaskbycheckingtheKdimension. #Ifitisoutofbounds,setitto0. ifEVEN_K: a=tl.load(a_ptrs) b=tl.load(b_ptrs) else: a=tl.load(a_ptrs,mask=offs_k[None,:]< K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) # We accumulate along the K dimension. accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # You can fuse arbitrary activation functions here # while the accumulator is still in FP32! if ACTIVATION == "leaky_relu": accumulator = leaky_relu(accumulator) c = accumulator.to(tl.float16) # ----------------------------------------------------------- # Write back the block of the output matrix C with masks. offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask)
接下来是单元测试,用来说明triton的输出结果和torch的输出结果必须是相同的
deftest_correctness(M,N,K,in_dtype,out_dtype): torch.manual_seed(0) a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) triton_output=matmul(a,b) torch_output=torch.matmul(a,b) print(f"triton_output={triton_output}") print(f"torch_output={torch_output}") rtol=0iftorch.version.hipisNoneelse1e-2 iftorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol): print("TritonandTorchmatch") else: print("TritonandTorchdiffer") asserttorch.allclose(triton_output,torch_output,atol=1e-2,rtol=rtol)
接下来你只需要指定好对应的GEMM的尺寸,我们的默认输入顺序还是以M,N,K为主,剩下都是中规中局的操作了。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['M','N','K'],#Argumentnamestouseasanx-axisfortheplot x_vals=[ (1024,1024,1024), (2048,2048,2048), (4096,4096,4096), (8192,8192,8192), (9728,8192,65536) ],#Differentpossiblevaluesfor`x_name` line_arg='provider',#Argumentnamewhosevaluecorrespondstoadifferentlineintheplot #Possiblevaluesfor`line_arg` line_vals=['rocblas','triton'], #Labelnameforthelines line_names=["rocBLAS","Triton"], #Linestyles styles=[('green','-'),('blue','-')], ylabel="TFLOPS",#Labelnameforthey-axis plot_name="matmul-performance",#Namefortheplot,usedalsoasafilenameforsavingtheplot. args={}, )) defbenchmark(M,N,K,provider): a=torch.randn((M,K),device='cuda',dtype=torch.float16) b=torch.randn((K,N),device='cuda',dtype=torch.float16) quantiles=[0.5,0.2,0.8] ifprovider=='rocblas': ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles) ifprovider=='triton': ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles) globalverbose ifverbose: print(f'SIZE:{M},{N},{K}Besttuningconfig:({matmul_kernel.get_best_config()})') perf=lambdams:2*M*N*K*1e-12/(ms*1e-3) returnperf(ms),perf(max_ms),perf(min_ms) defparse_args(): parser=argparse.ArgumentParser( prog="GEMMtutorialexample", allow_abbrev=False, ) parser.add_argument("-v",action='store_true',default=False,help="Printoutthebesttuningconfig") args=parser.parse_args() returnargs defmain(): #assigntoaglobalverbosevartoindicatewhetherprint #besttuningconfig globalverbose args=parse_args() verbose=args.v benchmark.run(show_plots=True,print_data=True) if__name__=='__main__': sys.exit(main())
关于在AMD GPU上更加自动化的GEMM benchmark调优脚本,我们将在后面的章节中来为大家进行解读。
审核编辑:刘清
-
amd
+关注
关注
25文章
5468浏览量
134135 -
gpu
+关注
关注
28文章
4736浏览量
128933 -
Triton
+关注
关注
0文章
28浏览量
7035 -
python
+关注
关注
56文章
4796浏览量
84668 -
GPU芯片
+关注
关注
1文章
303浏览量
5811 -
pytorch
+关注
关注
2文章
808浏览量
13221 -
OpenAI
+关注
关注
9文章
1087浏览量
6503
原文标题:OpenAI/Triton MLIR 第四章: ROCm-triton配置
文章出处:【微信号:GiantPandaCV,微信公众号:GiantPandaCV】欢迎添加关注!文章转载请注明出处。
发布评论请先 登录
相关推荐
评论