Skip to content

Commit

Permalink
fixup! allow undefined initial_value
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Sep 11, 2023
1 parent 154d764 commit 912257a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
18 changes: 10 additions & 8 deletions base/scopedvalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ julia> sval[]
implementation is available from the package ScopedValues.jl.
"""
mutable struct ScopedValue{T}
const initial_value::T
ScopedValue{T}() where T = new()
ScopedValue{T}(val) where T = new{T}(val)
ScopedValue(val::T) where T = new{T}(val)
const has_default::Bool
const default::T
ScopedValue{T}() where T = new(false)
ScopedValue{T}(val) where T = new{T}(true, val)
ScopedValue(val::T) where T = new{T}(true, val)
end

Base.eltype(::ScopedValue{T}) where {T} = T
Base.isassigned(val::ScopedValue) = val.has_default

const ScopeStorage = Base.PersistentDict{ScopedValue, Any}

Expand Down Expand Up @@ -99,12 +101,12 @@ function get(val::ScopedValue{T}) where {T}
# Inline current_scope to avoid doing the type assertion twice.
scope = current_task().scope
if scope === nothing
isdefined(val, :initial_value) && return Some(val.initial_value)
isassigned(val) && return Some(val.default)
return nothing
end
scope = scope::Scope
if isdefined(val, :initial_value)
return Some(Base.get(scope.values, val, val.initial_value)::T)
if isassigned(val)
return Some(Base.get(scope.values, val, val.default)::T)
else
v = Base.get(scope.values, val, novalue)
v === novalue || return Some(v::T)
Expand All @@ -124,7 +126,7 @@ function Base.show(io::IO, val::ScopedValue)
print(io, '(')
v = get(val)
if v === nothing
print(io, "(undefined)")
print(io, "undefined")
else
show(IOContext(io, :typeinfo => eltype(val)), something(v))
end
Expand Down
10 changes: 6 additions & 4 deletions test/scopedvalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ import Base: ScopedValues
@testset "errors" begin
@test ScopedValue{Float64}(1)[] == 1.0
@test_throws InexactError ScopedValue{Int}(1.5)
var = ScopedValue(1)
@test_throws MethodError var[] = 2
val = ScopedValue(1)
@test_throws MethodError val[] = 2
with() do
@test_throws MethodError var[] = 2
@test_throws MethodError val[] = 2
end
@test_throws MethodError ScopedValue{Int}()
val = ScopedValue{Int}()
@test_throws KeyError val[]
@test_throws MethodError ScopedValue()
end

Expand Down Expand Up @@ -61,6 +62,7 @@ import Base.Threads: @spawn
end

@testset "show" begin
@test sprint(show, ScopedValue{Int}()) == "ScopedValue{$Int}(undefined)"
@test sprint(show, sval) == "ScopedValue{$Int}(1)"
@test sprint(show, ScopedValues.current_scope()) == "nothing"
with(sval => 2.0) do
Expand Down

0 comments on commit 912257a

Please sign in to comment.