Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parallel and shuffle support to eachobs and DataLoader #82

Merged
merged 11 commits into from
May 27, 2022
1 change: 1 addition & 0 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ include("eachobs.jl")
export eachobs

include("parallel.jl")
include("reshuffle.jl")

include("dataloader.jl")
export DataLoader
Expand Down
100 changes: 85 additions & 15 deletions src/eachobs.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
"""
eachobs(data, buffer=false, batchsize=-1, partial=true)
eachobs(data; [buffer, batchsize, partial, parallel, shuffle])

Return an iterator over the observations in `data`.
Return an iterator over the observations in `data`.

# Arguments

- `data`. The data to be iterated over. The data type has to implement
[`numobs`](@ref) and [`getobs`](@ref).
- `buffer`. If `buffer=true` and supported by the type of `data`,
- `data`. The data to be iterated over. The data type has to be supported by
[`numobs`](@ref) and [`getobs`](@ref).
- `buffer`. If `buffer=true` and supported by the type of `data`,
a buffer will be allocated and reused for memory efficiency.
You can also pass a preallocated object to `buffer`.
- `batchsize`. If less than 0, iterates over individual observation.
Otherwise, each iteration (except possibly the last) yields a mini-batch
- `batchsize`. If less than 0, iterates over individual observation.
Otherwise, each iteration (except possibly the last) yields a mini-batch
containing `batchsize` observations.
- `partial`. This argument is used only when `batchsize > 0`.
If `partial=false` and the number of observations is not divisible by the batchsize,
If `partial=false` and the number of observations is not divisible by the batchsize,
then the last mini-batch is dropped.
- `parallel=false`. Whether to use load data in parallel using worker threads. Greatly
speeds up data loading by factor of available threads. Requires starting
Julia with multiple threads. Check `Threads.nthreads()` to see the number of
available threads. **Passing `parallel = true` breaks ordering guarantees**
- `shuffle = false`: Whether to shuffle the observations before iterating. Unlike
wrapping the data container with `shuffleobs(data)`, `shuffle = true` ensures
that the observations are shuffled anew every time you start iterating over
`eachobs`.

See also [`numobs`](@ref), [`getobs`](@ref).

Expand All @@ -42,18 +50,80 @@ for (x, y) in eachobs((X, Y))
end
```
"""
function eachobs(data; buffer = false, batchsize::Int = -1, partial::Bool =true)
function eachobs(
data;
buffer = false,
parallel = false,
shuffle = false,
batchsize::Int = -1,
partial::Bool = true,
executor = _default_executor())
if batchsize > 0
data = BatchView(data; batchsize, partial)
end
if buffer === false
gen = (getobs(data, i) for i in 1:numobs(data))

iter = if parallel
eachobsparallel(data; buffer, executor)
else
if buffer === true && numobs(data) > 0
buffer = getobs(data, 1)
if buffer === false
EachObs(data)
elseif buffer === true
EachObsBuffer(data)
else
EachObsBuffer(data, buffer)
end
gen = (getobs!(buffer, data, i) for i in 1:numobs(data))
end
return gen

if shuffle
iter = ReshuffleIter(iter)
end
return iter
end


# Internal

"""
EachObs(data)

Create an iterator over observations in data container `data`.

This is an internal function. Use `eachobs(data)` instead.
"""
struct EachObs{T}
data::T
end
Base.length(iter::EachObs) = numobs(iter.data)
Base.eltype(iter::EachObs) = eltype(iter.data)

function Base.iterate(iter::EachObs, i::Int = 1)
i > numobs(iter) && return nothing
return getobs(iter.data, i), i+1
end


"""
EachObsBuffer(data)

Create a buffered iterator over observations in data container `data`.
Buffering only works if [`getobs!`](@ref) is implemented for `data`.

This is an internal function. Use `eachobs(data, buffer = true)` instead.
"""
struct EachObsBuffer{T, B}
data::T
buffer::B
end
EachObsBuffer(data) = EachObsBuffer(data, getobs(data, 1))
Base.length(iter::EachObsBuffer) = numobs(iter.data)
Base.eltype(iter::EachObsBuffer) = eltype(iter.data)

