Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BLAS.gemm! with SubArray #15286

Closed
machiningcentre opened this issue Feb 29, 2016 · 6 comments
Closed

BLAS.gemm! with SubArray #15286

machiningcentre opened this issue Feb 29, 2016 · 6 comments
Labels
bug Indicates an unexpected problem or unintended behavior

Comments

@machiningcentre
Copy link

BLAS.gemm! result values can be stored in subarray, but their indices seem to be wrong.

julia> A = rand(3, 3)
3x3 Array{Float64,2}:
 0.391348  0.99617   0.127584
 0.761283  0.364923  0.796806
 0.534478  0.201001  0.114469

julia> B = rand(3, 3)
3x3 Array{Float64,2}:
 0.471382   0.0389864  0.10348
 0.376755   0.52683    0.102277
 0.0756814  0.51072    0.119012

julia> C = zeros(6, 6)
6x6 Array{Float64,2}:
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0

julia> sC = sub(C, 1:2:6, 1:2:6)
3x3 SubArray{Float64,2,Array{Float64,2},Tuple{StepRange{Int64,Int64},StepRange{Int64,Int64}},1}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

julia> BLAS.gemm!('N', 'N', 1.0, A, B, 0.0 ,sC)
3x3 SubArray{Float64,2,Array{Float64,2},Tuple{StepRange{Int64,Int64},StepRange{Int64,Int64}},1}:
 0.569442  0.605229  0.157565
 0.336335  0.185192  0.0894885
 0.0       0.0       0.0

julia> C
6x6 Array{Float64,2}:
 0.569442  0.0  0.605229  0.0  0.157565   0.0
 0.556645  0.0  0.628877  0.0  0.21093    0.0
 0.336335  0.0  0.185192  0.0  0.0894885  0.0
 0.0       0.0  0.0       0.0  0.0        0.0
 0.0       0.0  0.0       0.0  0.0        0.0
 0.0       0.0  0.0       0.0  0.0        0.0

my versioninfo()

Julia Version 0.4.2
Commit bb73f34 (2015-12-06 21:47 UTC)
Platform Info:
  System: Linux (x86_64-linux-gnu)
  CPU: Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz
  WORD_SIZE: 64
  BLAS: libopenblas (NO_AFFINITY HASWELL)
  LAPACK: liblapack.so.3
  LIBM: libopenlibm
  LLVM: libLLVM-3.3
@timholy
Copy link
Member

timholy commented Feb 29, 2016

This works if you use A_mul_B!(sC, A, B) rather than directly calling gemm!. However, that turns out to be a consequence of the fact that we specialize 3x3 multiplication, and does not reflect the correctness of the general algorithm.

I've verified that this bug is present on a recent 0.5 (d6bc9c9) as well.

@timholy timholy added the bug Indicates an unexpected problem or unintended behavior label Feb 29, 2016
@timholy
Copy link
Member

timholy commented Feb 29, 2016

AFAICT, lapackBLAS simply can't do this properly. But our own A_mul_B! should handle this correctly.

@timholy
Copy link
Member

timholy commented Mar 1, 2016

@machiningcentre, thanks for reporting this. If you're using master (i.e., julia-0.5), this is fixed now if you use A_mul_B!. It should be backported to an update of 0.4, but no promises about the schedule.

For gemm! itself, BLAS can't do what you were asking it to do. One option is to say that when people call unexported functions (especially ones ending in !), everything is At Your Own Risk, and the user has to read the relevant source code/docs to make sure the operation is supported. (At some level, there is no arguing with this viewpoint, unless julia starts hiding functions.) Another choice would be to check for unit strides in gemm! itself. The downside is we'd be checking twice (once in the wrapper, and again in gemm!).

However, in this particular case I just noticed that there is a third choice: see #15308.

@machiningcentre
Copy link
Author

@timholy, thank you for your kind help.

@pjabardo
Copy link

I just opened issue JuliaLang/LinearAlgebra.jl#319 when yuyichao made me aware of this issue which is the same thing. As a possible solution to this issue, defining new Union types specific for blas. Now when trying to call Blas functions (and lapack functions) with strided matrices with non UnitRanges will result in an error.

The code below, adapted from version 0.4.3 apparently works (throws an error). I simply changed references to Range or StepRange to UnitRange:

typealias URangeIndex Union{Int, UnitRange{Int}, Colon}

typealias BlasMatrix{T,A<:DenseArray,I<:Tuple{Vararg{URangeIndex}}}  Union{DenseArray{T,2}, SubArray{T,2,A,I}}

typealias UBlasVector{T,A<:DenseArray,I<:Tuple{Vararg{URangeIndex}}}  Union{DenseArray{T,1}, SubArray{T,1,A,I}}

typealias UVecOrMat{T} Union{UBlasVector{T}, BlasMatrix{T}}


const libblas = Base.libblas_name

import Base.LinAlg: BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, chksquare, axpy!
import Base: copy!, blasfunc



for (fname, elty) in ((:dgemv_,:Float64),
                      (:sgemv_,:Float32),
                      (:zgemv_,:Complex128),
                      (:cgemv_,:Complex64))
    @eval begin
        #SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
             #*     .. Scalar Arguments ..
             #      DOUBLE PRECISION ALPHA,BETA
             #      INTEGER INCX,INCY,LDA,M,N
             #      CHARACTER TRANS
             #*     .. Array Arguments ..
             #      DOUBLE PRECISION A(LDA,*),X(*),Y(*)
        function gemv!(trans::Char, alpha::($elty), A::UVecOrMat{$elty}, X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
            m,n = size(A,1),size(A,2)
            if trans == 'N' && (length(X) != n || length(Y) != m)
                throw(DimensionMismatch("A has dimensions $(size(A)), X has length $(length(X)) and Y has length $(length(Y))"))
            elseif trans == 'C' && (length(X) != m || length(Y) != n)
                throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
            elseif trans == 'T' && (length(X) != m || length(Y) != n)
                throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
            end
            ccall(($(blasfunc(fname)), libblas), Void,
                (Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
                 Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
                 Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
                 &trans, &size(A,1), &size(A,2), &alpha,
                 A, &max(1,stride(A,2)), X, &stride(X,1),
                 &beta, Y, &stride(Y,1))
            Y
        end
        function gemv(trans::Char, alpha::($elty), A::BlasMatrix{$elty}, X::StridedVector{$elty})
            gemv!(trans, alpha, A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
        end
        function gemv(trans::Char, A::BlasMatrix{$elty}, X::StridedVector{$elty})
            gemv!(trans, one($elty), A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
        end
    end
end



A = reshape(linspace(1,100,100), 10, 10)
y = [linspace(1,3,3);]

A1 = A[1:3:7, 2:3:8]
A2 = sub(A, 1:3:7, 2:3:8)
A3 = copy(A2)

z1 = gemv('N', 1.0, A1, y)
z2 = gemv('N', 1.0, A2, y)
z3 = gemv('N', 1.0, A3, y)
hcat(z1, z2, z3)

@timholy
Copy link
Member

timholy commented Mar 16, 2016

See also #15308.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Indicates an unexpected problem or unintended behavior
Projects
None yet
Development

No branches or pull requests

3 participants