Skip to content

Commit

Permalink
Switch to CuArray storage_type + new default S value
Browse files Browse the repository at this point in the history
  • Loading branch information
nHackel authored and dpo committed Apr 20, 2024
1 parent 9786291 commit be5e5bc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 14 deletions.
15 changes: 2 additions & 13 deletions ext/LinearOperatorsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,8 @@
module LinearOperatorsCUDAExt

using LinearOperators, LinearOperators.FastClosures, LinearOperators.LinearAlgebra
using LinearOperators
isdefined(Base, :get_extension) ? (using CUDA) : (using ..CUDA)

function LinearOperators.LinearOperator(
M::CuArray{T, 2, D};
symmetric = false,
hermitian = false,
S = CuArray{T, 1, D},
) where {T, D}
nrow, ncol = size(M)
prod! = @closure (res, v, α, β) -> mul!(res, M, v, α, β)
tprod! = @closure (res, u, α, β) -> mul!(res, transpose(M), u, α, β)
ctprod! = @closure (res, w, α, β) -> mul!(res, adjoint(M), w, α, β)
LinearOperators.LinearOperator{T}(nrow, ncol, symmetric, hermitian, prod!, tprod!, ctprod!, S = S)
end
LinearOperators.storage_type(::CuArray{T, 2, D}) where {T, D} = CuArray{T, 1, D}

end # module
2 changes: 1 addition & 1 deletion src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ function LinearOperator(
M::AbstractMatrix{T};
symmetric = false,
hermitian = false,
S = Vector{T},
S = storage_type(M)
) where {T}
nrow, ncol = size(M)
prod! = @closure (res, v, α, β) -> mul!(res, M, v, α, β)
Expand Down

0 comments on commit be5e5bc

Please sign in to comment.