function Base.iterate(iter::EachObsBuffer)
obs = getobs!(iter.buffer, iter.data, 1)
return obs, (obs, 2)
end

function Base.iterate(iter::EachObsBuffer, (buffer, i))
i > numobs(iter) && return nothing
return getobs!(buffer, iter.data, i), (buffer, i+1)
end
19 changes: 12 additions & 7 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ to the number of physical CPU cores.
this if you need the additional performance and `getobs!` is implemented for
`data`. Setting `buffer = true` means that when using the iterator, an
observation is only valid for the current loop iteration.
You can also pass in a preallocated `buffer = getobs(data, 1)`.
- `executor = Folds.ThreadedEx()`: task scheduler
You may specify a different task scheduler which can
be any `Folds.Executor`.
Expand All @@ -30,19 +31,23 @@ function eachobsparallel(
executor::Executor = _default_executor(),
buffer = false,
channelsize = Threads.nthreads())
if buffer
if buffer === false
return _eachobsparallel_unbuffered(data, executor; channelsize)
elseif buffer === true
return _eachobsparallel_buffered(data, executor; channelsize)
else
return _eachobsparallel_unbuffered(data, executor; channelsize)
return _eachobsparallel_buffered(data, executor; channelsize, buffer)
end
end


function _eachobsparallel_buffered(data, executor; channelsize=Threads.nthreads())
# Prepare initial buffers
buf = getobs(data, 1)
buffers = [buf]
foreach(_ -> push!(buffers, deepcopy(buf)), 1:channelsize)
function _eachobsparallel_buffered(
data,
executor;
buffer = getobs(data, 1),
channelsize=Threads.nthreads())
buffers = [buffer]
foreach(_ -> push!(buffers, deepcopy(buffer)), 1:channelsize)

# This ensures the `Loader` will take from the `RingBuffer`s result
# channel, and that a new results channel is created on repeated
Expand Down
36 changes: 36 additions & 0 deletions src/reshuffle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

"""
ReshuffleIter(iter)

Wrap a data iterator `iter` created by [`eachobs`](@ref) so that it is
shuffled anew every time it is iterated over.

This is an internal function. Use `eachobs(data; shuffle = true)` instead.
"""
struct ReshuffleIter{T}
iter::T
end
Base.length(iter::ReshuffleIter) = length(iter.iter)
Base.eltype(iter::ReshuffleIter) = eltype(iter.iter)

function Base.iterate(re::ReshuffleIter)
iter = reshuffle(re.iter)
el, state = iterate(iter)
return el, (iter, state)
end
function Base.iterate(::ReshuffleIter, (iter, state))
ret = iterate(iter, state)
isnothing(ret) && return ret
el, state = ret
return el, (iter, state)
end

reshuffle(iter::EachObs) = EachObs(shuffleobs(iter.data))
reshuffle(iter::EachObsBuffer) = EachObsBuffer(shuffleobs(iter.data))
reshuffle(iter::MLUtils.Loader) = MLUtils.Loader(
iter.f,
collect(shuffleobs(iter.argiter)), # shuffles the indices
iter.executor,
iter.channelsize,
iter.setup_channel,
)
29 changes: 29 additions & 0 deletions test/eachobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,33 @@
@test x == X[:,2i-1:2i]
end
end

@testset "shuffled" begin
# does not reshuffle on iteration
shuffled = eachobs(shuffleobs(1:50))
@test collect(shuffled) == collect(shuffled)

# does reshuffle
reshuffled = eachobs(1:50, shuffle = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, buffer = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, parallel = true)
@test collect(reshuffled) != collect(reshuffled)

reshuffled = eachobs(1:50, shuffle = true, buffer = true, parallel = true)
@test collect(reshuffled) != collect(reshuffled)
end
@testset "Argument combinations" begin
for batchsize ∈ (-1, 2), buffer ∈ (true, false, getobs(X, 1)),
parallel ∈ (true, false), shuffle ∈ (true, false), partial ∈ (true, false)
if !(buffer isa Bool) && batchsize > 0
buffer = getobs(BatchView(X; batchsize), 1)
end
iter = eachobs(X; batchsize, shuffle, buffer, parallel, partial)
@test_nowarn for _ in iter end
end
end
end