Skip to content

Commit

Permalink
combine Consts and API modules (#650)
Browse files Browse the repository at this point in the history
* combine Consts and API modules
* @t-bltg's fixes
* Builds docs only for functions in `API` module

Co-authored-by: Mosè Giordano <[email protected]>
  • Loading branch information
simonbyrne and giordano authored Sep 30, 2022
1 parent 5a9ed4e commit 61947d9
Show file tree
Hide file tree
Showing 32 changed files with 319 additions and 322 deletions.
2 changes: 1 addition & 1 deletion docs/src/reference/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ MPI.set_errorhandler!
## Miscellaneous

```@docs
MPI.Consts.@const_ref
MPI.API.@const_ref
```
1 change: 1 addition & 0 deletions docs/src/reference/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@

```@autodocs
Modules = [MPI.API]
Order = [:function]
```
2 changes: 1 addition & 1 deletion gen/src/MPIgenerator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ module MPIgenerator
end
write(src, join(lines, "\n"))

dst = normpath(@__DIR__, "..", "..", "src", "auto_generated_api.jl")
dst = normpath(@__DIR__, "..", "..", "src", "api", "generated_api.jl")
mv(src, dst; force=true) # move the generated file to src
rm(out) # cleanup

Expand Down
63 changes: 18 additions & 45 deletions src/MPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,31 @@ function deserialize(x)
Serialization.deserialize(s)
end

primitive type SentinelPtr Sys.WORD_SIZE
end

primitive type MPIPtr Sys.WORD_SIZE
end
@assert sizeof(MPIPtr) == sizeof(Ptr{Cvoid})
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)


function _doc_external(fname)
"""
- `$fname` man page: [OpenMPI](https://www.open-mpi.org/doc/current/man3/$fname.3.php), [MPICH](https://www.mpich.org/static/docs/latest/www3/$fname.html)
"""
end

"""
MPIError
import MPIPreferences

if MPIPreferences.binary == "MPICH_jll"
import MPICH_jll: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
elseif MPIPreferences.binary == "OpenMPI_jll"
import OpenMPI_jll: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
elseif MPIPreferences.binary == "MicrosoftMPI_jll"
import MicrosoftMPI_jll: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
elseif MPIPreferences.binary == "MPItrampoline_jll"
import MPItrampoline_jll: MPItrampoline_jll, libmpi, libmpi_handle, mpiexec
const libmpiconstants = MPItrampoline_jll.libload_time_mpi_constants_path
elseif MPIPreferences.binary == "system"
import MPIPreferences.System: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
else
error("Unknown MPI binary: $(MPIPreferences.binary)")
Error thrown when an MPI function returns an error code. The `code` field contains the MPI error code.
"""
struct MPIError <: Exception
code::Cint
end
function Base.show(io::IO, err::MPIError)
print(io, "MPIError(", err.code, "): ", error_string(err))
end


include("consts/consts.jl")
using .Consts


include("api/api.jl")
using .API
const Consts = API

# These functions are run after reading the values of the constants above)
const _mpi_load_time_hooks = Any[]
Expand All @@ -73,21 +58,9 @@ function run_load_time_hooks()
nothing
end

using MPIPreferences
include("implementations.jl")
include("error.jl")

module API
import ..libmpi, ..libmpi_handle, ..MPIPtr
import ..use_stdcall, ..MPIError, ..@mpicall, ..@mpichk
using ..Consts

for name in filter(n -> startswith(string(n), "MPI_"), names(Consts; all = true))
@eval $name = Consts.$name # signatures need types
end

include("auto_generated_api.jl")
end

include("info.jl")
include("group.jl")
include("comm.jl")
Expand Down Expand Up @@ -140,7 +113,7 @@ function __init__()

# Needs to be called after `dlopen`. Use `invokelatest` so that `cglobal`
# calls don't trigger early `dlopen`-ing of the library.
Base.invokelatest(Consts.init_consts)
Base.invokelatest(API.init_consts)

# disable UCX memory cache, since it doesn't work correctly
# https://github.com/openucx/ucx/issues/5061
Expand All @@ -157,7 +130,7 @@ function __init__()
end

if MPIPreferences.binary == "MPItrampoline_jll" && !haskey(ENV, "MPITRAMPOLINE_MPIEXEC")
ENV["MPITRAMPOLINE_MPIEXEC"] = MPItrampoline_jll.mpich_mpiexec_path
ENV["MPITRAMPOLINE_MPIEXEC"] = API.MPItrampoline_jll.mpich_mpiexec_path
end

run_load_time_hooks()
Expand Down
138 changes: 138 additions & 0 deletions src/api/api.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
module API

export MPI_Aint, MPI_Count, MPI_Offset, MPI_Status,
MPI_Comm, MPI_Datatype, MPI_Errhandler, MPI_File, MPI_Group,
MPI_Info, MPI_Message, MPI_Op, MPI_Request, MPI_Win,
libmpi, mpiexec, @mpichk, @mpicall, MPIPtr, SentinelPtr, FeatureLevelError

