Skip to content

Commit

Permalink
Use 4 PCG output function variants instead of itertaing 4x per fork
Browse files Browse the repository at this point in the history
Using four PCG steps iterates through the PCG space 4x faster, getting
back to the start after 2^60 splits instead of 2^64 (eh, who cares?).
It's also a bit slow and hard to parallelize. It also reduces the space
of possible weights for each register of SplitMix dot product from 2^64
to 2^60, which is a significant reduction in collision resistance.

This commit instead implements an approach where we use four different
PCG output functions on the same LCG state to get four sufficiently
linearly unrelated streams of pseudorandom SplitMix dot product weights
(one for each xoshiro256 state register). This should give us 256 bits
of SplitMix collision resistance, which is formidable and means that
collisions are a impossible in practice.

Appealingly, this appraich is also easy to vectorize. I even implemented
it with SIMD intrinsices just to be sure (code compiled but not tested):
```
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE])
{
    // load and advance PCG's LCG state
    uint64_t x = src[4];
    // high spectrum multiplier from https://arxiv.org/abs/2001.05304
    src[4] = dst[4] = x * 0xd1342543de82ef95 + 1;

    // manually vectorized PCG-RXS-M-XS with four variants
    static const uint64_t a[4] = {
        0xe5f8fa077b92a8a8, // random additive offsets...
        0x7a0cd918958c124d,
        0x86222f7d388588d4,
        0xd30cbd35f2b64f52
    };
    static const uint64_t m[4] = {
        0xaef17502108ef2d9, // standard multiplier
        0xf34026eeb86766af, // random odd multipliers...
        0x38fd70ad58dd9fbb,
        0x6677f9b93ab0c04d
    };

    __m256i p, s;
    p = _mm256_set1_epi64x(x);                        // p = x
    p = _mm256_add_epi64(p, _mm256_load_epi64(a));    // p += a
    s = _mm256_srlv_epi64(p, _mm256_set1_epi64x(59)); // s = x >> 59
    s = _mm256_add_epi64(s, _mm256_set1_epi64x(5));   // s += 5
    p = _mm256_xor_epi64(p, _mm256_srlv_epi64(p, s)); // p ^= p >> s
    p = _mm256_mullo_epi64(p, _mm256_load_epi64(m));  // p *= m
    s = _mm256_set1_epi64x(43);                       // s = 43
    p = _mm256_xor_epi64(p, _mm256_srlv_epi64(p, s)); // p ^= p >> s

    // load, modify & store xoshiro256 state
    __m256i sv = _mm256_load_epi64(src);
    __m256i dv = _mm256_add_epi64(sv, p); // SplitMix dot product
    _mm256_store_epi64(dst, dv);
}
```
I didn't end up using this because it only works on hardware with the
necessary AVX instructions, so it's not portable, but I wanted to be
sure it could be done. The committed version just uses a loop.

One concern wit this approach that the 256 bits of SplitMix dot product
collision resistance could actually be a mirage. Why? Because the random
weights are generated from 64 bits of LCG state. How is that an issue?
In the proof of DotMix's collision avoidance, which SplitMix inherits,
the number of possible weight values is key: the collision probability
is 1/N where N is the number of possible weight vectors. If we consider
all four xoshihro256 registers as one big dot product and apply the
proof to it, we have a problem: depsite 256 bits of register, there are
only 2^64 possible weight values we can generate, so the proof only
gives us a pairwise collision probability of 1/2^64.

Another way to look at this, however, is to consider the four xoshiro256
register dot products separately: each one has a 1/2^64 collision
probability and there are four of them; as long as the chance of each
one colliding is independent, the probability of all of them colliding
together is (1/2^64)^4 = 1/2^256. Clearly there are ways to generate the
four weights that don't satisfy independence. You could use the same
weights four times, for example. Or you could use weights that are just
scaled copies of each other. Basically any linear relationship between
the weights is be problematic. That's yet another reason that iterating
PCG multiple times to generate weights may not be ideal: the LCG that
drives PCG is very linear; only the output function sabotages the
linearity. If the output function being non-linear is crucial, why not
use multiple different output functions instead?

So that's what I'm doing here: using four different variations on the
PCG output function. First we perturb the LCG state by four different
random additive constants, which moves it to four distant and unrelated
places in the state space and gives the xor shifts different bits to
work with. We also use four different multiplicative constants in the
middle of the output function: the first is the standard PCG multiplier,
so we get known-good weights for one of the registers; the rest are
random odd multipliers. A potential improvement is to look for weights
with optimal cascading behaviors, but random constants tend to be good.

Assuming our four output variants are sufficiently independent, we
should get very strong collision resistance with a pairwise collision
probability of 1/2^256. It would, however, be reassuring to have
empirical evidence that this approach actually works. To that end, I did
a test by scaling everything down to four 8-bit SplitMix dot products
and tested how many simulated task spawns before we get collisions, and
compared to a single 8-bit dot product. Here's the test code:
```
function pcg_output_rxs_m_xs_8_8(x::UInt8)
    p = x
    p += 0xa0 # random but same as below
    p ⊻= p >> ((p >> 6) + 2)
    p *= 0xd9 # standard multiplier
    p ⊻= p >> 6
end

function pcg_output_rxs_m_xs_8_32(x::UInt8)
    ntuple(4) do i
        p = x
        p += (0xa0, 0x98, 0x66, 0x8d)[i]
        p ⊻= p >> ((p >> 6) + 2)
        p *= (0xd9, 0x2b, 0x19, 0x9b)[i]
        p ⊻= p >> 6
    end
end

Base.zero(::Type{NTuple{4, UInt8}}) = (0x0, 0x0, 0x0, 0x0)

function gen_collisions(
    ::Type{T},
    rec :: Int;
    cnt :: Dict{T,Int} = Dict{T,Int}(),
    lcg :: UInt8 = zero(UInt8),
    dot :: T = zero(T),
    out :: Function = T == UInt8 ?
        pcg_output_rxs_m_xs_8_8 :
        pcg_output_rxs_m_xs_8_32
) where {T}
    if rec > 0
        h = out(lcg)
        lcg = lcg * 0x8d + 0x01
        gen_collisions(rec - 1, T; out, cnt, lcg, dot = dot)
        gen_collisions(rec - 1, T; out, cnt, lcg, dot = dot .+ h)
    else
        cnt[dot] = get(cnt, dot, 0) + 1
    end
    return cnt
end
```
With this, `gen_collisions(UInt8, 5)` generating 2^5 = 32 dot products,
already has collisions, whereas we have to generate 2^20 = 1048576 dot
products `gen_collisions(NTuple{4,UInt8}, 20)` to get collisions with
four registers. This provides empirical evidence that this approach to
generating weights is sufficiently independent and that we can really
expect 256 bits of SplitMix collision resistance.
  • Loading branch information
StefanKarpinski committed Mar 25, 2023
1 parent 5c6bbdc commit ea965a1
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 41 deletions.
1 change: 0 additions & 1 deletion base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ let
task.rngState2 = 0x503e1d32781c2608
task.rngState3 = 0x3a77f7189200c20b
task.rngState4 = 0x5502376d099035ae
task.rngState5 = 0x01dd7c407e7dcb1b

# Stdlibs sorted in dependency, then alphabetical, order by contrib/print_sorted_stdlibs.jl
# Run with the `--exclude-jlls` option to filter out all JLL packages
Expand Down
6 changes: 3 additions & 3 deletions src/gc.c
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,9 @@ static void jl_gc_run_finalizers_in_list(jl_task_t *ct, arraylist_t *list) JL_NO
ct->sticky = sticky;
}

static uint64_t finalizer_rngState[6];
static uint64_t finalizer_rngState[JL_RNG_SIZE];

void jl_rng_split(uint64_t dst[6], uint64_t src[6]) JL_NOTSAFEPOINT;
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE]) JL_NOTSAFEPOINT;

JL_DLLEXPORT void jl_gc_init_finalizer_rng_state(void)
{
Expand Down Expand Up @@ -413,7 +413,7 @@ static void run_finalizers(jl_task_t *ct)
jl_atomic_store_relaxed(&jl_gc_have_pending_finalizers, 0);
arraylist_new(&to_finalize, 0);

uint64_t save_rngState[6];
uint64_t save_rngState[JL_RNG_SIZE];
memcpy(&save_rngState[0], &ct->rngState[0], sizeof(save_rngState));
jl_rng_split(ct->rngState, finalizer_rngState);

Expand Down
6 changes: 2 additions & 4 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2768,7 +2768,7 @@ void jl_init_types(void) JL_GC_DISABLED
NULL,
jl_any_type,
jl_emptysvec,
jl_perm_symsvec(17,
jl_perm_symsvec(16,
"next",
"queue",
"storage",
Expand All @@ -2781,12 +2781,11 @@ void jl_init_types(void) JL_GC_DISABLED
"rngState2",
"rngState3",
"rngState4",
"rngState5",
"_state",
"sticky",
"_isexception",
"priority"),
jl_svec(17,
jl_svec(16,
jl_any_type,
jl_any_type,
jl_any_type,
Expand All @@ -2799,7 +2798,6 @@ void jl_init_types(void) JL_GC_DISABLED
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint64_type,
jl_uint8_type,
jl_bool_type,
jl_bool_type,
Expand Down
4 changes: 3 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1910,6 +1910,8 @@ typedef struct _jl_handler_t {
size_t world_age;
} jl_handler_t;

#define JL_RNG_SIZE 5 // xoshiro 4 + splitmix 1

typedef struct _jl_task_t {
JL_DATA_TYPE
jl_value_t *next; // invasive linked list for scheduler
Expand All @@ -1921,7 +1923,7 @@ typedef struct _jl_task_t {
jl_function_t *start;
// 4 byte padding on 32-bit systems
// uint32_t padding0;
uint64_t rngState[6]; // xoshiro 4 + splitmix 2
uint64_t rngState[JL_RNG_SIZE];
_Atomic(uint8_t) _state;
uint8_t sticky; // record whether this Task can be migrated to a new thread
_Atomic(uint8_t) _isexception; // set if `result` is an exception to throw or that we exited with
Expand Down
56 changes: 27 additions & 29 deletions src/task.c
Original file line number Diff line number Diff line change
Expand Up @@ -866,36 +866,34 @@ uint64_t jl_genrandom(uint64_t rngState[4]) JL_NOTSAFEPOINT
return res;
}

// pcg_out = pcg_output_rxs_m_xs_64_64 from
// https://github.com/imneme/pcg-c/blob/83252d9c23df9c82ecb42210afed61a7b42402d7/include/pcg_variants.h#L188-L193
//
// This is the best statistical output function of the PCG family; it produces
// statistically good output even in the case when the state and output are the
// same size, in this case both being 64 bits.
//
inline uint64_t pcg_out(uint64_t x)
void jl_rng_split(uint64_t dst[JL_RNG_SIZE], uint64_t src[JL_RNG_SIZE])
{
int s = x >> 59;
x ^= x >> (s + 5);
x *= 0xaef17502108ef2d9;
return x ^ (x >> 43);
}

const uint64_t LCG_MUL = 0xd1342543de82ef95; // https://arxiv.org/abs/2001.05304

void jl_rng_split(uint64_t dst[6], uint64_t src[6]) JL_NOTSAFEPOINT
{
uint64_t lcg = src[4]; // load internal PCG's LCG state
uint64_t dot = src[5] + pcg_out(lcg); // update splitmix dot product
dst[4] = src[4] = lcg * LCG_MUL + 1; // LCG advances in both child and parent
dst[5] = dot; // dot product modified in child only
// use dot as a PCG state to seed the xoshiro256 registers:
dst[0] = pcg_out(dot = dot * LCG_MUL + 1);
dst[1] = pcg_out(dot = dot * LCG_MUL + 1);
dst[2] = pcg_out(dot = dot * LCG_MUL + 1);
dst[3] = pcg_out(dot = dot * LCG_MUL + 1);
// since the PCG state and output are the same size, the outputs must all be
// distinct, which guarantees that the xoshiro256 state cannot be all zeros
// load and advance PCG's LCG state
uint64_t x = src[4];
src[4] = dst[4] = x * 0xd1342543de82ef95 + 1;
// high spectrum multiplier from https://arxiv.org/abs/2001.05304

static const uint64_t a[4] = {
0xe5f8fa077b92a8a8, // random additive offsets...
0x7a0cd918958c124d,
0x86222f7d388588d4,
0xd30cbd35f2b64f52
};
static const uint64_t m[4] = {
0xaef17502108ef2d9, // standard multiplier
0xf34026eeb86766af, // random odd multipliers...
0x38fd70ad58dd9fbb,
0x6677f9b93ab0c04d
};

// PCG-RXS-M-XS output with four variants
for (int i = 0; i < 4; i++) {
uint64_t p = x + a[i];
p ^= p >> ((p >> 59) + 5);
p *= m[i];
p ^= p >> 43;
dst[i] = src[i] + p; // SplitMix dot product
}
}

JL_DLLEXPORT jl_task_t *jl_new_task(jl_function_t *start, jl_value_t *completion_future, size_t ssize)
Expand Down
4 changes: 1 addition & 3 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,14 @@ rng_native_52(::TaskLocalRNG) = UInt64
function setstate!(
x::TaskLocalRNG,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64 = hash((s0, s1)), # splitmix weight rng state
s5::UInt64 = hash((s2, s3)), # splitmix dot product
s4::UInt64 = hash((s0, s1, s2, s3)), # splitmix weight rng state
)
t = current_task()
t.rngState0 = s0
t.rngState1 = s1
t.rngState2 = s2
t.rngState3 = s3
t.rngState4 = s4
t.rngState5 = s5
x
end

Expand Down

0 comments on commit ea965a1

Please sign in to comment.