Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Carefully move some internal methods of recurrent.py to a dedicated utils file #178

Open
wants to merge 12 commits into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sonnet/src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,15 @@ snt_py_test(
],
)

snt_py_library(
name = "recurrent_internals",
srcs = ["recurrent_internals.py"],
deps = [
# pip: tensorflow
# pip: tree
],
)

snt_py_library(
name = "recurrent",
srcs = ["recurrent.py"],
Expand All @@ -377,6 +386,7 @@ snt_py_library(
":once",
":types",
":utils",
":recurrent_internals"
# pip: six
# pip: tensorflow
# pip: tree
Expand Down
196 changes: 16 additions & 180 deletions sonnet/src/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from __future__ import print_function

import abc
import collections
import functools
import uuid

import six
from sonnet.src import base
Expand All @@ -32,19 +30,13 @@
from sonnet.src import once
from sonnet.src import types
from sonnet.src import utils
from sonnet.src import recurrent_internals

import tensorflow.compat.v1 as tf1
import tensorflow as tf
import tree

from typing import Optional, Sequence, Text, Tuple, Union

# pylint: disable=g-direct-tensorflow-import
# Required for specializing `UnrolledLSTM` per device.
from tensorflow.python import context as context_lib
from tensorflow.python.eager import function as function_lib
# pylint: enable=g-direct-tensorflow-import


@six.add_metaclass(abc.ABCMeta)
class RNNCore(base.Module):
Expand Down Expand Up @@ -260,7 +252,7 @@ def static_unroll(
ValueError: If ``input_sequence`` is empty or its leading dimension is
not known statically.
"""
num_steps, input_tas = _unstack_input_sequence(input_sequence)
num_steps, input_tas = recurrent_internals.unstack_input_sequence(input_sequence)
if not isinstance(num_steps, six.integer_types):
raise ValueError(
"input_sequence must have a statically known number of time steps")
Expand All @@ -269,7 +261,7 @@ def static_unroll(
state = initial_state
output_accs = None
for t in six.moves.range(num_steps):
outputs, state = _rnn_step(
outputs, state = recurrent_internals.rnn_step(
core,
input_tas,
sequence_length,
Expand Down Expand Up @@ -360,10 +352,10 @@ def dynamic_unroll(
Raises:
ValueError: If ``input_sequence`` is empty.
"""
num_steps, input_tas = _unstack_input_sequence(input_sequence)
num_steps, input_tas = recurrent_internals.unstack_input_sequence(input_sequence)

# Unroll the first time step separately to infer outputs structure.
outputs, state = _rnn_step(
outputs, state = recurrent_internals.rnn_step(
core,
input_tas,
sequence_length,
Expand All @@ -380,7 +372,7 @@ def dynamic_unroll(
parallel_iterations=parallel_iterations,
swap_memory=swap_memory,
maximum_iterations=num_steps - 1)
outputs, state = _rnn_step(
outputs, state = recurrent_internals.rnn_step(
core,
input_tas,
sequence_length,
Expand All @@ -394,77 +386,6 @@ def dynamic_unroll(
return output_sequence, state


def _unstack_input_sequence(input_sequence):
r"""Unstacks the input sequence into a nest of :tf:`TensorArray`\ s.

This allows to traverse the input sequence using :tf:`TensorArray.read`
instead of a slice, avoiding O(sliced tensor) slice gradient
computation during the backwards pass.

Args:
input_sequence: See :func:`dynamic_unroll` or :func:`static_unroll`.

Returns:
num_steps: Number of steps in the input sequence.
input_tas: An arbitrarily nested structure of :tf:`TensorArray`\ s of
size ``num_steps``.

Raises:
ValueError: If tensors in ``input_sequence`` have inconsistent number
of steps or the number of steps is 0.
"""
flat_input_sequence = tree.flatten(input_sequence)
all_num_steps = {i.shape[0] for i in flat_input_sequence}
if len(all_num_steps) > 1:
raise ValueError(
"input_sequence tensors must have consistent number of time steps")
[num_steps] = all_num_steps
if num_steps == 0:
raise ValueError("input_sequence must have at least a single time step")
elif num_steps is None:
# Number of steps is not known statically, fall back to dynamic shape.
num_steps = tf.shape(flat_input_sequence[0])[0]
# TODO(b/141910613): uncomment when the bug is fixed.
# for i in flat_input_sequence[1:]:
# tf.debugging.assert_equal(
# tf.shape(i)[0], num_steps,
# "input_sequence tensors must have consistent number of time steps")

input_tas = tree.map_structure(
lambda i: tf.TensorArray(i.dtype, num_steps).unstack(i), input_sequence)
return num_steps, input_tas


def _safe_where(condition, x, y): # pylint: disable=g-doc-args
"""`tf.where` which allows scalar inputs."""
if x.shape.rank == 0:
# This is to match the `tf.nn.*_rnn` behavior. In general, we might
# want to branch on `tf.reduce_all(condition)`.
return y
# TODO(tomhennigan) Broadcasting with SelectV2 is currently broken.
return tf1.where(condition, x, y)


def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_state):
"""Performs a single RNN step optionally accounting for variable length."""
outputs, state = core(
tree.map_structure(lambda i: i.read(t), input_tas), prev_state)

if prev_outputs is None:
assert t == 0
prev_outputs = tree.map_structure(tf.zeros_like, outputs)

# TODO(slebedev): do not go into this block if t < min_len.
if sequence_length is not None:
# Selectively propagate outputs/state to the not-yet-finished
# sequences.
maybe_propagate = functools.partial(_safe_where, t >= sequence_length)
outputs = tree.map_structure(maybe_propagate, prev_outputs, outputs)
state = tree.map_structure(maybe_propagate, prev_state, state)

return outputs, state


class VanillaRNN(RNNCore):
"""Basic fully-connected RNN core.

Expand Down Expand Up @@ -545,7 +466,7 @@ def initial_state(self, batch_size: int) -> tf.Tensor:

@once.once
def _initialize(self, inputs: tf.Tensor):
dtype = _check_inputs_dtype(inputs, self._dtype)
dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype)
self._b = tf.Variable(self._b_init([self._hidden_size], dtype), name="b")


Expand Down Expand Up @@ -743,7 +664,7 @@ def deep_rnn_with_residual_connections(
name=name)


LSTMState = collections.namedtuple("LSTMState", ["hidden", "cell"])
LSTMState = recurrent_internals.LSTMState


class LSTM(RNNCore):
Expand Down Expand Up @@ -840,7 +761,7 @@ def __init__(self,
def __call__(self, inputs, prev_state):
"""See base class."""
self._initialize(inputs)
return _lstm_fn(inputs, prev_state, self._w_i, self._w_h, self.b,
return recurrent_internals.lstm_fn(inputs, prev_state, self._w_i, self._w_h, self.b,
self.projection)

def initial_state(self, batch_size: int) -> LSTMState:
Expand All @@ -861,7 +782,7 @@ def hidden_to_hidden(self):
def _initialize(self, inputs):
utils.assert_rank(inputs, 2)
input_size = inputs.shape[1]
dtype = _check_inputs_dtype(inputs, self._dtype)
dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype)

w_i_init = self._w_i_init or initializers.TruncatedNormal(
stddev=1.0 / tf.sqrt(tf.cast(input_size, dtype)))
Expand Down Expand Up @@ -890,25 +811,6 @@ def _initialize(self, inputs):
name="projection")


def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None):
"""Compute one step of an LSTM."""
gates_x = tf.matmul(inputs, w_i)
gates_h = tf.matmul(prev_state.hidden, w_h)
gates = gates_x + gates_h + b

# i = input, f = forget, g = cell updates, o = output.
i, f, g, o = tf.split(gates, num_or_size_splits=4, axis=1)

next_cell = tf.sigmoid(f) * prev_state.cell
next_cell += tf.sigmoid(i) * tf.tanh(g)
next_hidden = tf.sigmoid(o) * tf.tanh(next_cell)

if projection is not None:
next_hidden = tf.matmul(next_hidden, projection)

return next_hidden, LSTMState(hidden=next_hidden, cell=next_cell)


class UnrolledLSTM(UnrolledRNN):
"""Unrolled long short-term memory (LSTM).

Expand Down Expand Up @@ -975,7 +877,7 @@ def hidden_to_hidden(self):
def _initialize(self, input_sequence):
utils.assert_rank(input_sequence, 3) # [num_steps, batch_size, input_size].
input_size = input_sequence.shape[2]
dtype = _check_inputs_dtype(input_sequence, self._dtype)
dtype = recurrent_internals.check_inputs_dtype(input_sequence, self._dtype)

w_i_init = self._w_i_init or initializers.TruncatedNormal(
stddev=1.0 / tf.sqrt(tf.cast(input_size, dtype)))
Expand All @@ -992,69 +894,10 @@ def _initialize(self, input_sequence):
self.b = tf.Variable(tf.concat([b_i, b_f, b_g, b_o], axis=0), name="b")


# TODO(b/133740216): consider upstreaming into TensorFlow.
def _specialize_per_device(api_name, specializations, default):
"""Create a :tf:`function` specialized per-device.

Args:
api_name: Name of the function, e.g. ``"lstm"``.
specializations: A mapping from device type (e.g. ``"CPU"`` or ``"TPU``) to
a Python function with a specialized implementation for that device.
default: Default device type to use (typically, ``"CPU"``).

Returns:
A :tf:`function` which when called dispatches to the specialization
for the current device.
"""
# Cached to avoid redundant ``ModuleWrapper.__getattribute__`` calls.
list_logical_devices = tf.config.experimental.list_logical_devices

def wrapper(*args, **kwargs):
"""Specialized {}.

In eager mode the specialization is chosen based on the current
device context or, if no device context is active, on availability
of a GPU.

In graph mode (inside tf.function) the choice is delegated to the
implementation selector pass in Grappler.

Args:
*args: Positional arguments to pass to the chosen specialization.
**kwargs: Keyword arguments to pass to the chosen specialization.
""".format(api_name)
ctx = context_lib.context()
if ctx.executing_eagerly():
device = ctx.device_spec.device_type
if device is None:
# Soft-placement will never implicitly place an op an a TPU, so
# we only need to consider CPU/GPU.
device = "GPU" if list_logical_devices("GPU") else "CPU"

specialization = specializations.get(device) or specializations[default]
return specialization(*args, **kwargs)

# Implementation selector requires a globally unique name for each
# .register() call.
unique_api_name = "{}_{}".format(api_name, uuid.uuid4())
functions = {}
for device, specialization in specializations.items():
functions[device] = function_lib.defun_with_attributes(
specialization,
attributes={
"api_implements": unique_api_name,
"api_preferred_device": device
})
function_lib.register(functions[device], *args, **kwargs)
return functions[default](*args, **kwargs)

return wrapper


def _fallback_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b):
"""Fallback version of :class:`UnrolledLSTM` which works on any device."""
return dynamic_unroll(
functools.partial(_lstm_fn, w_i=w_i, w_h=w_h, b=b), input_sequence,
functools.partial(recurrent_internals.lstm_fn, w_i=w_i, w_h=w_h, b=b), input_sequence,
initial_state)


Expand Down Expand Up @@ -1105,7 +948,7 @@ def _cudnn_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b):
if hasattr(tf.raw_ops, "BlockLSTMV2"):
_unrolled_lstm_impls["CPU"] = _block_unrolled_lstm

_specialized_unrolled_lstm = _specialize_per_device(
_specialized_unrolled_lstm = recurrent_internals.specialize_per_device(
"snt_unrolled_lstm", specializations=_unrolled_lstm_impls, default="TPU")


Expand Down Expand Up @@ -1349,7 +1192,7 @@ def initial_state(self, batch_size):

@once.once
def _initialize(self, inputs):
dtype = _check_inputs_dtype(inputs, self._dtype)
dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype)
b_i, b_f, b_g, b_o = tf.split(
self._b_init([4 * self._output_channels], dtype), num_or_size_splits=4)
b_f += self._forget_bias
Expand Down Expand Up @@ -1602,7 +1445,7 @@ def hidden_to_hidden(self):
def _initialize(self, inputs):
utils.assert_rank(inputs, 2)
input_size = inputs.shape[1]
dtype = _check_inputs_dtype(inputs, self._dtype)
dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype)
self._w_i = tf.Variable(
self._w_i_init([input_size, 3 * self._hidden_size], dtype), name="w_i")
self._w_h = tf.Variable(
Expand Down Expand Up @@ -1711,17 +1554,10 @@ def initial_state(self, batch_size):
def _initialize(self, inputs):
utils.assert_rank(inputs, 3) # [num_steps, batch_size, input_size].
input_size = inputs.shape[2]
dtype = _check_inputs_dtype(inputs, self._dtype)
dtype = recurrent_internals.check_inputs_dtype(inputs, self._dtype)
self._w_i = tf.Variable(
self._w_i_init([input_size, 3 * self._hidden_size], dtype), name="w_i")
self._w_h = tf.Variable(
self._w_h_init([self._hidden_size, 3 * self._hidden_size], dtype),
name="w_h")
self.b = tf.Variable(self._b_init([3 * self._hidden_size], dtype), name="b")


def _check_inputs_dtype(inputs, expected_dtype):
if inputs.dtype is not expected_dtype:
raise TypeError("inputs must have dtype {!r}, got {!r}".format(
expected_dtype, inputs.dtype))
return expected_dtype
Loading