-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BLAS.gemm! with SubArray #15286
Comments
This works if you use I've verified that this bug is present on a recent 0.5 (d6bc9c9) as well. |
AFAICT, |
Check stride on preallocated output for matmul (fixes #15286)
@machiningcentre, thanks for reporting this. If you're using master (i.e., julia-0.5), this is fixed now if you use For However, in this particular case I just noticed that there is a third choice: see #15308. |
@timholy, thank you for your kind help. |
I just opened issue JuliaLang/LinearAlgebra.jl#319 when yuyichao made me aware of this issue which is the same thing. As a possible solution to this issue, defining new Union types specific for blas. Now when trying to call Blas functions (and lapack functions) with strided matrices with non UnitRanges will result in an error. The code below, adapted from version 0.4.3 apparently works (throws an error). I simply changed references to typealias URangeIndex Union{Int, UnitRange{Int}, Colon}
typealias BlasMatrix{T,A<:DenseArray,I<:Tuple{Vararg{URangeIndex}}} Union{DenseArray{T,2}, SubArray{T,2,A,I}}
typealias UBlasVector{T,A<:DenseArray,I<:Tuple{Vararg{URangeIndex}}} Union{DenseArray{T,1}, SubArray{T,1,A,I}}
typealias UVecOrMat{T} Union{UBlasVector{T}, BlasMatrix{T}}
const libblas = Base.libblas_name
import Base.LinAlg: BlasReal, BlasComplex, BlasFloat, BlasInt, DimensionMismatch, chksquare, axpy!
import Base: copy!, blasfunc
for (fname, elty) in ((:dgemv_,:Float64),
(:sgemv_,:Float32),
(:zgemv_,:Complex128),
(:cgemv_,:Complex64))
@eval begin
#SUBROUTINE DGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
#* .. Scalar Arguments ..
# DOUBLE PRECISION ALPHA,BETA
# INTEGER INCX,INCY,LDA,M,N
# CHARACTER TRANS
#* .. Array Arguments ..
# DOUBLE PRECISION A(LDA,*),X(*),Y(*)
function gemv!(trans::Char, alpha::($elty), A::UVecOrMat{$elty}, X::StridedVector{$elty}, beta::($elty), Y::StridedVector{$elty})
m,n = size(A,1),size(A,2)
if trans == 'N' && (length(X) != n || length(Y) != m)
throw(DimensionMismatch("A has dimensions $(size(A)), X has length $(length(X)) and Y has length $(length(Y))"))
elseif trans == 'C' && (length(X) != m || length(Y) != n)
throw(DimensionMismatch("A' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
elseif trans == 'T' && (length(X) != m || length(Y) != n)
throw(DimensionMismatch("A.' has dimensions $n, $m, X has length $(length(X)) and Y has length $(length(Y))"))
end
ccall(($(blasfunc(fname)), libblas), Void,
(Ptr{UInt8}, Ptr{BlasInt}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt},
Ptr{$elty}, Ptr{$elty}, Ptr{BlasInt}),
&trans, &size(A,1), &size(A,2), &alpha,
A, &max(1,stride(A,2)), X, &stride(X,1),
&beta, Y, &stride(Y,1))
Y
end
function gemv(trans::Char, alpha::($elty), A::BlasMatrix{$elty}, X::StridedVector{$elty})
gemv!(trans, alpha, A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
end
function gemv(trans::Char, A::BlasMatrix{$elty}, X::StridedVector{$elty})
gemv!(trans, one($elty), A, X, zero($elty), similar(X, $elty, size(A, (trans == 'N' ? 1 : 2))))
end
end
end
A = reshape(linspace(1,100,100), 10, 10)
y = [linspace(1,3,3);]
A1 = A[1:3:7, 2:3:8]
A2 = sub(A, 1:3:7, 2:3:8)
A3 = copy(A2)
z1 = gemv('N', 1.0, A1, y)
z2 = gemv('N', 1.0, A2, y)
z3 = gemv('N', 1.0, A3, y)
hcat(z1, z2, z3) |
See also #15308. |
BLAS.gemm! result values can be stored in subarray, but their indices seem to be wrong.
my
versioninfo()
The text was updated successfully, but these errors were encountered: