Skip to content

Commit

Permalink
quite faster rand(::MersenneTwister, ::Type{Int}) ... (JuliaLang#37914)
Browse files Browse the repository at this point in the history
... when the quest for superficial beauty leads to performance ...

The recent implementation of `show` for `MersenneTwister` was not
ideal, as a number of book-keeping variables had to be introduced as
fields of MT; in particular, as generating the cache for ints (using
the generic `rand!` for integer arrays) was consuming random `Float64`
numbers, 4 integers had to be shown in `show` only to reproduce the
state of the ints cache. In total 8 integers were shown.

But it is not so difficult to improve a bit, thanks to two features of
the internal cache:

1) it's 16-byte aligned, so the dSFMT low-level routine can be called
   directly on it (whereas the generic `rand!` for integers has to
   take care that the same stream is produced whatever the alignment)

2) it can be resized: dSFMT randomizing only 52 out of 64 bits,
   i.e. a bit more than 80% of the bits, the trick is to count
   the total number of needed bits, grow the array to a size
   such that dSFMT produces these needed bits, and then condense
   these bits back into a 100% randomized array of the original size

As a net result, two variables could be deleted which `show` doesn't
needs to display anymore.

Purely as a side effect, scalar generation of `Int64`/`UInt64` has
a speedup of about 1.8x, and about 1.6x for `Int128`/`UInt128`,
and about 1.3x or 1.5x for generation in a range of integers of size
with less than 64 bits, e.g. `rand(1:9)` (at least on this machine...)
  • Loading branch information
rfourquet authored Oct 17, 2020
1 parent 8078eac commit b8bc816
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 92 deletions.
18 changes: 9 additions & 9 deletions stdlib/Random/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,22 @@ Scalar and array methods for `Die` now work as expected:

```jldoctest Die; setup = :(Random.seed!(1))
julia> rand(Die)
Die(10)
Die(15)
julia> rand(MersenneTwister(0), Die)
Die(16)
Die(11)
julia> rand(Die, 3)
3-element Vector{Die}:
Die(18)
Die(5)
Die(20)
Die(9)
Die(4)
julia> a = Vector{Die}(undef, 3); rand!(a)
3-element Vector{Die}:
Die(11)
Die(5)
Die(20)
Die(10)
Die(15)
```