import MPIPreferences
using Libdl

if MPIPreferences.binary == "MPICH_jll"
import MPICH_jll: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
elseif MPIPreferences.binary == "OpenMPI_jll"
import OpenMPI_jll: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
elseif MPIPreferences.binary == "MicrosoftMPI_jll"
import MicrosoftMPI_jll: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
elseif MPIPreferences.binary == "MPItrampoline_jll"
import MPItrampoline_jll: MPItrampoline_jll, libmpi, libmpi_handle, mpiexec
const libmpiconstants = MPItrampoline_jll.libload_time_mpi_constants_path
elseif MPIPreferences.binary == "system"
import MPIPreferences.System: libmpi, libmpi_handle, mpiexec
const libmpiconstants = nothing
else
error("Unknown MPI binary: $(MPIPreferences.binary)")
end

import ..MPIError
const initexprs = Any[]

"""
@const_ref name T expr
Defines an constant binding
```julia
const name = Ref{T}()
```
and adds a hook to execute
```julia
name[] = expr
```
at module initialization time.
"""
macro const_ref(name, T, expr)
push!(initexprs, :($name[] = $expr))
:(const $(esc(name)) = Ref{$T}())
end

@static if MPIPreferences.abi == "MPICH"
include("mpich.jl")
elseif MPIPreferences.abi == "OpenMPI"
include("openmpi.jl")
elseif MPIPreferences.abi == "MicrosoftMPI"
include("microsoftmpi.jl")
elseif MPIPreferences.abi == "MPItrampoline"
include("mpitrampoline.jl")
elseif MPIPreferences.abi == "HPE MPT"
include("mpt.jl")
else
error("Unknown MPI ABI $(MPIPreferences.abi)")
end

primitive type SentinelPtr Sys.WORD_SIZE
end

primitive type MPIPtr Sys.WORD_SIZE
end
@assert sizeof(MPIPtr) == sizeof(Ptr{Cvoid})
Base.cconvert(::Type{MPIPtr}, x::SentinelPtr) = x
Base.unsafe_convert(::Type{MPIPtr}, x::SentinelPtr) = reinterpret(MPIPtr, x)


# Initialize the ref constants from the library.
# This is not `API.__init__`, as it should be called _after_
# `dlopen` to ensure the library is opened correctly.
@eval function init_consts()
$(Expr(:block, initexprs...))
end

const use_stdcall = startswith(basename(libmpi), "msmpi")

macro mpicall(expr)
@assert expr isa Expr && expr.head == :call && expr.args[1] == :ccall

# On unix systems we call the global symbols to allow for LD_PRELOAD interception
# It can be emulated in Windows (via Libdl.dllist), but this is not fast.
if Sys.isunix() && expr.args[2].head == :tuple &&
(VERSION v"1.5-" || expr.args[2].args[1] :(:MPI_Get_library_version))
expr.args[2] = expr.args[2].args[1]
end

# Microsoft MPI uses stdcall calling convention
# this only affects 32-bit Windows
# unfortunately we need to use ccall to call Get_library_version
# so check using library name instead
if use_stdcall
insert!(expr.args, 3, :stdcall)
end
return esc(expr)
end

"""
FeatureLevelError
Error thrown if a feature is not implemented in the current MPI backend.
"""
struct FeatureLevelError <: Exception
function_name::Symbol
min_version::VersionNumber # minimal MPI version required for this feature to be available
end
function Base.show(io::IO, err::FeatureLevelError)
print(io, "FeatureLevelError($(err.function_name)): Minimum MPI version is $(err.min_version)")
end

macro mpichk(expr, min_version=nothing)
if !isnothing(min_version) && expr.args[2].head == :tuple
fn = expr.args[2].args[1].value
if isnothing(dlsym(libmpi_handle, fn; throw_error=false))
return quote
throw(FeatureLevelError($(QuoteNode(fn)), $min_version))
end
end
end

expr = macroexpand(@__MODULE__, :(@mpicall($expr)))
# MPI_SUCCESS is defined to be 0
:((errcode = $(esc(expr))) == 0 || throw(MPIError(errcode)))
end


include("generated_api.jl")

# since this is called by invokelatest, it isn't automatically precompiled
precompile(init_consts, ())

end
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/buffers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ MPIPtr

struct InPlace
end
Base.cconvert(::Type{MPIPtr}, ::InPlace) = Consts.MPI_IN_PLACE[]
Base.cconvert(::Type{MPIPtr}, ::InPlace) = API.MPI_IN_PLACE[]


