diff --git a/stdlib/Future/src/Future.jl b/stdlib/Future/src/Future.jl index 1d70dba7c84de..746f6e149a47d 100644 --- a/stdlib/Future/src/Future.jl +++ b/stdlib/Future/src/Future.jl @@ -36,7 +36,10 @@ One such step corresponds to the generation of two `Float64` numbers. For each different value of `steps`, a large polynomial has to be generated internally. One is already pre-computed for `steps=big(10)^20`. """ -randjump(r::MersenneTwister, steps::Integer) = - Random._randjump(r, Random.DSFMT.calc_jump(steps)) +function randjump(r::MersenneTwister, steps::Integer) + j = Random._randjump(r, Random.DSFMT.calc_jump(steps)) + j.adv_jump += 2*big(steps) # convert to BigInt to prevent overflow + j +end end # module Future diff --git a/stdlib/Random/src/RNGs.jl b/stdlib/Random/src/RNGs.jl index e9b7f152ef4fe..1312fb0d6e7ed 100644 --- a/stdlib/Random/src/RNGs.jl +++ b/stdlib/Random/src/RNGs.jl @@ -85,14 +85,24 @@ mutable struct MersenneTwister <: AbstractRNG idxF::Int idxI::Int - function MersenneTwister(seed, state, vals, ints, idxF, idxI) + # counters for show + adv::Int64 # state of advance at the DSFMT_state level + adv_jump::BigInt # number of skipped Float64 values via randjump + adv_vals::Int64 # state of advance when vals is filled-up + adv_ints::Int64 # state of advance when ints is filled-up + adv_vals_pre::Int64 # state of advance when vals is filled-up before ints + adv_idxF_pre::Int # value of idxF before ints is filled-up + + function MersenneTwister(seed, state, vals, ints, idxF, idxI, + adv, adv_jump, adv_vals, adv_ints, adv_vals_pre, adv_idxF_pre) length(vals) == MT_CACHE_F && 0 <= idxF <= MT_CACHE_F || throw(DomainError((length(vals), idxF), "`length(vals)` and `idxF` must be consistent with $MT_CACHE_F")) length(ints) == MT_CACHE_I >> 4 && 0 <= idxI <= MT_CACHE_I || throw(DomainError((length(ints), idxI), "`length(ints)` and `idxI` must be consistent with $MT_CACHE_I")) - new(seed, state, vals, ints, idxF, idxI) + new(seed, state, vals, ints, idxF, idxI, + adv, adv_jump, adv_vals, adv_ints, adv_vals_pre, adv_idxF_pre) end end @@ -100,7 +110,7 @@ MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) = MersenneTwister(seed, state, Vector{Float64}(undef, MT_CACHE_F), Vector{UInt128}(undef, MT_CACHE_I >> 4), - MT_CACHE_F, 0) + MT_CACHE_F, 0, 0, 0, -1, -1, -1, -1) """ MersenneTwister(seed) @@ -147,12 +157,19 @@ function copy!(dst::MersenneTwister, src::MersenneTwister) copyto!(dst.ints, src.ints) dst.idxF = src.idxF dst.idxI = src.idxI + dst.adv = src.adv + dst.adv_jump = src.adv_jump + dst.adv_vals = src.adv_vals + dst.adv_ints = src.adv_ints + dst.adv_vals_pre = src.adv_vals_pre + dst.adv_idxF_pre = src.adv_idxF_pre dst end copy(src::MersenneTwister) = MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints), - src.idxF, src.idxI) + src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints, + src.adv_vals_pre, src.adv_idxF_pre) ==(r1::MersenneTwister, r2::MersenneTwister) = @@ -164,17 +181,47 @@ copy(src::MersenneTwister) = hash(r::MersenneTwister, h::UInt) = foldr(hash, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI); init=h) -function fillcache_zeros!(r::MersenneTwister) - # the use of this function is not strictly necessary, but it makes - # comparing two MersenneTwister RNGs easier +function show(io::IO, rng::MersenneTwister) + # seed + seed = from_seed(rng.seed) + seed_str = seed <= typemax(Int) ? string(seed) : "0x" * string(seed, base=16) # DWIM + if rng.adv_jump == 0 && rng.adv == 0 + return print(io, "MersenneTwister($seed_str)") + end + print(io, "MersenneTwister($seed_str, (") + # state + adv = Integer[rng.adv_jump, rng.adv] + if rng.adv_vals != -1 + push!(adv, rng.adv_vals, rng.idxF) + end + if rng.adv_ints != -1 # then rng.adv_vals is always != -1 + idxI = (length(rng.ints)*16 - rng.idxI) / 8 # 8 represents one Int64 + idxI = Int(idxI) # idxI should always be an integer when using public APIs + push!(adv, + rng.adv_ints, + rng.adv_vals_pre == -1 ? 0 : rng.adv_vals_pre, + rng.adv_vals_pre == -1 ? 0 : rng.adv_idxF_pre, + idxI) + end + join(io, adv, ", ") + print(io, "))") +end + +### low level API + +function reset_caches!(r::MersenneTwister) + # zeroing the caches makes comparing two MersenneTwister RNGs easier fill!(r.vals, 0.0) fill!(r.ints, zero(UInt128)) + mt_setempty!(r) + mt_setempty!(r, UInt128) + r.adv_vals = -1 + r.adv_ints = -1 + r.adv_vals_pre = -1 + r.adv_idxF_pre = -1 r end - -### low level API - #### floats mt_avail(r::MersenneTwister) = MT_CACHE_F - r.idxF @@ -184,7 +231,8 @@ mt_setempty!(r::MersenneTwister) = r.idxF = MT_CACHE_F mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idxF+=1] function gen_rand(r::MersenneTwister) - GC.@preserve r dsfmt_fill_array_close1_open2!(r.state, pointer(r.vals), length(r.vals)) + r.adv_vals = r.adv + GC.@preserve r fill_array!(r, pointer(r.vals), length(r.vals), CloseOpen12()) mt_setfull!(r) end @@ -212,6 +260,9 @@ mt_avail(r::MersenneTwister, ::Type{T}) where {T<:BitInteger} = r.idxI >> logsizeof(T) function mt_setfull!(r::MersenneTwister, ::Type{<:BitInteger}) + r.adv_ints = r.adv + r.adv_vals_pre = r.adv_vals + r.adv_idxF_pre = r.idxF rand!(r, r.ints) r.idxI = MT_CACHE_I end @@ -275,14 +326,18 @@ function make_seed(n::Integer) end end +# inverse of make_seed(::Integer) +from_seed(a::Vector{UInt32})::BigInt = sum(a[i] * big(2)^(32*(i-1)) for i in 1:length(a)) + + #### seed!() function seed!(r::MersenneTwister, seed::Vector{UInt32}) copyto!(resize!(r.seed, length(seed)), seed) dsfmt_init_by_array(r.state, r.seed) - mt_setempty!(r) - mt_setempty!(r, UInt128) - fillcache_zeros!(r) + reset_caches!(r) + r.adv = 0 + r.adv_jump = 0 return r end @@ -464,6 +519,10 @@ function _rand_max383!(r::MersenneTwister, A::UnsafeView{Float64}, I::FloatInter A end +function fill_array!(rng::MersenneTwister, A::Ptr{Float64}, n::Int, I) + rng.adv += n + fill_array!(rng.state, A, n, I) +end fill_array!(s::DSFMT_state, A::Ptr{Float64}, n::Int, ::CloseOpen01_64) = dsfmt_fill_array_close_open!(s, A, n) @@ -488,10 +547,10 @@ function rand!(r::MersenneTwister, A::UnsafeView{Float64}, align = Csize_t(pA) % 16 if align > 0 pA2 = pA + 16 - align - fill_array!(r.state, pA2, n2, I[]) # generate the data in-place, but shifted + fill_array!(r, pA2, n2, I[]) # generate the data in-place, but shifted unsafe_copyto!(pA, pA2, n2) # move the data to the beginning of the array else - fill_array!(r.state, pA, n2, I[]) + fill_array!(r, pA, n2, I[]) end for i=n2+1:n A[i] = rand(r, I[]) @@ -653,5 +712,12 @@ end # Old randjump methods are deprecated, the scalar version is in the Future module. -_randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X) = - fillcache_zeros!(MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly))) +function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X) + adv = r.adv + adv_jump = r.adv_jump + s = MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly)) + reset_caches!(s) + s.adv = adv + s.adv_jump = adv_jump + s +end diff --git a/stdlib/Random/src/Random.jl b/stdlib/Random/src/Random.jl index 5197ac1c34e7b..2cdffd6067252 100644 --- a/stdlib/Random/src/Random.jl +++ b/stdlib/Random/src/Random.jl @@ -17,7 +17,7 @@ using Base.GMP: Limb using Base: BitInteger, BitInteger_types, BitUnsigned, require_one_based_indexing import Base: copymutable, copy, copy!, ==, hash, convert, - rand, randn + rand, randn, show export rand!, randn!, randexp, randexp!, diff --git a/stdlib/Random/test/runtests.jl b/stdlib/Random/test/runtests.jl index 2aeea0f623877..7032a5228c11b 100644 --- a/stdlib/Random/test/runtests.jl +++ b/stdlib/Random/test/runtests.jl @@ -592,16 +592,16 @@ end @test_throws DomainError DSFMT.DSFMT_state(zeros(Int32, rand(0:DSFMT.JN32-1))) @test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0) + zeros(Float64, 10), zeros(UInt128, MT_CACHE_I>>4), 0, 0, 0, 0, -1, -1, -1, -1) @test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0) + zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), -1, 0, 0, 0, -1, -1, -1, -1) @test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0) + zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>3), 0, 0, 0, 0, -1, -1, -1, -1) @test_throws DomainError MersenneTwister(zeros(UInt32, 1), DSFMT.DSFMT_state(), - zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1) + zeros(Float64, MT_CACHE_F), zeros(UInt128, MT_CACHE_I>>4), 0, -1, 0, 0, -1, -1, -1, -1) # seed is private to MersenneTwister let seed = rand(UInt32, 10)