diff --git a/tests/test_models/test_informer.py b/tests/test_models/test_informer.py index ac8c2b4..d259d93 100644 --- a/tests/test_models/test_informer.py +++ b/tests/test_models/test_informer.py @@ -2,8 +2,10 @@ python -m unittest -v tests/test_models/test_informer.py """ +from typing import Any, Dict import unittest +import numpy as np import tensorflow as tf from tensorflow.keras.layers import LayerNormalization @@ -12,6 +14,8 @@ from tfts.layers.attention_layer import FullAttention, ProbAttention from tfts.models.informer import Decoder, DecoderLayer, DistilConv, Encoder, EncoderLayer, Informer +tf.config.run_functions_eagerly(True) + class InformerTest(unittest.TestCase): def test_model(self): @@ -101,3 +105,45 @@ def test_decoder(self): y = decoder(x, memory=memory) self.assertEqual(y.shape, (2, 50, attention_hidden_sizes)) + + def test_train(self): + params: Dict[str, Any] = { + "n_encoder_layers": 1, + "n_decoder_layers": 1, + "attention_hidden_sizes": 32 * 1, + "num_heads": 1, + "attention_dropout": 0.0, + "ffn_hidden_sizes": 32 * 1, + "ffn_filter_sizes": 32 * 1, + "ffn_dropout": 0.0, + "skip_connect_circle": False, + "skip_connect_mean": False, + "prob_attention": False, + "distil_conv": False, + } + + custom_params = params.copy() + custom_params["prob_attention"] = True + + train_length = 49 + predict_length = 10 + n_encoder_feature = 2 + n_decoder_feature = 3 + + x_train = ( + np.random.rand(1, train_length, 1), + np.random.rand(1, train_length, n_encoder_feature), + np.random.rand(1, predict_length, n_decoder_feature), + ) + y_train = np.random.rand(1, predict_length, 1) # target: (batch, predict_length, 1) + + x_valid = ( + np.random.rand(1, train_length, 1), + np.random.rand(1, train_length, n_encoder_feature), + np.random.rand(1, predict_length, n_decoder_feature), + ) + y_valid = np.random.rand(1, predict_length, 1) + + model = AutoModel("Informer", predict_length=predict_length, custom_model_params=custom_params) + trainer = KerasTrainer(model) + trainer.train((x_train, y_train), (x_valid, y_valid), n_epochs=1) diff --git a/tfts/layers/attention_layer.py b/tfts/layers/attention_layer.py index 43f6bb4..739d8b5 100644 --- a/tfts/layers/attention_layer.py +++ b/tfts/layers/attention_layer.py @@ -141,28 +141,28 @@ def build(self, input_shape: Tuple[Optional[int], ...]) -> None: super().build(input_shape) def _prob_qk(self, q, k, sample_k, top_n): - B, H, L, E = k.shape + _, H, L, E = k.shape _, _, S, _ = q.shape + B = tf.shape(k)[0] k_expand = tf.broadcast_to(tf.expand_dims(k, -3), (B, H, L, S, E)) - k_random_index = tf.random.uniform((S, sample_k), maxval=L, dtype=tf.int32) - k_random_index = tf.tile(k_random_index[tf.newaxis, tf.newaxis, :], [B, H, 1, 1]) - batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis, tf.newaxis], (1, H, L, k_random_index.shape[-1])) - head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis, tf.newaxis], (B, 1, L, k_random_index.shape[-1])) - k_indexes = tf.tile(tf.range(L)[tf.newaxis, tf.newaxis, :, tf.newaxis], (B, H, 1, k_random_index.shape[-1])) + indx_q_seq = tf.random.uniform((S,), maxval=L, dtype=tf.int32) + indx_k_seq = tf.random.uniform((sample_k,), maxval=L, dtype=tf.int32) - k_random_index = tf.stack([batch_indexes, head_indexes, k_indexes, k_random_index], axis=-1) - k_sample = tf.gather_nd(k_expand, k_random_index) + K_sample = tf.gather(k_expand, tf.range(S), axis=2) - qk_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.transpose(k_sample, [0, 1, 2, 4, 3]))) - m = tf.math.reduce_max(qk_sample, axis=-1) - tf.divide(tf.reduce_sum(qk_sample, axis=-1), L) - m_top = tf.math.top_k(m, top_n, sorted=False)[1] - m_top = m_top[tf.newaxis] if B == 1 else m_top - m_top = tf.tile(m_top, (1, 1, 1)) + K_sample = tf.gather(K_sample, indx_q_seq, axis=2) + K_sample = tf.gather(K_sample, indx_k_seq, axis=3) + + Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample))) + M = tf.math.reduce_max(Q_K_sample, axis=-1) - tf.raw_ops.Div(x=tf.reduce_sum(Q_K_sample, axis=-1), y=L) + m_top = tf.math.top_k(M, top_n, sorted=False)[1] + m_top = m_top[tf.newaxis, tf.newaxis] if B == 1 else m_top batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, top_n)) head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, top_n)) + idx = tf.stack([batch_indexes, head_indexes, m_top], axis=-1) q_reduce = tf.gather_nd(q, idx) @@ -170,7 +170,8 @@ def _prob_qk(self, q, k, sample_k, top_n): return qk, m_top def _get_initial_context(self, v, L_Q): - B, H, L_V, D = v.shape + _, H, L_V, D = v.shape + B = tf.shape(v)[0] if not self.mask_flag: v_sum = tf.math.reduce_sum(v, axis=-2) context = tf.identity(tf.boradcast_to(tf.expand_dims(v_sum, -2), [B, H, L_Q, v_sum.shape[-1]])) @@ -180,9 +181,10 @@ def _get_initial_context(self, v, L_Q): return context def _update_context(self, context_in, v, scores, index, L_Q): - B, H, L_V, D = v.shape - batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, index.shape[-1])) - head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, index.shape[-1])) + _, H, L_V, D = v.shape + B = tf.shape(v)[0] + batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, tf.shape(index)[-1])) + head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, tf.shape(index)[-1])) index = tf.stack([batch_indexes, head_indexes, index], axis=-1) if self.mask_flag: @@ -193,18 +195,20 @@ def _update_context(self, context_in, v, scores, index, L_Q): context_in = tf.tensor_scatter_nd_update(context_in, index, tf.matmul(attn, v)) return tf.convert_to_tensor(context_in) + # @tf.function def call(self, q, k, v, mask=None): """Prob attention""" q = self.dense_q(q) # project the query/key/value to num_heads * units k = self.dense_k(k) v = self.dense_v(v) - B, L, D = q.shape + _, L, D = q.shape + B = tf.shape(q)[0] _, S, _ = k.shape - q_ = tf.reshape(q, (B, self.num_heads, L, -1)) - k_ = tf.reshape(k, (B, self.num_heads, S, -1)) - v_ = tf.reshape(v, (B, self.num_heads, S, -1)) + q_ = tf.reshape(q, (-1, self.num_heads, L, self.hidden_size // self.num_heads)) + k_ = tf.reshape(k, (-1, self.num_heads, S, self.hidden_size // self.num_heads)) + v_ = tf.reshape(v, (-1, self.num_heads, S, self.hidden_size // self.num_heads)) u_q = self.factor * np.ceil(np.log(L)).astype("int").item() u_k = self.factor * np.ceil(np.log(S)).astype("int").item() diff --git a/tfts/layers/mask_layer.py b/tfts/layers/mask_layer.py index 55802d6..c9b0d8f 100644 --- a/tfts/layers/mask_layer.py +++ b/tfts/layers/mask_layer.py @@ -36,7 +36,7 @@ def __init__(self, B, H, L, index, scores): mask_expanded = tf.broadcast_to(mask, [B, H, L, scores.shape[-1]]) # mask specific q based on reduced Q mask_Q = tf.gather_nd(mask_expanded, index) - self._mask = tf.cast(tf.reshape(mask_Q, scores.shape), tf.bool) + self._mask = tf.cast(tf.reshape(mask_Q, tf.shape(scores)), tf.bool) @property def mask(self): diff --git a/tfts/models/informer.py b/tfts/models/informer.py index 087421d..434b9c3 100644 --- a/tfts/models/informer.py +++ b/tfts/models/informer.py @@ -147,9 +147,7 @@ def call(self, x, mask=None): if self.conv_layers is not None: for attn_layer, conv_layer in zip(self.layers, self.conv_layers): x = attn_layer(x, mask) - # print(x.shape) # x = conv_layer(x) - # print(x.shape) x = self.layers[-1](x, mask) else: