Jerry Chen / Aug 19 2019
Remix of Get Started by help

MAGMA binding: gesdd and its GPU interfaces

In the first blog post we configured the gesvd wrappers and test them.

However, there is a more important subroutine which is actually what we really want in MAGMA: gesdd, which utilizes the divide and conquer algorithm.

Our final goal is to use MAGMA to run GPU version SVD on our tensors. Typically, MAGMA names the subroutines with GPU interfaces as "*_gpu" where '*' stands for the original name of subroutines.

GPU interfaces for gesdd

Unfortunately, MAGMA does not provide the native GPU interfaces for gesdd.

After refering to the counterpart implementation in PyTorch, we followed a similar method in Julia; that is, basically, we first convert the matrix to CPU before feed it back to MAGMA gesdd subroutines, where the CPU data will be transfered to CUDA again.

Meanwhile, the other svd subroutines have no native GPU interfaces either, so we are still discussing the possible solutions to this problem in an open issue.

Different argument lists according to element types

The gesdd (as well as gesvd) has two different argument lists: for the magma_cgesdd and magma_zgesdd, which is to do the SVDs on complex matrices, there is an extra argument called rwork. (See https://icl.cs.utk.edu/projectsfiles/magma/doxygen/group__magma__gesdd.html)

So, we used multiple dispatch to implement the two different kinds of wrappers:

Shift+Enter to run
for (gesvd, gesdd, elty, relty) in    ((:sgesvd, :sgesdd, :Float32, :Float32),
                            (:dgesvd, :dgesdd, :Float64, :Float64)),
                            interface in (:Matrix, :CuMatrix)
	@eval begin
    function magma_gesdd!(job::AbstractChar, A::$interface{$elty})
      A = Matrix{$elty}(A)

      m, n    = size(A)
      minmn   = min(m, n)
      lda     = max(1, stride(A, 2))

      if job == 'A'
        U  = similar(A, $elty, (m, m))
        VT = similar(A, $elty, (n, n))
        elseif job == 'S'
        U  = similar(A, $elty, (m, minmn))
        VT = similar(A, $elty, (minmn, n))
        elseif job == 'O'
        U  = similar(A, $elty, (m, m >= n ? 0 : m))
        VT = similar(A, $elty, (n, m >= n ? n : 0))
      else
        U  = similar(A, $elty, (m, 0))
        VT = similar(A, $elty, (n, 0))
      end

      ldu = max(1, stride(U, 2))
      ldvt = max(1, stride(VT, 2))

      S       = similar(A, $relty, minmn)

      work    = Vector{$elty}(undef, 1)
      lwork   = Cint(-1)
      iwork = Vector{Cint}(undef, 8*minmn)
      info    = Ref{Cint}()

      job_magma      = char_to_magmaInt(job)

      for i in 1:2
        ccall((@magmafunc($gesdd), libmagma), Cint,
          (Cint,
           Cint, Cint,
           Ptr{$elty}, Cint,
           Ptr{$relty},
           Ptr{$elty}, Cint,
           Ptr{$elty}, Cint,
           Ptr{$elty}, Cint,
           Ptr{Cint}, Ptr{Cint}),

           job_magma,
           m, n,
           A, lda,
           S,
           U, ldu,
           VT, ldvt,
           work, lwork,
           iwork, info)
        if i == 1
          lwork = ceil(Cint, nextfloat(real(work[1])))
          resize!(work, lwork)
        end
      end
      # allocate different outputs according to different jobs
      if job == 'O'
        if m >= n
          return (A, S, VT)
        else
          return (U, S, A)
        end
      end
      return (U, S, VT)
    end
  end
end



for (gesvd, gesdd, elty, relty) in    ((:cgesvd, :cgesdd, :ComplexF32, :Float32),
                            (:zgesvd, :zgesdd, :ComplexF64, :Float64)),
                            interface in (:Matrix, :CuMatrix)
  @eval begin
    function magma_gesdd!(job::AbstractChar, A::$interface{$elty})
      
      A = Matrix{$elty}(A)

      m, n    = size(A)
      minmn   = min(m, n)
      lda     = max(1, stride(A, 2))

      if job == 'A'
        U  = similar(A, $elty, (m, m))
        VT = similar(A, $elty, (n, n))
        elseif job == 'S'
        U  = similar(A, $elty, (m, minmn))
        VT = similar(A, $elty, (minmn, n))
        elseif job == 'O'
        U  = similar(A, $elty, (m, m >= n ? 0 : m))
        VT = similar(A, $elty, (n, m >= n ? n : 0))
      else
        U  = similar(A, $elty, (m, 0))
        VT = similar(A, $elty, (n, 0))
      end

      ldu = max(1, stride(U, 2))
      ldvt = max(1, stride(VT, 2))

      S       = similar(A, $relty, minmn)


      work    = Vector{$elty}(undef, 1)
      lwork   = Cint(-1)
      rwork = Vector{$relty}(undef, job == 'N' ?
        7*minmn : minmn*max(5*minmn+7, 2*max(m,n)+2*minmn+1))
      iwork = Vector{Cint}(undef, 8*minmn)
      info    = Ref{Cint}()

      job_magma      = char_to_magmaInt(job)

      for i in 1:2

        ccall((@magmafunc($gesdd), libmagma), Cint,
          (Cint,
            Cint, Cint,
            Ptr{$elty}, Cint,
            Ptr{$relty},
            Ptr{$elty}, Cint,
            Ptr{$elty}, Cint,
            Ptr{$elty}, Cint,
            Ptr{$relty},
            Ptr{Cint}, Ptr{Cint}),

          job_magma,
          m, n,
          A, lda,
          S,
          U, ldu,
          VT, ldvt,
          work, lwork,
          rwork,
          iwork, info)
        if i == 1
          lwork = ceil(Cint, nextfloat(real(work[1])))
          resize!(work, lwork)
        end
      end

      if job == 'O'
        if m >= n
          return (A, S, VT)
        else
          return (U, S, A)
        end
      end
      return (U, S, VT)
    end
  end
end