forked from google/flax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecode.py
141 lines (124 loc) · 5.61 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2020 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Fast decoding routines for inference from a trained language model."""
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import numpy as np
def multinomial(rng, logits):
"""Draws samples from a multinomial distribution given by logits.
Args:
rng: A JAX PRNGKey.
logits: array with unnormalized log-probabilities in last axis.
Returns:
Array with sampled categories in last axis.
"""
probs = jax.nn.softmax(logits)
cum_probs = jnp.cumsum(probs, axis=-1)
uniform_variates = jax.random.uniform(rng, logits.shape[:-1] + (1,))
return jnp.argmin(uniform_variates > cum_probs, axis=-1)
def top_k(x, k):
"""Select the top k slices from the last dimension."""
bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape)
sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs)
topk_vals = lax.slice_in_dim(sorted_vals, -k, sorted_vals.shape[-1], axis=-1)
topk_idxs = lax.slice_in_dim(sorted_idxs, -k, sorted_idxs.shape[-1], axis=-1)
return topk_vals, topk_idxs
def temperature_sample(prompt_inputs,
init_cache,
tokens_to_logits,
prng_key,
temperature=1.0,
topk=20,
eos_token=1):
"""Temperature sampling for language model generation.
Args:
prompt_inputs: array: [batch_size, max_decode_len] int32 sequence of tokens.
init_cache: flax attention cache.
tokens_to_logits: fast autoregressive decoder function taking single token
slices and cache and returning next-token logits and updated cache.
prng_key: JAX PRNGKey.
temperature: float: sampling temperature factor. As it approaches
zero this becomes equivalent to greedy sampling.
topk: integer: if nonzero only use the top-k logits to sample next token,
if zero don't use any cutoff and sample from full logits over vocabulary.
eos_token: int: end-of-sentence token for target vocabulary.
Returns:
Array of sampled sequences: [batch_size, max_decode_len]
"""
batch_size = prompt_inputs.shape[0]
max_decode_len = prompt_inputs.shape[1]
end_marker = jnp.array(eos_token)
temperature = jnp.array(temperature)
# Initialize sampling loop state.
# initial loop PRNGKey
rng0 = prng_key
# loop position counter.
i0 = jnp.array(0)
# per batch-item holding current token in loop.
token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32)
# per batch-item state bit indicating if sentence has finished.
ended0 = jnp.zeros((batch_size, 1), dtype=jnp.bool_)
# (batch, length) array containing prefix prompt tokens for sampling loop
# as well as the generated output of newly sampled tokens.
sequences0 = prompt_inputs
# Sampling loop state is stored in a simple tuple.
sampling_loop_init_state = (i0, sequences0, init_cache, token0, ended0, rng0)
def sampling_loop_cond_fn(state):
"""Sampling loop termination condition."""
(i, _, _, _, ended, _) = state
# Have we reached max decoding length?
not_at_end = (i <= max_decode_len)
# Have all sampled sequences reached an end marker?
all_sequences_ended = jnp.all(ended)
return not_at_end & (~all_sequences_ended)
def sampling_loop_body_fn(state):
"""Sampling loop state update."""
i, sequences, cache, cur_token, ended, rng = state
# Split RNG for sampling.
rng1, rng2 = random.split(rng)
# Call fast-decoder model on current tokens to get next-position logits.
logits, new_cache = tokens_to_logits(cur_token, cache)
# Sample next token from logits.
# TODO(levskaya): add top-p "nucleus" sampling option.
if topk:
# Get top-k logits and their indices, sample within these top-k tokens.
topk_logits, topk_idxs = top_k(logits, topk)
topk_token = jnp.expand_dims(multinomial(
rng1, topk_logits / temperature).astype(jnp.int32), axis=-1)
# Return the original indices corresponding to the sampled top-k tokens.
next_token = jnp.squeeze(
jnp.take_along_axis(topk_idxs, topk_token, axis=-1), axis=-1)
else:
next_token = multinomial(rng1, logits / temperature).astype(jnp.int32)
# Only use sampled tokens if we're past provided prefix tokens.
out_of_prompt = (sequences[:, i+1] == 0)
next_token = (next_token * out_of_prompt +
sequences[:, i+1] * ~out_of_prompt)
# If end-marker reached for batch item, only emit padding tokens.
next_token_or_endpad = next_token * ~ended
ended |= (next_token_or_endpad == end_marker)
# Add current sampled tokens to recorded sequences.
new_sequences = lax.dynamic_update_slice(
sequences, next_token_or_endpad, (0, i+1))
return (i+1, new_sequences, new_cache, next_token_or_endpad, ended, rng2)
# Run sampling loop and collect final state.
final_state = lax.while_loop(sampling_loop_cond_fn,
sampling_loop_body_fn,
sampling_loop_init_state)
# Pick part of the state corresponding to the sampled sequences.
final_sequences = final_state[1]
return final_sequences