Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add create_hvector (support MPI_Type_create_hvector) #635

Merged
merged 4 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ MPI.to_type
MPI.Types.extent
MPI.Types.create_contiguous
MPI.Types.create_vector
MPI.Types.create_hvector
MPI.Types.create_subarray
MPI.Types.create_struct
MPI.Types.create_resized
Expand Down
31 changes: 30 additions & 1 deletion src/datatypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,40 @@ function create_vector(count::Integer, blocklength::Integer, stride::Integer, ol
end
function create_vector!(newtype::Datatype, count::Integer, blocklength::Integer, stride::Integer, oldtype::Datatype)
# int MPI_Type_vector(int count, int blocklength, int stride,
# MPI_Datatype oldtype, MPI_Datatype *newtype)
# MPI_Datatype oldtype, MPI_Datatype *newtype)
API.MPI_Type_vector(count, blocklength, stride, oldtype, newtype)
return newtype
end

"""
MPI.Types.create_hvector(count::Integer, blocklength::Integer, stride::Integer, oldtype::MPI.Datatype)

Create a derived [`Datatype`](@ref) that replicates `oldtype` into locations that
consist of equally spaced (bytes) blocks.

Note that [`MPI.Types.commit!`](@ref) must be used before the datatype can be used for
communication.

# Example

```julia
datatype = MPI.Types.create_hvector(3, 2, 5, MPI.Datatype(Int64))
MPI.Types.commit!(datatype)
```

# External links
$(_doc_external("MPI_Type_create_hvector"))
"""
function create_hvector(count::Integer, blocklength::Integer, stride::Integer, oldtype::Datatype)
finalizer(free, create_hvector!(Datatype(), count, blocklength, stride, oldtype))
end
function create_hvector!(newtype::Datatype, count::Integer, blocklength::Integer, stride::Integer, oldtype::Datatype)
# int MPI_Type_create_hvector(int count, int blocklength, MPI_Aint stride,
# MPI_Datatype oldtype, MPI_Datatype *newtype)
API.MPI_Type_create_hvector(count, blocklength, MPI_Aint(stride), oldtype, newtype)
return newtype
end

"""
MPI.Types.create_subarray(sizes, subsizes, offset, oldtype::Datatype;
rowmajor=false)
Expand Down
11 changes: 10 additions & 1 deletion test/test_datatype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,32 @@ end
# create_vector
RowType = MPI.Types.create_vector(8, 1, 8, MPI.DOUBLE)
@test typeof(RowType) == MPI.Datatype
@test MPI.Types.size(RowType) == 8MPI.Types.size(MPI.DOUBLE)

# create_hvector
DType = MPI.Types.create_hvector(8, 1, 8, MPI.DOUBLE)
@test typeof(DType) == MPI.Datatype
t-bltg marked this conversation as resolved.
Show resolved Hide resolved
@test MPI.Types.size(DType) == 8MPI.Types.size(MPI.DOUBLE)

# create_subarray
SubMatrixType = MPI.Types.create_subarray((8, 8), (4, 4), (0, 0), MPI.INT64_T)
@test typeof(SubMatrixType) == MPI.Datatype
@test MPI.Types.size(SubMatrixType) == 4^2 * MPI.Types.size(MPI.INT64_T)

# create_struct + _resized
oldtypes = MPI.Datatype.([Float32, Char, Float64])
len = [4, 1, 1]
disp = Vector{Int}(undef, 3)
disp[1] = 0
disp[2] = disp[1] + 4 * sizeof(Float32)
disp[2] = disp[1] + 4sizeof(Float32)
disp[3] = disp[2] + sizeof(Char)

tmp = MPI.Types.create_struct(len, disp, oldtypes)
@test typeof(tmp) == MPI.Datatype
@test MPI.Types.size(tmp) == sum(len .* MPI.Types.size.(oldtypes))
ParticleMPI = MPI.Types.create_resized(tmp, 0, sizeof(Particle))
@test typeof(ParticleMPI) == MPI.Datatype
@test MPI.Types.size(ParticleMPI) == sum(len .* MPI.Types.size.(oldtypes))
end


Expand Down