#### A simple sampler without pre-computed data
Expand All @@ -173,13 +173,13 @@ In order to define random generation out of objects of type `S`, the following m
julia> Random.rand(rng::AbstractRNG, d::Random.SamplerTrivial{Die}) = rand(rng, 1:d[].nsides);
julia> rand(Die(4))
2
3
julia> rand(Die(4), 3)
3-element Vector{Any}:
1
4
2
1
1
```

Given a collection type `S`, it's currently assumed that if `rand(::S)` is defined, an object of type `eltype(S)` will be produced. In the last example, a `Vector{Any}` is produced; the reason is that `eltype(Die) == Any`. The remedy is to define `Base.eltype(::Type{Die}) = Int`.
Expand Down
126 changes: 69 additions & 57 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,27 +90,25 @@ mutable struct MersenneTwister <: AbstractRNG
adv_jump::BigInt # number of skipped Float64 values via randjump
adv_vals::Int64 # state of advance when vals is filled-up
adv_ints::Int64 # state of advance when ints is filled-up
adv_vals_pre::Int64 # state of advance when vals is filled-up before ints
adv_idxF_pre::Int # value of idxF before ints is filled-up

function MersenneTwister(seed, state, vals, ints, idxF, idxI,
adv, adv_jump, adv_vals, adv_ints, adv_vals_pre, adv_idxF_pre)
adv, adv_jump, adv_vals, adv_ints)
length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F ||
throw(DomainError((length(vals), idxF),
"`length(vals)` and `idxF` must be consistent with $MT_CACHE_F"))
length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I ||
throw(DomainError((length(ints), idxI),
"`length(ints)` and `idxI` must be consistent with $MT_CACHE_I"))
new(seed, state, vals, ints, idxF, idxI,
adv, adv_jump, adv_vals, adv_ints, adv_vals_pre, adv_idxF_pre)
adv, adv_jump, adv_vals, adv_ints)
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state,
Vector{Float64}(undef, MT_CACHE_F),
Vector{UInt128}(undef, MT_CACHE_I >> 4),
MT_CACHE_F, 0, 0, 0, -1, -1, -1, -1)
MT_CACHE_F, 0, 0, 0, -1, -1)

"""
MersenneTwister(seed)
Expand Down Expand Up @@ -161,15 +159,12 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
dst.adv_jump = src.adv_jump
dst.adv_vals = src.adv_vals
dst.adv_ints = src.adv_ints
dst.adv_vals_pre = src.adv_vals_pre
dst.adv_idxF_pre = src.adv_idxF_pre
dst
end

copy(src::MersenneTwister) =
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints,
src.adv_vals_pre, src.adv_idxF_pre)
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints)


==(r1::MersenneTwister, r2::MersenneTwister) =
Expand All @@ -191,17 +186,18 @@ function show(io::IO, rng::MersenneTwister)
print(io, "MersenneTwister($seed_str, (")
# state
adv = Integer[rng.adv_jump, rng.adv]
if rng.adv_vals != -1
push!(adv, rng.adv_vals, rng.idxF)
if rng.adv_vals != -1 || rng.adv_ints != -1
if rng.adv_vals == -1
@assert rng.idxF == MT_CACHE_F
push!(adv, 0, 0) # "(0, 0)" is nicer on the eyes than (-1, 1002)
else
push!(adv, rng.adv_vals, rng.idxF)
end
end
if rng.adv_ints != -1 # then rng.adv_vals is always != -1
if rng.adv_ints != -1
idxI = (length(rng.ints)*16 - rng.idxI) / 8 # 8 represents one Int64
idxI = Int(idxI) # idxI should always be an integer when using public APIs
push!(adv,
rng.adv_ints,
rng.adv_vals_pre == -1 ? 0 : rng.adv_vals_pre,
rng.adv_vals_pre == -1 ? 0 : rng.adv_idxF_pre,
idxI)
push!(adv, rng.adv_ints, idxI)
end
join(io, adv, ", ")
print(io, "))")
Expand All @@ -217,8 +213,6 @@ function reset_caches!(r::MersenneTwister)
mt_setempty!(r, UInt128)
r.adv_vals = -1
r.adv_ints = -1
r.adv_vals_pre = -1
r.adv_idxF_pre = -1
r
end

Expand Down Expand Up @@ -261,9 +255,33 @@ mt_avail(r::MersenneTwister, ::Type{T}) where {T<:BitInteger} =

function mt_setfull!(r::MersenneTwister, ::Type{<:BitInteger})
r.adv_ints = r.adv
r.adv_vals_pre = r.adv_vals
r.adv_idxF_pre = r.idxF
rand!(r, r.ints)
ints = r.ints

@assert length(ints) == 501
# dSFMT natively randomizes 52 out of 64 bits of each UInt64 words,
# i.e. 12 bits are missing;
# by generating 5 words == 5*52 == 260 bits, we can fully
# randomize 4 UInt64 = 256 bits; IOW, at the array level, we must
# randomize ceil(501*1.25) = 627 UInt128 words (with 2*52 bits each),
# which we then condense into fully randomized 501 UInt128 words

len = 501 + 126 # 126 == ceil(501 / 4)
resize!(ints, len)
p = pointer(ints) # must be *after* resize!
GC.@preserve r fill_array!(r, Ptr{Float64}(p), len*2, CloseOpen12_64())

k = 501
n = 0
@inbounds while n != 500
u = ints[k+=1]
ints[n+=1] ⊻= u << 48
ints[n+=1] ⊻= u << 36
ints[n+=1] ⊻= u << 24
ints[n+=1] ⊻= u << 12
end
@assert k == len - 1
@inbounds ints[501] ⊻= ints[len] << 48
resize!(ints, 501)
r.idxI = MT_CACHE_I
end

Expand Down Expand Up @@ -740,17 +758,17 @@ jump!(r::MersenneTwister, steps::Integer) = copy!(r, jump(r, steps))
# parameters in the tuples are:
# 1: .adv_jump (jump steps)
# 2: .adv (number of generated floats at the DSFMT_state level since seeding, besides jumps)
# 3, 4: .adv_vals, .idxF (counters to reconstruct the float chache, optional if 5-8 not shown))
# 5-8: .adv_ints, .adv_vals_pre, .adv_idxF_pre, .idxI (counters to reconstruct the integer chache, optional)
# 3, 4: .adv_vals, .idxF (counters to reconstruct the float chache, optional if 5-6 not shown))
# 5, 6: .adv_ints, .idxI (counters to reconstruct the integer chache, optional)

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{8,Integer}) =
Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{6,Integer}) =
advance!(MersenneTwister(seed), advance...)

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{4,Integer}) =
MersenneTwister(seed, (advance..., -1, -1, -1, -1))
MersenneTwister(seed, (advance..., 0, 0))

Random.MersenneTwister(seed::Union{Integer,Vector{UInt32}}, advance::NTuple{2,Integer}) =
MersenneTwister(seed, (advance..., 0, 0, -1, -1, -1, -1))
MersenneTwister(seed, (advance..., 0, 0, 0, 0))

# advances raw state (per fill_array!) of r by n steps (Float64 values)
function _advance_n!(r::MersenneTwister, n::Int64, work::Vector{Float64})
Expand All @@ -775,24 +793,10 @@ function _advance_to!(r::MersenneTwister, adv::Int64, work)
end

function _advance_F!(r::MersenneTwister, adv_vals, idxF, work)
if adv_vals == idxF == 0
# this case happens only when integer cache was generated before float cache
# then (0, 0) is printed instead of (-1, MT_CACHE_F) which is somewhat confusing;
# in this case, nothing to do, the float cache mustn't be filled
if r.adv_vals == -1 && r.idxF == MT_CACHE_F
return
else
throw(DomainError(n, "can't advance $r to the specified state"))
end
end
if r.adv_vals != adv_vals
_advance_to!(r, adv_vals, work)
gen_rand(r)
@assert r.adv_vals == adv_vals
end # otherwise, advancing was done automatically while generating the integer cache

_advance_to!(r, adv_vals, work)
gen_rand(r)
@assert r.adv_vals == adv_vals
r.idxF = idxF
nothing
end

function _advance_I!(r::MersenneTwister, adv_ints, idxI, work)
Expand All @@ -802,26 +806,34 @@ function _advance_I!(r::MersenneTwister, adv_ints, idxI, work)
r.idxI = 16*length(r.ints) - 8*idxI
end

function advance!(r::MersenneTwister, adv_jump, adv, adv_vals, idxF,
adv_ints, adv_vals_pre, adv_idxF_pre, idxI)
function advance!(r::MersenneTwister, adv_jump, adv, adv_vals, idxF, adv_ints, idxI)
adv_jump = BigInt(adv_jump)
adv, adv_vals, adv_ints, adv_vals_pre = Int64.((adv, adv_vals, adv_ints, adv_vals_pre))
idxF, adv_idxF_pre, idxI = Int.((idxF, adv_idxF_pre, idxI))
adv, adv_vals, adv_ints = Int64.((adv, adv_vals, adv_ints))
idxF, idxI = Int.((idxF, idxI))

ms = dsfmt_get_min_array_size() % Int
work = sizehint!(Vector{Float64}(), 2ms)
jump!(r, adv_jump)
if adv_vals_pre != -1
_advance_F!(r, adv_vals_pre, adv_idxF_pre, work)
_advance_I!(r, adv_ints, idxI, work)

@assert r.adv_vals_pre == adv_vals_pre ||
r.adv_vals_pre == -1 && adv_vals_pre == 0
@assert r.adv_idxF_pre == adv_idxF_pre ||
r.adv_idxF_pre == 1002 && adv_idxF_pre == 0
adv_jump != 0 && jump!(r, adv_jump)
advF = (adv_vals, idxF) != (0, 0)
advI = (adv_ints, idxI) != (0, 0)

if advI && advF
@assert adv_vals != adv_ints
if adv_vals < adv_ints
_advance_F!(r, adv_vals, idxF, work)
_advance_I!(r, adv_ints, idxI, work)
else
_advance_I!(r, adv_ints, idxI, work)
_advance_F!(r, adv_vals, idxF, work)
end
elseif advF
_advance_F!(r, adv_vals, idxF, work)
elseif advI
_advance_I!(r, adv_ints, idxI, work)
else
@assert adv == 0
end
_advance_F!(r, adv_vals, idxF, work)
_advance_to!(r, adv, work)
r
end
14 changes: 7 additions & 7 deletions stdlib/Random/src/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ julia> rng = MersenneTwister(1234);
julia> bitrand(rng, 10)
10-element BitVector:
0
1
1
1
1
0
0
0
1
0
0
0
1
1
```
"""
Expand All @@ -53,13 +53,13 @@ number generator, see [Random Numbers](@ref).
# Examples
```jldoctest
julia> Random.seed!(3); randstring()
"4zSHdXlw"
"Y7m62wOj"
julia> randstring(MersenneTwister(3), 'a':'z', 6)
"bzlhqn"
"ocucay"
julia> randstring("ACGT")
"AGGACATT"
"ATTTGCGT"
```
!!! note
Expand Down
36 changes: 20 additions & 16 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ let A = zeros(2, 2)
end
let A = zeros(2, 2)
@test_throws ArgumentError rand!(MersenneTwister(0), A, 5)
@test rand(MersenneTwister(0), Int64, 1) == [2118291759721269919]
@test rand(MersenneTwister(0), Int64, 1) == [-3433174948434291912]
end
let A = zeros(Int64, 2, 2)
rand!(MersenneTwister(0), A)
Expand Down Expand Up @@ -253,16 +253,20 @@ let mt = MersenneTwister(0)
end

Random.seed!(mt, 0)
Aend = Any[]
Bend = Any[]
for (i,T) in enumerate([Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, Float16, Float32])
A = Vector{T}(undef, 16)
B = Vector{T}(undef, 31)
rand!(mt, A)
rand!(mt, B)
@test A[end] == Any[21, 0x7b, 17385, 0x3086, -1574090021, 0xadcb4460, 6797283068698303107, 0xc8e6453e139271f3,
69855512850528774484795047199183096941, Float16(0.16895), 0.21086597f0][i]
@test B[end] == Any[49, 0x65, -3725, 0x719d, 814246081, 0xdf61843a, 2120308604158549401, 0xcb28c236e9c0f608,
61881313582466480231846019869039259750, Float16(0.38672), 0.20027375f0][i]
push!(Aend, A[end])
push!(Bend, B[end])
end
@test Aend == Any[21, 0x7b, 17385, 0x3086, -1574090021, 0xadcb4460, 6797283068698303107, 0x68a9f9865393cfd6,
33687499368208574024854346399216845930, Float16(0.7744), 0.97259974f0]
@test Bend == Any[49, 0x65, -3725, 0x719d, 814246081, 0xdf61843a, -3433174948434291912, 0xd461716f27c91500,
-85900088726243933988214632401750448432, Float16(0.10645), 0.13879478f0]

Random.seed!(mt, 0)
AF64 = Vector{Float64}(undef, Random.dsfmt_get_min_array_size()-1)
Expand Down Expand Up @@ -592,16 +596,16 @@ end
@test_throws DomainError DSFMT.DSFMT_state(zeros(Int32, rand(0:DSFMT.JN32-1)))

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0, 0, 0, -1, -1, -1, -1)
zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0, 0, 0, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0, 0, 0, -1, -1, -1, -1)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0, 0, 0, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0, 0, 0, -1, -1, -1, -1)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0, 0, 0, -1, -1)

@test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(),
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1, 0, 0, -1, -1, -1, -1)
zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1, 0, 0, -1, -1)

# seed is private to MersenneTwister
let seed = rand(UInt32, 10)
Expand Down Expand Up @@ -842,19 +846,19 @@ end
@test m == MersenneTwister(123, (200000000000000000000, 0))
rand(m)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 1002, 0, 1))"

@test m == MersenneTwister(123, (200000000000000000000, 1002, 0, 1))
rand(m, Int64)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 2002, 0, 255, 1002, 0, 1, 1))"
@test m == MersenneTwister(123, (200000000000000000000, 2002, 0, 255, 1002, 0, 1, 1))
@test string(m) == "MersenneTwister(123, (200000000000000000000, 2256, 0, 1, 1002, 1))"
@test m == MersenneTwister(123, (200000000000000000000, 2256, 0, 1, 1002, 1))

m = MersenneTwister(0x0ecfd77f89dcd508caa37a17ebb7556b)
@test string(m) == "MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b)"
rand(m, Int64)
@test string(m) == "MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 2002, 1000, 254, 0, 0, 0, 1))"
@test m == MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 2002, 1000, 254, 0, 0, 0, 1))
@test string(m) == "MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 1254, 0, 0, 0, 1))"
@test m == MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 1254, 0, 0, 0, 1))

# test when floats advancing is done by initializing ints, and (few) floats are then generated
m = MersenneTwister(0); rand(m, Int64); rand(m)
@test string(m) == "MersenneTwister(0, (0, 2002, 1000, 255, 0, 0, 0, 1))"
@test m == MersenneTwister(0, (0, 2002, 1000, 255, 0, 0, 0, 1))
@test string(m) == "MersenneTwister(0, (0, 2256, 1254, 1, 0, 1))"
@test m == MersenneTwister(0, (0, 2256, 1254, 1, 0, 1))
end
Loading

0 comments on commit b8bc816

Please sign in to comment.