Skip to content

Commit

Permalink
krr
Browse files Browse the repository at this point in the history
  • Loading branch information
Rabab53 committed Aug 13, 2024
1 parent 5f7b4f6 commit ecaffbe
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 0 deletions.
66 changes: 66 additions & 0 deletions example/KRR.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using Dagger
using Base: read, reinterpret
using LinearAlgebra
#using Distances

function read_binary_to_matrix(filename::String, nrows::Int, ncols::Int)
# Open the file in read mode
open(filename, "r") do io
# Read the entire file content as raw bytes
bytes = read(io)
# Reinterpret the raw bytes as an array of Float32
data = reinterpret(Float32, bytes)
# Reshape the array into the desired matrix dimensions
new_data = reshape(data, nrows, ncols)

return new_data#[1:10, 1:20]
end
end

function all_elements_less_than(matrix::Array{Float32,2}, threshold::Float32)
return all(matrix .< threshold)
end

function gaussian_kernel(matrix::Array{Float32,2}, sigma::Float32)
return exp.(-matrix.^2 ./ (2 * sigma^2))
end

# Example usage:
filename = "/ibex/ai/home/omairyrm/gbgwas/hicmamain/scripts/genotype09.bin" # Replace with your actual file path
nrows, ncols = 1024, 30720 # Replace with the desired matrix dimensions

A = read_binary_to_matrix(filename, nrows, ncols)
println(size(A))
A = Matrix(A)

#nrows, ncols = 10, 20
#A = rand(0:2, nrows, ncols)
#display(A)
println(size(A))
Acpy = A
AA= A'*A #syrk
#display(AA)

C = diag(AA) .* ones(ncols, ncols)
#display(C)
D = C'
#display(D)

C = ((-2) .* AA) + (C) + (D)
#C = LowerTriangular(C) + LowerTriangular(D)
#display(C)
C= sqrt.(C)
C[diagind(C)] .+= 10
C = Float32.(C)
sigma = Float32(1e1)
println(typeof(C))

#Dl = pairwise(Euclidean(), Acpy)

gaussian_matrix = gaussian_kernel(C, sigma)
display(gaussian_matrix)

DA = view(gaussian_matrix, Blocks(1024, 1024));

threshold = Float32(10^-3)
MP = Dagger.adapt_precision(DA, threshold)
181 changes: 181 additions & 0 deletions src/array/adapt_precision2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
"""
tile_precision(uplo, global_norm, scalar_factor, tolerance, A)
it receives tile and it compute required precision per tile
### Input
- `A` -- tile of size m x n
- `global_norm` -- global norm of the whole matrix
- `scalar_factor` -- scale tile by this value which is the number of tiles
- `tolerance` -- user defined tolerance as required aby the application
### Output
The required precision of the tile
"""
function tile_precision(A, global_norm, scalar_factor, tolerance)

tile_sqr = mapreduce(LinearAlgebra.norm_sqr, +, A)

tile_norm = sqrt(tile_sqr)

cal = tile_norm * scalar_factor / global_norm
decision_hp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float16)
decision_sp = tile_norm * scalar_factor / global_norm < tolerance / eps(Float32)

#We are planning in near future to support fp8 E4M3 and E5M2
decision_fp8 = tile_norm * scalar_factor / global_norm < tolerance / 0.0625
if decision_fp8
return "Float8"
elseif decision_hp
return "Float16"
elseif decision_sp
#@show m, n
return "Float32"
else
return "Float64"
end
end

"""
function adapt_precision( A::UpperTriangular{T,<:DArray{T,2}},
MP::UpperTriangular{String,<:DArray{String,2}}, tolerance::Float64) where {T}
it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham
### Input
- `A` -- Dagger UpperTriangular array of tiles with real values
- `MP` -- Dagger UpperTriangular array to associate precision with each tile
- `tolerance` -- User defined tolerance as required aby the application
### Output
The Dagger array shows the required precision of each tile
"""
"""
function adapt_precision(A::UpperTriangular{T,<:DArray{T,2}}, tolerance::Float64) where {T}
Ac = parent(A).chunks
mt, nt = size(Ac)
global_norm = LinearAlgebra.norm2(A)
MP = fill("T", mt, nt)
DMP = view(MP, Blocks(1, 1))
MPc = parent(DMP).chunks
for n in range(1, nt)
for m in range(1, n)
if m == n
MPc[m, n] = Dagger.@spawn tile_precision(
UpperTriangular(Ac[m, n]),
global_norm,
max(mt, nt),
tolerance)
else
MPc[m, n] = Dagger.@spawn tile_precision(
Ac[m, n],
global_norm,
max(mt, nt),
tolerance)
end
end
end
return UpperTriangular(collect(DMP))
end
"""

"""
adapt_precision( A::LowerTriangular{T,<:DArray{T,2}},
MP::LowerTriangular{String,<:DArray{String,2}}, tolerance::Float64) where {T}
it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham
### Input
- `A` -- Dagger LowerTriangular array of tiles with real values
- `MP` -- Dagger LowerTriangular array to associate precision with each tile
- `tolerance` -- User defined tolerance as required aby the application
### Output
The Dagger array shows the required precision of each tile
"""

"""
function adapt_precision(A::LowerTriangular{T,<:DArray{T,2}}, tolerance::T) where {T}
Ac = parent(A).chunks
mt, nt = size(Ac)
global_norm = LinearAlgebra.norm2(A)
MP = fill("T", mt, nt)
DMP = view(MP, Blocks(1, 1))
MPc = parent(DMP).chunks
for m in range(1, mt)
for n in range(1, m)
if m == n
MPc[m, n] = Dagger.@spawn tile_precision(
LowerTriangular(Ac[m, n]),
global_norm,
max(mt, nt),
tolerance)
else
MPc[m, n] = Dagger.@spawn tile_precision(
Ac[m, n],
global_norm,
max(mt, nt),
tolerance)
end
end
end
return LowerTriangular(collect(DMP))
end
"""

"""
adapt_precision(A::DArray{T,2}, MP::DArray{String,2}, tolerance::T) where {T}
it iterates over all tiles and calculates the required precision per tile based on formulation from Nicholas J. Higham
### Input
- `A` -- Dagger array of tiles with real values
- `MP` -- Dagger array to associate precision with each tile
- `tolerance` -- User defined tolerance as required aby the application
### Output
The Dagger array shows the required precision of each tile
"""

function adapt_precision(A::DArray{T,2}, tolerance::T) where {T}

Ac = parent(A).chunks
mt, nt = size(Ac)

global_norm = LinearAlgebra.norm2(A)

MP = fill("T", mt, nt)
DMP = view(MP, Blocks(1, 1))
MPc = DMP.chunks


for m in range(1, mt)
for n in range(1, nt)
MPc[m, n] =
Dagger.@spawn tile_precision(
Ac[m, n],
global_norm,
max(mt, nt),
tolerance)
end
end

return collect(DMP)
end

0 comments on commit ecaffbe

Please sign in to comment.