"""
Expand Down
4 changes: 2 additions & 2 deletions src/collective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ If only one buffer `sendrecvbuf` is used, then data is overwritten.
$(_doc_external("MPI_Alltoall"))
"""
function Alltoall!(sendbuf::UBuffer, recvbuf::UBuffer, comm::Comm)
if sendbuf.data !== Consts.MPI_IN_PLACE[] && sendbuf.nchunks !== nothing
if sendbuf.data !== API.MPI_IN_PLACE[] && sendbuf.nchunks !== nothing
@assert sendbuf.nchunks >= Comm_size(comm)
end
if recvbuf.nchunks !== nothing
Expand Down Expand Up @@ -521,7 +521,7 @@ Similar to [`Alltoall!`](@ref), except with different size chunks per process.
$(_doc_external("MPI_Alltoallv"))
"""
function Alltoallv!(sendbuf::VBuffer, recvbuf::VBuffer, comm::Comm)
if sendbuf.data !== Consts.MPI_IN_PLACE[]
if sendbuf.data !== API.MPI_IN_PLACE[]
@assert length(sendbuf.counts) >= Comm_size(comm)
end
@assert length(recvbuf.counts) >= Comm_size(comm)
Expand Down
22 changes: 11 additions & 11 deletions src/comm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ Base.unsafe_convert(::Type{MPI_Comm}, comm::Comm) = comm.val
Base.unsafe_convert(::Type{Ptr{MPI_Comm}}, comm::Comm) = convert(Ptr{MPI_Comm}, pointer_from_objref(comm))


const COMM_NULL = Comm(Consts.MPI_COMM_NULL[])
add_load_time_hook!(() -> COMM_NULL.val = Consts.MPI_COMM_NULL[])
const COMM_NULL = Comm(API.MPI_COMM_NULL[])
add_load_time_hook!(() -> COMM_NULL.val = API.MPI_COMM_NULL[])

"""
MPI.COMM_WORLD
A communicator containing all processes with which the local rank can communicate at
initialization. In a typical "static-process" model, this will be all processes.
"""
const COMM_WORLD = Comm(Consts.MPI_COMM_WORLD[])
add_load_time_hook!(() -> COMM_WORLD.val = Consts.MPI_COMM_WORLD[])
const COMM_WORLD = Comm(API.MPI_COMM_WORLD[])
add_load_time_hook!(() -> COMM_WORLD.val = API.MPI_COMM_WORLD[])

"""
MPI.COMM_SELF
A communicator containing only the local process.
"""
const COMM_SELF = Comm(Consts.MPI_COMM_SELF[])
add_load_time_hook!(() -> COMM_SELF.val = Consts.MPI_COMM_SELF[])
const COMM_SELF = Comm(API.MPI_COMM_SELF[])
add_load_time_hook!(() -> COMM_SELF.val = API.MPI_COMM_SELF[])

Comm() = Comm(COMM_NULL.val)

Expand Down Expand Up @@ -173,7 +173,7 @@ $(_doc_external("MPI_Comm_split"))
"""
function Comm_split(comm::Comm, color::Union{Integer, Nothing}, key::Integer)
if isnothing(color)
color = Consts.MPI_UNDEFINED[]
color = API.MPI_UNDEFINED[]
end
newcomm = Comm()
API.MPI_Comm_split(comm, color, key, newcomm)
Expand All @@ -185,7 +185,7 @@ mutable struct SplitType
val::Cint
end
const COMM_TYPE_SHARED = SplitType(-1)
add_load_time_hook!(() -> COMM_TYPE_SHARED.val = Consts.MPI_COMM_TYPE_SHARED[])
add_load_time_hook!(() -> COMM_TYPE_SHARED.val = API.MPI_COMM_TYPE_SHARED[])


"""
Expand All @@ -205,7 +205,7 @@ $(_doc_external("MPI_Comm_split_type"))
"""
function Comm_split_type(comm::Comm, split_type, key::Integer; kwargs...)
if isnothing(split_type)
split_type = Consts.MPI_UNDEFINED[]
split_type = API.MPI_UNDEFINED[]
elseif split_type isa SplitType
split_type = split_type.val
end
Expand Down Expand Up @@ -276,7 +276,7 @@ The total number of available slots, or `nothing` if it is not defined. This is
This is typically dependent on the MPI implementation: for MPICH-based implementations, this is specified by the `-usize` argument. OpenMPI defines a default value based on the number of processes available.
"""
function universe_size()
ptr = unsafe_get_attr(COMM_WORLD, Consts.MPI_UNIVERSE_SIZE[])
ptr = unsafe_get_attr(COMM_WORLD, API.MPI_UNIVERSE_SIZE[])
isnothing(ptr) && return nothing
return Int(unsafe_load(Ptr{Cint}(ptr)))
end
Expand All @@ -287,7 +287,7 @@ end
The maximum value tag value for point-to-point operations.
"""
function tag_ub()
ptr = something(unsafe_get_attr(COMM_WORLD, Consts.MPI_TAG_UB[]))
ptr = something(unsafe_get_attr(COMM_WORLD, API.MPI_TAG_UB[]))
return Int(unsafe_load(Ptr{Cint}(ptr)))
end

Expand Down
Loading

0 comments on commit 61947d9

Please sign in to comment.