Skip to content

Commit ca5d92d

Browse files
authored
Update set_global_shift_decrease! to shift weights rather than recomputing them (#112)
* Update `set_global_shift_decrease!` to shift weights rather than recomputing them, taking advantage of the fact that the shift is always down (note: this includes a bug) and add comments about invariants and a TODO * Add test * inline the only remaining call-site of recompute_range Note: this does include some regressions, though the gestalt is positive and 5B′ is an improvement.
1 parent abb46c5 commit ca5d92d

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

src/DynamicDiscreteSamplers.jl

+39-32
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ function _set_from_zero!(m::Memory, v::Float64, i::Int)
318318
# update group total weight and total weight
319319
significand = 0x8000000000000000 | uv << 11
320320
weight_index = _convert(Int, exponent + 4)
321-
significand_sum = update_significand_sum(m, weight_index, significand)
321+
significand_sum = update_significand_sum(m, weight_index, significand) # Temporarily break the "weights are accurately computed" invariant
322322

323323
if m[4] == 0 # if we were empty, set global shift (m[3]) so that m[4] will become ~2^40.
324324
m[3] = -24 - exponent
@@ -339,16 +339,18 @@ function _set_from_zero!(m::Memory, v::Float64, i::Int)
339339
# Base.top_set_bit(significand_sum)+signed(exponent) + signed(m[3]) == 48
340340
# signed(m[3]) == 48 - Base.top_set_bit(significand_sum) - signed(exponent)
341341
m3 = 48 - Base.top_set_bit(significand_sum) - exponent
342+
# The "weights are accurately computed" invariant is broken for weight_index, but the "sum(weights) == m[4]" invariant still holds
343+
# set_global_shift_decrease! will do something wrong to weight_index, but preserve the "sum(weights) == m[4]" invariant.
342344
set_global_shift_decrease!(m, m3) # TODO for perf: special case all call sites to this function to take advantage of known shift direction and/or magnitude; also try outlining
343345
shift = signed(exponent + m3)
344346
end
345347
weight = _convert(UInt64, significand_sum << shift) + 1
346348

347349
old_weight = m[weight_index]
348-
m[weight_index] = weight
349-
m4 = m[4]
350+
m[weight_index] = weight # The "weights are accurately computed" invariant is now restored
351+
m4 = m[4] # The "sum(weights) == m[4]" invariant is broken
350352
m4 -= old_weight
351-
m4, o = Base.add_with_overflow(m4, weight)
353+
m4, o = Base.add_with_overflow(m4, weight) # The "sum(weights) == m4" invariant now holds, though the computation overflows
352354
if o
353355
# If weights overflow (>2^64) then shift down by 16 bits
354356
m3 = m[3]-0x10
@@ -491,8 +493,30 @@ function set_global_shift_increase!(m::Memory, m2, m3::UInt64, m4) # Increase sh
491493
i <= -signed(m3)-122+4
492494
So for -signed(m3)-118 < i, we could need to adjust the ith weight
493495
=#
494-
recompute_range = max(5, -signed(m3)-117):m2 # TODO It would be possible to scale this range with length (m[1]) in which case testing could be stricter and performance could be (marginally) better, though not in large cases so possibly not worth doing at all)
495-
m[4] = recompute_weights!(m, m3, m4, recompute_range)
496+
r0 = max(5, -signed(m3)-117)
497+
r1 = m2 # TODO It would be possible to scale this range with length (m[1]) in which case testing could be stricter and performance could be (marginally) better, though not in large cases so possibly not worth doing at all)
498+
499+
# shift = signed(i-4+m3)
500+
# weight = significand_sum == 0 ? 0 : UInt64(significand_sum << shift) + 1
501+
# shift < -64; the low 64 bits are shifted off.
502+
# i < -60-signed(m3); the low 64 bits are shifted off.
503+
504+
checkbounds(m, r0:2r1+2042)
505+
@inbounds for i in r0:min(r1, -61-signed(m3))
506+
significand_sum_lo = m[_convert(Int, 2i+2041)]
507+
significand_sum_hi = m[_convert(Int, 2i+2042)]
508+
significand_sum_lo == significand_sum_hi == 0 && continue # in this case, the weight was and still is zero
509+
shift = signed(i-4+m3) + 64
510+
m4 += update_weight!(m, i, significand_sum_hi << shift)
511+
end
512+
@inbounds for i in max(r0,-60-signed(m3)):r1
513+
significand_sum = get_significand_sum(m, i)
514+
significand_sum == 0 && continue # in this case, the weight was and still is zero
515+
shift = signed(i-4+m3)
516+
m4 += update_weight!(m, i, significand_sum << shift)
517+
end
518+
519+
m[4] = m4
496520
end
497521

498522
function set_global_shift_decrease!(m::Memory, m3::UInt64, m4=m[4]) # Decrease shift, on insertion of elements
@@ -503,7 +527,7 @@ function set_global_shift_decrease!(m::Memory, m3::UInt64, m4=m[4]) # Decrease s
503527
# In the case of adding a giant element, call this first, then add the element.
504528
# In any case, this only adjusts elements at or before m[2]
505529
# from the first index that previously could have had a weight > 1 to min(m[2], the first index that can't have a weight > 1) (never empty), set weights to 1 or 0
506-
# from the first index that could have a weight > 1 to m[2] (possibly empty), recompute weights.
530+
# from the first index that could have a weight > 1 to m[2] (possibly empty), shift weights by delta.
507531
m2 = signed(m[2])
508532
i1 = -signed(m3)-117 # see above, this is the first index that could have weight > 1 (anything after this will have weight 1 or 0)
509533
i1_old = -signed(m3_old)-117 # anything before this is already weight 1 or 0
@@ -520,35 +544,18 @@ function set_global_shift_decrease!(m::Memory, m3::UInt64, m4=m[4]) # Decrease s
520544
m[i] = weight
521545
m4 += weight-old_weight
522546
end
523-
m4 = recompute_weights!(m, m3, m4, recompute_range)
547+
548+
delta = m3_old-m3
549+
checkbounds(m, recompute_range)
550+
@inbounds for i in recompute_range
551+
old_weight = m[i]
552+
old_weight <= 1 && continue # in this case, the weight was and still is 0 or 1
553+
m4 += update_weight!(m, i, (old_weight-1) >> delta)
554+
end
524555

525556
m[4] = m4
526557
end
527558

528-
function recompute_weights!(m::Memory{UInt64}, m3::UInt64, m4::UInt64, range::UnitRange{Int64})
529-
isempty(range) && return m4
530-
r0,r1 = extrema(range)
531-
# shift = signed(i-4+m3)
532-
# weight = significand_sum == 0 ? 0 : UInt64(significand_sum << shift) + 1
533-
# shift < -64; the low 64 bits are shifted off.
534-
# i < -60-signed(m3); the low 64 bits are shifted off.
535-
536-
checkbounds(m, r0:2r1+2042)
537-
@inbounds for i in r0:min(r1, -61-signed(m3))
538-
significand_sum_lo = m[_convert(Int, 2i+2041)]
539-
significand_sum_hi = m[_convert(Int, 2i+2042)]
540-
significand_sum_lo == significand_sum_hi == 0 && continue # in this case, the weight was and still is zero
541-
shift = signed(i-4+m3) + 64
542-
m4 += update_weight!(m, i, significand_sum_hi << shift)
543-
end
544-
@inbounds for i in max(r0,-60-signed(m3)):r1
545-
significand_sum = get_significand_sum(m, i)
546-
significand_sum == 0 && continue # in this case, the weight was and still is zero
547-
shift = signed(i-4+m3)
548-
m4 += update_weight!(m, i, significand_sum << shift)
549-
end
550-
m4
551-
end
552559
Base.@propagate_inbounds function update_weight!(m::Memory{UInt64}, i, shifted_significand_sum)
553560
weight = _convert(UInt64, shifted_significand_sum) + 1
554561
old_weight = m[i]

test/weights.jl

+7
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ w[1] = 0.95
240240
w[2] = 6.41e14
241241
verify(w.m)
242242

243+
# This test catches a bug that was not revealed by the RNG tests below
244+
w = DynamicDiscreteSamplers.FixedSizeWeights(3);
245+
w[1] = 1.5
246+
w[2] = prevfloat(1.5)
247+
w[3] = 2^25
248+
verify(w.m)
249+
243250
# This test catches a bug that was not revealed by the RNG tests below.
244251
# The final line is calibrated to have about a 50% fail rate on that bug
245252
# and run in about 3 seconds:

0 commit comments

Comments
 (0)