diff --git a/examples/NanoGPT/Project.toml b/examples/NanoGPT/Project.toml index a4d4b40b66..cea35b0928 100644 --- a/examples/NanoGPT/Project.toml +++ b/examples/NanoGPT/Project.toml @@ -13,3 +13,19 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" + +[compat] +Comonicon = "1" +DataDeps = "0.7" +Enzyme = "0.13.14" +JLD2 = "0.5" +Lux = "1.2.3" +MLUtils = "0.4" +NNlib = "0.9.24" +OneHotArrays = "0.2.5" +Optimisers = "0.4.1" +Printf = "1.10" +Random = "1.10" +Reactant = "0.2.5" +Statistics = "1.10" +StatsBase = "0.34.3" diff --git a/examples/NanoGPT/README.md b/examples/NanoGPT/README.md new file mode 100644 index 0000000000..3b710a9134 --- /dev/null +++ b/examples/NanoGPT/README.md @@ -0,0 +1,58 @@ +# NanoGPT using Lux & Reactant + +## Requirements + +* Install [julia](https://julialang.org/) +* In the Julia REPL instantiate the `Project.toml` in the parent directory + +## Training + +To train a model, run `main.jl` with the necessary parameters. + +```bash +julia --startup=no --project=examples/NanoGPT --threads=auto examples/NanoGPT/main.jl +``` + +## Inference + +To run inference on a trained model, run `main.jl` with the necessary parameters. + +```bash +julia --startup=no --project=examples/NanoGPT --threads=auto examples/NanoGPT/main.jl \ + --inference \ + --model-path= +``` + +## Usage + +```bash + main + +Usage + + main [options] [flags] + +Options + + --n-embed <64::Int> + --n-hidden <256::Int> + --n-heads <4::Int> + --qk-dim <16::Int> + --v-dim <16::Int> + --n-layers <6::Int> + --sequence-length <64::Int> + --batchsize <128::Int> + --dropout-rate <0.0::Float32> + --test-split <0.1::Float64> + --lr <0.01::Float64> + --epochs <100::Int> + --model-path <::String> + --seed <::Union{String, Vector{String}}> + --output-length <1024::Int> + +Flags + + --inference + -h, --help Print this help message. + --version Print version. +``` diff --git a/examples/NanoGPT/main.jl b/examples/NanoGPT/main.jl index 6d6974a302..db9ebc100f 100644 --- a/examples/NanoGPT/main.jl +++ b/examples/NanoGPT/main.jl @@ -1,6 +1,6 @@ # Taken from https://github.com/FluxML/model-zoo/pull/410 using MLUtils, Lux, Random, Optimisers, Printf, Statistics, NNlib, DataDeps, StatsBase, - OneHotArrays + OneHotArrays, JLD2 using Reactant, Enzyme using Comonicon: @main @@ -51,42 +51,57 @@ function GPT(; token_embedding=Embedding(n_vocab => n_embed), position_embedding=Embedding(sequence_length => n_embed), drop=Dropout(dropout_rate), - blocks=ntuple(n_layers) do i + blocks=Chain(ntuple(n_layers) do i return gpt_block(; n_embed, n_hidden, qk_dim, v_dim, n_heads, dropout_rate) - end, + end...), ln=LayerNorm((n_embed, 1)), output_layer=Dense(n_embed => n_vocab)) do tokens - te = token_embedding(tokens) - pe = position_embedding(1:size(tokens, 1)) - x = drop(te .+ pe) - for blk in blocks - x = blk(x) - end - x = ln(x) - x = output_layer(x) - @return x + x = drop(token_embedding(tokens) .+ position_embedding(1:size(tokens, 1))) + x = blocks(x) + @return output_layer(ln(x)) end end # Use the model to generate some text. -# function generate(model, seed, outlen) -# seqlen = context_length(model) -# if isempty(seed) -# seed = "_" -# end -# x = map(c -> findfirst(==(c), model.alphabet)::Int64, collect(seed)) -# while length(x) < outlen -# tail = x[max(1, end-seqlen+1):end] -# tail = reshape(tail, length(tail), 1) -# y = model(tail |> device) |> cpu -# p = softmax(y[:,end,1]) -# j = sample(1:length(model.alphabet), Weights(p)) -# #j = argmax(p) -# #x = vcat(x, [j]) -# push!(x, j) -# end -# String(map(j -> model.alphabet[j], x)) -# end +function generate_text( + model, ps, st, seed; alphabet, output_length, sequence_length +) + dev = get_device((ps, st)) + @assert !(dev isa ReactantDevice) "Currently we don't support running inference of \ + dynamically sized tensors." + + seed = copy(seed) + seed_len = maximum(length, seed) + extra_letters = zeros(Int, length(seed)) + for (i, s) in enumerate(seed) + if seed_len != length(s) + extra_letters[i] = seed_len - length(s) + seed[i] = "_"^extra_letters[i] * s + end + end + original_output_length = output_length + output_length += maximum(extra_letters) + + st = Lux.testmode(st) + + x = zeros(Int, output_length, length(seed)) + for (i, s) in enumerate(seed), j in 1:seed_len + x[j, i] = findfirst(==(s[j]), alphabet) + end + for i in (seed_len + 1):output_length + tail = x[max(1, i - sequence_length + 1):(i - 1), :] |> dev + y = model(tail, ps, st)[1] |> cpu_device() + p = softmax(y[:, end, 1]) + x[i, :] .= sample(1:length(alphabet), Weights(p)) + end + + res = [String(map(Base.Fix1(getindex, alphabet), x[:, i])) for i in axes(x, 2)] + for i in eachindex(res) + res[i] = res[i][(extra_letters[i] + 1):end][1:original_output_length] + end + + return res +end # Load data from input file, and partition into training and testing subsets. function get_nanogpt_data(; sequence_length, test_split) @@ -121,32 +136,62 @@ function get_nanogpt_data(; sequence_length, test_split) return alphabet, Array(trainX), Array(trainY), Array(testX), Array(testY) end -@main function train_nanogpt(; +@main function main(; n_embed::Int=64, n_hidden::Int=256, n_heads::Int=4, qk_dim::Int=16, v_dim::Int=16, n_layers::Int=6, sequence_length::Int=64, batchsize::Int=128, dropout_rate::Float32=0.0f0, test_split::Float64=0.1, lr::Float64=1e-2, - epochs::Int=20 + epochs::Int=100, + # Only inference options + inference::Bool=false, model_path::String="", + seed::Union{String, Vector{String}}=["_", "The", "Julia", "Lux.jl"], + output_length::Int=1024 ) - alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split) - - @printf "[Info] Alphabet size: %d\n" length(alphabet) - @printf "[Info] Training size: %d sequences.\n" size(trainX, 2) - @printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2) - rng = Random.default_rng() Random.seed!(rng, 1234) dev = reactant_device() cdev = cpu_device() + if inference + @printf "[Info] Inference mode enabled.\n" + + @assert !isempty(model_path) "Please provide a path to a model checkpoint." + + @printf "[Info] Loading model from %s.\n" model_path + model_config = JLD2.load(model_path, "model_config") + model = GPT(; model_config...) + ps = JLD2.load(model_path, "parameters") + st = JLD2.load(model_path, "states") + alphabet = JLD2.load(model_path, "alphabet") + sequence_length = model_config.sequence_length + + texts = generate_text( + model, ps, st, seed; alphabet, output_length, sequence_length + ) + + for (i, (text, s)) in enumerate(zip(texts, seed)) + @printf "[Info] Seed [%d]: %s\n" i s + @printf "[Generated Text] %s\n\n" text + end + + return + end + + alphabet, trainX, trainY, testX, testY = get_nanogpt_data(; sequence_length, test_split) + + @printf "[Info] Alphabet size: %d\n" length(alphabet) + @printf "[Info] Training size: %d sequences.\n" size(trainX, 2) + @printf "[Info] Testing size: %d sequences.\n\n" size(testX, 2) + train_loader = DataLoader( (trainX, trainY); batchsize, shuffle=true, parallel=true ) |> dev - model = GPT(; + model_config = (; n_vocab=length(alphabet), n_embed, sequence_length, n_hidden, n_layers, dropout_rate, n_heads, qk_dim, v_dim ) + model = GPT(; model_config...) ps, st = Lux.setup(rng, model) |> dev @printf "[Info] Number of parameters: %d\n" Lux.parameterlength(ps) @printf "[Info] Number of states: %d\n\n" Lux.statelength(st) @@ -156,9 +201,12 @@ end @printf "[Info] Compiling Inference Model...\n" testX, testY = (testX, testY) |> dev + start_time = time() model_compiled = @compile model(testX, ps, Lux.testmode(st)) + time_to_compile = time() - start_time best_test_loss = Inf + @printf "[Info] Time taken to compile inference model: %0.5fs\n" time_to_compile @printf "[Info] Starting Model Training...\n\n" loss_fn = CrossEntropyLoss(; logits=Val(true)) @@ -185,7 +233,15 @@ end ) @printf "[Test] Epoch %3d\tTest Loss %.8e\n" epoch test_loss - # XXX: Also generate some text here... + # Generate some text here... + texts = generate_text( + model, ps |> cdev, st |> cdev, seed; + alphabet, output_length, sequence_length + ) + for (i, (text, s)) in enumerate(zip(texts, seed)) + @printf "[Info] Seed [%d]: %s\n" i s + @printf "[Generated Text] %s\n\n" text + end if test_loss < best_test_loss best_test_loss = test_loss @@ -195,26 +251,9 @@ end joinpath(@__DIR__, "nanogpt.jld2"); parameters=train_state.parameters |> cdev, states=train_state.states |> cdev, - alphabet=alphabet + alphabet=alphabet, + model_config=model_config ) end end end - -# # Load a model from a checkpoint (see `jldsave` above). -# function load_model(filename) -# args = JLD2.load(filename, "args") -# alphabet = JLD2.load(filename, "alphabet") -# model = GPT(args, alphabet) -# model_state = JLD2.load(filename, "model_state") -# model = Flux.loadmodel!(model, model_state); -# return args, model -# end - -# if true -# args, model = train() -# else -# args, model = load_model("model-checkpoint.jld2") |> device -# end - -# generate(model, "The", 50)