MAGMA binding: gesdd and its GPU interfaces
In the first blog post we configured the gesvd wrappers and test them.
However, there is a more important subroutine which is actually what we really want in MAGMA: gesdd, which utilizes the divide and conquer algorithm.
Our final goal is to use MAGMA to run GPU version SVD on our tensors. Typically, MAGMA names the subroutines with GPU interfaces as "*_gpu" where '*' stands for the original name of subroutines.
GPU interfaces for gesdd
Unfortunately, MAGMA does not provide the native GPU interfaces for gesdd.
After refering to the counterpart implementation in PyTorch, we followed a similar method in Julia; that is, basically, we first convert the matrix to CPU before feed it back to MAGMA gesdd subroutines, where the CPU data will be transfered to CUDA again.
Meanwhile, the other svd subroutines have no native GPU interfaces either, so we are still discussing the possible solutions to this problem in an open issue.
Different argument lists according to element types
The gesdd (as well as gesvd) has two different argument lists: for the magma_cgesdd and magma_zgesdd, which is to do the SVDs on complex matrices, there is an extra argument called rwork. (See https://icl.cs.utk.edu/projectsfiles/magma/doxygen/group__magma__gesdd.html)
So, we used multiple dispatch to implement the two different kinds of wrappers:
for (gesvd, gesdd, elty, relty) in ((:sgesvd, :sgesdd, :Float32, :Float32), (:dgesvd, :dgesdd, :Float64, :Float64)), interface in (:Matrix, :CuMatrix) begin function magma_gesdd!(job::AbstractChar, A::$interface{$elty}) A = Matrix{$elty}(A) m, n = size(A) minmn = min(m, n) lda = max(1, stride(A, 2)) if job == 'A' U = similar(A, $elty, (m, m)) VT = similar(A, $elty, (n, n)) elseif job == 'S' U = similar(A, $elty, (m, minmn)) VT = similar(A, $elty, (minmn, n)) elseif job == 'O' U = similar(A, $elty, (m, m >= n ? 0 : m)) VT = similar(A, $elty, (n, m >= n ? n : 0)) else U = similar(A, $elty, (m, 0)) VT = similar(A, $elty, (n, 0)) end ldu = max(1, stride(U, 2)) ldvt = max(1, stride(VT, 2)) S = similar(A, $relty, minmn) work = Vector{$elty}(undef, 1) lwork = Cint(-1) iwork = Vector{Cint}(undef, 8*minmn) info = Ref{Cint}() job_magma = char_to_magmaInt(job) for i in 1:2 ccall(( ($gesdd), libmagma), Cint, (Cint, Cint, Cint, Ptr{$elty}, Cint, Ptr{$relty}, Ptr{$elty}, Cint, Ptr{$elty}, Cint, Ptr{$elty}, Cint, Ptr{Cint}, Ptr{Cint}), job_magma, m, n, A, lda, S, U, ldu, VT, ldvt, work, lwork, iwork, info) if i == 1 lwork = ceil(Cint, nextfloat(real(work[1]))) resize!(work, lwork) end end # allocate different outputs according to different jobs if job == 'O' if m >= n return (A, S, VT) else return (U, S, A) end end return (U, S, VT) end end end for (gesvd, gesdd, elty, relty) in ((:cgesvd, :cgesdd, :ComplexF32, :Float32), (:zgesvd, :zgesdd, :ComplexF64, :Float64)), interface in (:Matrix, :CuMatrix) begin function magma_gesdd!(job::AbstractChar, A::$interface{$elty}) A = Matrix{$elty}(A) m, n = size(A) minmn = min(m, n) lda = max(1, stride(A, 2)) if job == 'A' U = similar(A, $elty, (m, m)) VT = similar(A, $elty, (n, n)) elseif job == 'S' U = similar(A, $elty, (m, minmn)) VT = similar(A, $elty, (minmn, n)) elseif job == 'O' U = similar(A, $elty, (m, m >= n ? 0 : m)) VT = similar(A, $elty, (n, m >= n ? n : 0)) else U = similar(A, $elty, (m, 0)) VT = similar(A, $elty, (n, 0)) end ldu = max(1, stride(U, 2)) ldvt = max(1, stride(VT, 2)) S = similar(A, $relty, minmn) work = Vector{$elty}(undef, 1) lwork = Cint(-1) rwork = Vector{$relty}(undef, job == 'N' ? 7*minmn : minmn*max(5*minmn+7, 2*max(m,n)+2*minmn+1)) iwork = Vector{Cint}(undef, 8*minmn) info = Ref{Cint}() job_magma = char_to_magmaInt(job) for i in 1:2 ccall(( ($gesdd), libmagma), Cint, (Cint, Cint, Cint, Ptr{$elty}, Cint, Ptr{$relty}, Ptr{$elty}, Cint, Ptr{$elty}, Cint, Ptr{$elty}, Cint, Ptr{$relty}, Ptr{Cint}, Ptr{Cint}), job_magma, m, n, A, lda, S, U, ldu, VT, ldvt, work, lwork, rwork, iwork, info) if i == 1 lwork = ceil(Cint, nextfloat(real(work[1]))) resize!(work, lwork) end end if job == 'O' if m >= n return (A, S, VT) else return (U, S, A) end end return (U, S, VT) end end end