diff --git a/src/slurmmanager.jl b/src/slurmmanager.jl index ff6573e..e2fab25 100644 --- a/src/slurmmanager.jl +++ b/src/slurmmanager.jl @@ -12,34 +12,38 @@ mutable struct SlurmManager <: ClusterManager srun_proc function SlurmManager(; launch_timeout=60.0, srun_post_exit_sleep=0.01) + ntasks = get_slurm_ntasks_int() + jobid = get_slurm_jobid_int() - jobid = - if "SLURM_JOB_ID" in keys(ENV) - ENV["SLURM_JOB_ID"] - elseif "SLURM_JOBID" in keys(ENV) - ENV["SLURM_JOBID"] - else - error(""" - SlurmManager must be constructed inside a slurm allocation environemnt. - SLURM_JOB_ID or SLURM_JOBID must be defined. - """) - end - - ntasks = - if "SLURM_NTASKS" in keys(ENV) - ENV["SLURM_NTASKS"] - else - error(""" - SlurmManager must be constructed inside a slurm environment with a specified number of tasks. - SLURM_NTASKS must be defined. - """) - end + new(jobid, ntasks, launch_timeout, srun_post_exit_sleep, nothing) + end +end - jobid = parse(Int, jobid) - ntasks = parse(Int, ntasks) +function get_slurm_ntasks_int() + if haskey(ENV, "SLURM_NTASKS") + ntasks_str = ENV["SLURM_NTASKS"] + else + msg = "SlurmManager must be constructed inside a Slurm allocation environment." * + "SLURM_NTASKS must be defined." + error(msg) + end + ntasks_int = parse(Int, ntasks_str)::Int + return ntasks_int +end - new(jobid, ntasks, launch_timeout, srun_post_exit_sleep, nothing) +function get_slurm_jobid_int() + if haskey(ENV, "SLURM_JOB_ID") + jobid_str = ENV["SLURM_JOB_ID"] + elseif haskey(ENV, "SLURM_JOBID") + jobid_str = ENV["SLURM_JOBID"] + else + msg = "SlurmManager must be constructed inside a Slurm allocation environment." * + "SLURM_JOB_ID or SLURM_JOBID must be defined." + error(msg) end + + jobid_int = parse(Int, jobid_str)::Int + return jobid_int end @static if Base.VERSION >= v"1.9.0" diff --git a/test/runtests.jl b/test/runtests.jl index 8772e5b..db3f4e0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ import Distributed import Test # Bring some names into scope, just for convenience: -using Test: @testset, @test +using Test: @testset, @test, @test_throws, @test_logs const original_JULIA_DEBUG = strip(get(ENV, "JULIA_DEBUG", "")) if isempty(original_JULIA_DEBUG) @@ -16,6 +16,10 @@ else end @testset "SlurmClusterManager.jl" begin + @testset "Unit tests" begin + include("unit.jl") + end + # test that slurm is available @test !(Sys.which("sinfo") === nothing) diff --git a/test/unit.jl b/test/unit.jl new file mode 100644 index 0000000..9d6f6d2 --- /dev/null +++ b/test/unit.jl @@ -0,0 +1,26 @@ +@testset "get_slurm_ntasks_int()" begin + x = withenv("SLURM_NTASKS" => "12") do + SlurmClusterManager.get_slurm_ntasks_int() + end + @test x == 12 + + withenv("SLURM_NTASKS" => nothing) do + @test_throws ErrorException SlurmClusterManager.get_slurm_ntasks_int() + end +end + +@testset "get_slurm_jobid_int()" begin + x = withenv("SLURM_JOB_ID" => "34", "SLURM_JOBID" => nothing) do + SlurmClusterManager.get_slurm_jobid_int() + end + @test x == 34 + + x = withenv("SLURM_JOB_ID" => nothing, "SLURM_JOBID" => "56") do + SlurmClusterManager.get_slurm_jobid_int() + end + @test x == 56 + + withenv("SLURM_JOB_ID" => nothing, "SLURM_JOBID" => nothing) do + @test_throws ErrorException SlurmClusterManager.get_slurm_jobid_int() + end +end