Skip to content

Commit

Permalink
[SYSTEMDS-3829] BERT layer forward pass
Browse files Browse the repository at this point in the history
This patch introduces the forward pass of the BERT layer from the
BERT transformer architecture as a built-in.

Closes #2184
  • Loading branch information
MaximilianSchreff authored and phaniarnab committed Feb 3, 2025
1 parent 344ca0b commit e022eaf
Show file tree
Hide file tree
Showing 42 changed files with 647 additions and 1 deletion.
186 changes: 186 additions & 0 deletions scripts/nn/layers/bert_layer.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

source("nn/layers/affine.dml") as affine
source("nn/layers/multi_attention.dml") as attention
source("nn/layers/dropout.dml") as dropout
source("nn/layers/batch_norm1d.dml") as batch_norm
source("nn/layers/tanh.dml") as tanh
source("nn/layers/gelu.dml") as gelu

linear_tensor_forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int B, int C)
return (matrix[double] out) {
/*
* Helper function for computing linear layer with tensor input, of shape (A, B*C)
*/
A = nrow(X)
C_new = ncol(W)
out = affine::forward(matrix(X, rows=A*B, cols=C), W, b)
out = matrix(out, rows=A, cols=B*C_new)
}

layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C)
return (matrix[double] out, matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {
/*
* Helper function for computing layer norm via 1D batch norm with tensor input, of shpae (A, B*C)
*/
A = nrow(X)
batch_norm_input = t(matrix(X, rows=A*B, cols=C))
# EMA matrices are unused and thus empty matrices will be provided
emas_mat = matrix(0, rows=1, cols=A*B)
[batch_norm_out, unused1, unused2, cache_mean, cache_var, cache_norm] = batch_norm::forward(
batch_norm_input, t(gamma), t(beta), "train", emas_mat, emas_mat, 0.0, epsilon)
out = matrix(t(batch_norm_out), rows=A, cols=B*C)
}

forward = function(matrix[double] states,
int H, int T, int d, int I,
matrix[double] W_Q, matrix[double] b_Q,
matrix[double] W_K, matrix[double] b_K,
matrix[double] W_V, matrix[double] b_V,
matrix[double] W_context, matrix[double] b_context,
matrix[double] W_intermediate, matrix[double] b_intermediate,
matrix[double] W_out, matrix[double] b_out,
double dropout_p_attention,
double dropout_p_output,
double epsilon_ln,
matrix[double] gamma_ln1, matrix[double] beta_ln1,
matrix[double] gamma_ln2, matrix[double] beta_ln2,
string activation)
return (matrix[double] out_states, matrix[double] attention,
list[unknown] outputs,
matrix[double] dropout_mask_attention,
matrix[double] dropout_mask_output_1,
matrix[double] dropout_mask_output_2,
matrix[double] cache_mean_ln1, matrix[double] cache_var_ln1, matrix[double] cache_norm_ln1,
matrix[double] cache_mean_ln2, matrix[double] cache_var_ln2, matrix[double] cache_norm_ln2) {
/*
* Computes the forward pass for a layer of the BERT transformer architecture.
*
* Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
* - states: Hidden states, of shape (B, T*D).
* - H: Head count.
* - T: Sequence length.
* - d: Embedding length of single token per head with d*H = D.
* - I: Intemediate embedding length.
* - W_Q: Weights for linear query layer, of shape (D, D).
* - b_Q: Biases for linear query layer, of shape (1, D).
* - W_K: Weights for linear key layer, of shape (D, D).
* - b_K: Biases for linear key layer, of shape (1, D).
* - W_V: Weights for linear value layer, of shape (D, D).
* - b_V: Biases for linear value layer, of shape (1, D).
* - W_context: Weights for linear output layer on context, of shape (D, D).
* - b_context: Biases for linear output layer on context, of shape (1, D).
* - W_intermediate: Weights for intermediate linear layer, of shape (D, I).
* - b_intermediate: Biases for intermediate linear layer, of shape (1, I).
* - W_out: Weights for last linear output layer, of shape (D, D).
* - b_out: Biases for last linear output layer, of shape (1, D).
* - dropout_p_attention: Probability for dropout on attention.
* - dropout_p_output: Probability for dropout on output.
* - epsilon_ln: Epsilon value for layer norm.
* - gamma_ln1: Gamma params for layer norm 1, of shape (1, D).
* - beta_ln1: Beta params for layer norm 1, of shape (1, D).
* - gamma_ln2: Gamma params for layer norm 2, of shape (1, D).
* - beta_ln2: Beta params for layer norm 2, of shape (1, D).
* - activation: String specifying type of activation to use.
* Can be tanh or gelu.
*
* Outputs:
* - out_states: Token output states, of shape (B, T*D)
* - attention: Attention values for keys & querys, of shape (B, H*T*T)
* - outputs: List of relevant outputs for backward pass with following
* order/content:
* -> 1: Output of linear query layer, of shape (B, T*D).
* -> 2: Output of linear key layer, of shape (B, T*D).
* -> 3: Output of linear value layer, of shape (B, T*D).
* -> 4: Output context of attention layer, of shape (B, T*D).
* -> 5: Output attention of attention layer, of shape (B, T*D).
* -> 6: Output of residual pass 1, of shape (B, T*D).
* -> 7: Output of layer norm 1, of shape (B, T*D).
* -> 8: Output of intermediate linear layer, of shape (B, T*I).
* -> 9: Output of activation layer, of shape (B, T*I).
* -> 10: Output of residual pass 2, of shape (B, T*D).
* - dropout_mask_attention: Dropout mask used on attention, of shape (B, H*T*T)
* - dropout_mask_output_1: Dropout mask used on attention output, of shape (B, T*D)
* - dropout_mask_output_2: Dropout mask used on attention output, of shape (B, T*D)
* - cache_mean_ln1: Cached mean from layer norm 1, of shape (1, B*T)
* - cache_var_ln1: Cached mean from layer norm 1, of shape (1, B*T)
* - cache_norm_ln1: Cached mean from layer norm 1, of shape (1, B*T)
* - cache_mean_ln2: Cached mean from layer norm 2, of shape (1, B*T)
* - cache_var_ln2: Cached mean from layer norm 2, of shape (1, B*T)
* - cache_norm_ln2: Cached mean from layer norm 2, of shape (1, B*T)
*/
# Embedding dim
D = d * H

# Linear layers for Q, K, V
Q = linear_tensor_forward(states, W_Q, b_Q, T, D) # Shape (B, T*D)
K = linear_tensor_forward(states, W_K, b_K, T, D) # Shape (B, T*D)
V = linear_tensor_forward(states, W_V, b_V, T, D) # Shape (B, T*D)

# Multi-head self attention
[context, attention, dropout_mask_attention] = attention::forward(Q, K, V, H, T, d, dropout_p_attention)
# Shapes (B, T*D), (B, H*T*T), (B, H*T*T)
outputs = list(Q, K, V, context, attention)

# Linear layer on attention output (output layer)
out_states = linear_tensor_forward(context, W_context, b_context, T, D) # Shape (B, T*D)
# Dropout on output 1
dropout_mask_output_1 = matrix(0, 1, 1)
if (dropout_p_output > 0.0) {
[out_states, dropout_mask_output_1] = dropout::forward(out_states, dropout_p_output, -1)
}

# Residual pass 1
out_states = out_states + states # Shapes (B, T*D).
outputs = append(outputs, out_states)
# Layer norm 1 for each token
[out_states, cache_mean_ln1, cache_var_ln1, cache_norm_ln1] = layer_norm_forward(
out_states, gamma_ln1, beta_ln1, epsilon_ln, T, D)
outputs = append(outputs, out_states)

# Save out_states for residual pass
out_states_identity = out_states
# Linear layer of intermediate part
out_states = linear_tensor_forward(out_states, W_intermediate, b_intermediate, T, D) # Shape (B, T*I)
outputs = append(outputs, out_states)
# Activation
if (activation == "gelu") {
out_states = gelu::forward(out_states)
} else if (activation == "tanh") {
out_states = tanh::forward(out_states)
}
outputs = append(outputs, out_states)

# Final linear output layer
out_states = linear_tensor_forward(out_states, W_out, b_out, T, I) # Shape (B, T*D)
# Dropout on output 2
dropout_mask_output_2 = matrix(0, 1, 1)
if (dropout_p_output > 0.0) {
[out_states, dropout_mask_output_2] = dropout::forward(out_states, dropout_p_output, -1)
}
# Residual pass 2
out_states = out_states + out_states_identity
outputs = append(outputs, out_states)
# Layer norm 2 for each token
[out_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2] = layer_norm_forward(
out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.test.applications.nn.transformers;

import org.apache.sysds.common.Types;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Test;

public class BertLayerTest extends AutomatedTestBase{
private static final String TEST_NAME_FORWARD = "bert_layer_forward";
private static final String TEST_DIR = "applications/nn/component/";
private static final String RESOURCE_DIR = "src/test/resources/component/transformers/bert_layer/";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME_FORWARD, new TestConfiguration(TEST_DIR, TEST_NAME_FORWARD));
}

@Test
public void testBertLayerForwardNormalTanh() {
runBertLayerTest("test1", 5, 4, 6, 2, 3, 7, "tanh", 0, TEST_NAME_FORWARD,
1e-5, true);
}

@Test
public void testBertLayerForwardNormalGelu() {
runBertLayerTest("test2", 4, 4, 8, 2, 4, 7, "gelu", 0, TEST_NAME_FORWARD,
1e-5, true);
}

private void runBertLayerTest(String testSuffix, int batchSize, int seqLength, int embeddingDim, int numHeads,
int perHeadEmbeddingDim, int intermediateEmbeddingDim, String activation, int debug, String testname, double precision,
boolean isForward) {
// Set execution platform
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);

try {
// Load test configuration
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();

// Program arguments
if (isForward) {
programArgs = new String[] {
"-stats", "-args",
String.valueOf(debug), String.valueOf(batchSize),
String.valueOf(seqLength), String.valueOf(embeddingDim),
String.valueOf(numHeads), String.valueOf(perHeadEmbeddingDim),
String.valueOf(intermediateEmbeddingDim), activation,
RESOURCE_DIR + "input_states_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_Q_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_Q_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_K_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_K_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_V_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_V_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_context_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_context_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_intermediate_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_intermediate_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_out_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_out_" + testSuffix + ".csv",
RESOURCE_DIR + "input_gamma_ln1_" + testSuffix + ".csv",
RESOURCE_DIR + "input_beta_ln1_" + testSuffix + ".csv",
RESOURCE_DIR + "input_gamma_ln2_" + testSuffix + ".csv",
RESOURCE_DIR + "input_beta_ln2_" + testSuffix + ".csv",
RESOURCE_DIR + "output_states_" + testSuffix + ".csv",
RESOURCE_DIR + "output_attention_" + testSuffix + ".csv",
output("states_error"),
output("attention_error"),
};
}

// Run the test
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);

// Compare results
if (isForward) {
double statesMaxError = (Double) readDMLScalarFromOutputDir("states_error").values().toArray()[0];
assert statesMaxError < precision;
double attentionMaxError = (Double) readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
assert attentionMaxError < precision;
} else {
double dqueryMaxError = (Double) readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
assert dqueryMaxError < precision;
double dkeyMaxError = (Double) readDMLScalarFromOutputDir("dkey_error").values().toArray()[0];
assert dkeyMaxError < precision;
double dvalueMaxError = (Double) readDMLScalarFromOutputDir("dvalue_error").values().toArray()[0];
assert dvalueMaxError < precision;
}
} catch (Throwable ex) {
ex.printStackTrace(System.out); // Log or debug all exceptions or errors
throw new RuntimeException(ex);
} finally {
resetExecMode(platformOld);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ private void runMultiAttentionTest(String testSuffix, int batchSize, int seqLeng
if (isForward) {
double contextMaxError = (Double) readDMLScalarFromOutputDir("context_error").values().toArray()[0];
assert contextMaxError < precision;
double attentionMaxError = (Double) readDMLScalarFromOutputDir("context_error").values().toArray()[0];
double attentionMaxError = (Double) readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
assert attentionMaxError < precision;
} else {
double dqueryMaxError = (Double) readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-0.366342,-0.367447,-0.368428,0.043807,-0.098173,-0.287969
-0.352497,-0.027552,0.258298,-0.085474,-0.085857,0.018208
-0.063882,0.359021,-0.047110,0.291535,-0.336430,-0.287788
0.005280,-0.166521,-0.182245,0.113960,0.221207,-0.224734
-0.185457,0.368649,0.326457,0.196166,0.324140,-0.237889
0.153787,0.147849,-0.329904,0.144177,0.279334,0.139517
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
0.014873,0.127848,-0.276551,-0.306393,0.044137,0.116741,0.004873,-0.350424
0.227909,0.044769,-0.185308,0.175143,0.316675,0.265246,-0.060110,0.159592
-0.267258,-0.002632,0.285492,-0.251829,0.216273,-0.113814,-0.186207,-0.169799
-0.242719,-0.069891,-0.286925,-0.100361,-0.223521,0.000566,0.046730,-0.235940
-0.205295,0.044359,-0.025387,-0.118623,0.158570,0.182018,0.292360,-0.203683
0.247464,-0.080732,0.349749,-0.052357,-0.249925,-0.341919,-0.103351,0.203278
-0.127090,-0.002484,0.127717,0.003867,-0.149845,0.255612,-0.209903,0.187233
0.298218,0.045111,0.010010,0.291613,0.103988,-0.292361,-0.130758,0.271360
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
0.160208,-0.102197,-0.265960,0.082622,-0.366738,-0.382060
0.024454,-0.198863,-0.020951,0.259580,-0.193541,-0.344565
-0.199196,-0.142819,0.292245,0.386712,0.277978,-0.082808
0.193179,-0.334609,-0.041968,0.259260,-0.002646,0.223886
-0.391612,-0.086841,0.011346,0.387596,-0.202918,0.220716
-0.241971,0.087266,-0.035219,-0.029525,-0.312845,-0.393728
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-0.123685,0.009826,-0.317605,-0.071714,-0.068100,-0.318219,-0.040798,0.169884
-0.289780,-0.030501,-0.167612,0.193891,-0.069418,-0.023860,-0.157829,0.124861
-0.075206,0.071553,0.240736,0.191146,-0.317261,0.310922,0.282720,-0.085021
0.075574,0.224803,-0.002292,-0.340978,-0.305271,-0.144212,-0.285705,-0.074354
-0.230328,0.334902,-0.175732,0.220540,-0.055324,0.319260,0.037938,-0.291357
-0.018144,0.224526,-0.270932,-0.276659,0.004572,0.128041,-0.074023,0.191571
0.253091,0.335668,-0.330874,-0.074745,-0.160610,-0.319068,0.252477,0.280714
-0.036345,-0.025570,-0.298402,-0.143356,0.133183,0.223692,0.098693,0.241910
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-0.237054,-0.003039,-0.319334,0.147474,-0.136974,0.249731
0.285747,-0.080703,-0.213976,0.011559,-0.060456,-0.258100
-0.146751,0.051221,0.329658,-0.353792,0.004466,0.183101
0.344353,-0.093221,-0.331313,0.202237,0.336726,-0.288589
0.147626,-0.002869,-0.029315,-0.290787,0.050965,-0.173026
0.051695,0.052090,0.403855,-0.115887,0.365665,0.120075
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-0.334274,0.011264,-0.080849,0.244498,0.069431,0.122792,0.029533,0.165114
0.076385,-0.237900,-0.015723,-0.263194,-0.262510,-0.129004,0.044147,-0.171997
-0.198408,-0.285785,-0.215330,0.144839,0.058866,0.134202,-0.277945,-0.292986
-0.315220,0.281811,0.119572,-0.118884,0.150589,0.235453,0.027785,-0.304028
0.310023,0.057572,0.111782,-0.170578,0.139947,-0.184608,0.244825,0.352708
-0.229602,0.293317,-0.007293,0.063514,-0.044505,0.003487,0.318592,0.224432
-0.040221,-0.118525,-0.079515,-0.183656,-0.289839,0.146194,0.207801,-0.244388
0.101291,0.104141,-0.217941,0.081460,-0.054502,0.027711,0.047377,0.138325
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
0.295156,0.337588,-0.196067,0.148081,-0.193161,0.357983
-0.337590,-0.119339,-0.272441,-0.136338,-0.191908,-0.265121
0.005627,-0.242375,-0.235192,-0.114084,-0.385986,-0.046443
-0.069409,-0.150986,0.234725,0.120609,0.088201,0.116961
-0.215013,-0.404635,0.216198,0.335594,-0.229102,0.013006
0.053959,0.184281,0.313339,0.111000,-0.363984,-0.274703
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-0.212761,-0.063619,-0.239211,-0.241276,0.290790,0.116962,-0.199302,0.226094
0.317851,0.137643,-0.348794,-0.057128,0.289475,0.353439,-0.198385,0.026281
0.170822,0.062799,-0.283923,-0.229609,0.253087,0.183375,-0.272053,-0.166925
0.192735,0.150426,0.279120,0.245506,0.272988,0.219786,0.237351,0.324932
-0.221599,-0.120147,0.191285,-0.267289,0.314375,-0.123741,0.251352,0.144582
0.101434,0.172383,0.331709,-0.172499,-0.090532,0.169645,-0.040239,-0.268398
-0.123943,-0.246947,0.283239,-0.341565,0.155564,0.040626,-0.204596,0.338380
0.276251,0.079852,-0.315739,-0.200728,0.314991,-0.084435,0.273263,0.268479
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-0.093357,-0.091816,-0.196967,0.067973,0.141788,0.168810,0.282700
-0.018155,-0.251656,0.073340,0.173885,-0.148960,0.031998,0.367878
-0.248641,0.282322,-0.212067,0.161597,0.154963,0.034102,0.239948
0.138070,-0.303910,0.094062,-0.051390,0.271878,0.050976,0.054707
0.129074,0.167245,0.080172,-0.334677,-0.213168,-0.320943,0.190658
-0.008422,-0.137275,-0.303120,-0.062933,0.004026,0.032084,-0.198604
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-0.095970,0.175377,0.300088,-0.339869,0.016123,0.019002,0.112700
-0.068211,-0.323126,0.023189,-0.343534,-0.318806,0.167661,0.189790
0.033835,-0.063266,-0.235639,-0.071723,0.293212,-0.283489,0.049253
0.326977,-0.262741,-0.126673,0.237741,0.190369,-0.101691,-0.236557
0.018929,-0.150856,0.077203,-0.334631,0.351431,-0.347146,-0.274117
-0.218298,0.127383,-0.269520,0.293869,0.178619,-0.137706,-0.109077
0.018121,-0.251069,0.175649,-0.141429,-0.233370,0.076272,0.155195
0.169524,0.131425,-0.320980,0.103550,0.295070,-0.277597,0.348744
Loading

0 comments on commit e022eaf

Please sign in to comment.