Skip to content

Commit

Permalink
Refactor the SLURM_NTASKS and SLURM_JOB_ID functionality out into…
Browse files Browse the repository at this point in the history
… separate utility functions, and add some more unit tests to increase code coverage
  • Loading branch information
DilumAluthge committed Feb 9, 2025
1 parent 410734b commit 03df383
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
52 changes: 28 additions & 24 deletions src/slurmmanager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,38 @@ mutable struct SlurmManager <: ClusterManager
srun_proc

function SlurmManager(; launch_timeout=60.0, srun_post_exit_sleep=0.01)
ntasks = get_slurm_ntasks_int()
jobid = get_slurm_jobid_int()

jobid =
if "SLURM_JOB_ID" in keys(ENV)
ENV["SLURM_JOB_ID"]
elseif "SLURM_JOBID" in keys(ENV)
ENV["SLURM_JOBID"]
else
error("""
SlurmManager must be constructed inside a slurm allocation environemnt.
SLURM_JOB_ID or SLURM_JOBID must be defined.
""")
end

ntasks =
if "SLURM_NTASKS" in keys(ENV)
ENV["SLURM_NTASKS"]
else
error("""
SlurmManager must be constructed inside a slurm environment with a specified number of tasks.
SLURM_NTASKS must be defined.
""")
end
new(jobid, ntasks, launch_timeout, srun_post_exit_sleep, nothing)
end
end

jobid = parse(Int, jobid)
ntasks = parse(Int, ntasks)
function get_slurm_ntasks_int()
if haskey(ENV, "SLURM_NTASKS")
ntasks_str = ENV["SLURM_NTASKS"]
else
msg = "SlurmManager must be constructed inside a Slurm allocation environment." *
"SLURM_NTASKS must be defined."
error(msg)
end
ntasks_int = parse(Int, ntasks_str)::Int
return ntasks_int
end

new(jobid, ntasks, launch_timeout, srun_post_exit_sleep, nothing)
function get_slurm_jobid_int()
if haskey(ENV, "SLURM_JOB_ID")
jobid_str = ENV["SLURM_JOB_ID"]
elseif haskey(ENV, "SLURM_JOBID")
jobid_str = ENV["SLURM_JOBID"]
else
msg = "SlurmManager must be constructed inside a Slurm allocation environment." *
"SLURM_JOB_ID or SLURM_JOBID must be defined."
error(msg)
end

jobid_int = parse(Int, jobid_str)::Int
return jobid_int
end

@static if Base.VERSION >= v"1.9.0"
Expand Down
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Distributed
import Test

# Bring some names into scope, just for convenience:
using Test: @testset, @test
using Test: @testset, @test, @test_throws, @test_logs

const original_JULIA_DEBUG = strip(get(ENV, "JULIA_DEBUG", ""))
if isempty(original_JULIA_DEBUG)
Expand All @@ -16,6 +16,10 @@ else
end

@testset "SlurmClusterManager.jl" begin
@testset "Unit tests" begin
include("unit.jl")
end

# test that slurm is available
@test !(Sys.which("sinfo") === nothing)

Expand Down
26 changes: 26 additions & 0 deletions test/unit.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
@testset "get_slurm_ntasks_int()" begin
x = withenv("SLURM_NTASKS" => "12") do
SlurmClusterManager.get_slurm_ntasks_int()
end
@test x == 12

withenv("SLURM_NTASKS" => nothing) do
@test_throws ErrorException SlurmClusterManager.get_slurm_ntasks_int()
end
end

@testset "get_slurm_jobid_int()" begin
x = withenv("SLURM_JOB_ID" => "34", "SLURM_JOBID" => nothing) do
SlurmClusterManager.get_slurm_jobid_int()
end
@test x == 34

x = withenv("SLURM_JOB_ID" => nothing, "SLURM_JOBID" => "56") do
SlurmClusterManager.get_slurm_jobid_int()
end
@test x == 56

withenv("SLURM_JOB_ID" => nothing, "SLURM_JOBID" => nothing) do
@test_throws ErrorException SlurmClusterManager.get_slurm_jobid_int()
end
end

0 comments on commit 03df383

Please sign in to comment.