-
Notifications
You must be signed in to change notification settings - Fork 481
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYSTEMDS-3829] BERT layer forward pass
This patch introduces the forward pass of the BERT layer from the BERT transformer architecture as a built-in. Closes #2184
- Loading branch information
1 parent
344ca0b
commit e022eaf
Showing
42 changed files
with
647 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
#------------------------------------------------------------- | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
# | ||
#------------------------------------------------------------- | ||
|
||
source("nn/layers/affine.dml") as affine | ||
source("nn/layers/multi_attention.dml") as attention | ||
source("nn/layers/dropout.dml") as dropout | ||
source("nn/layers/batch_norm1d.dml") as batch_norm | ||
source("nn/layers/tanh.dml") as tanh | ||
source("nn/layers/gelu.dml") as gelu | ||
|
||
linear_tensor_forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int B, int C) | ||
return (matrix[double] out) { | ||
/* | ||
* Helper function for computing linear layer with tensor input, of shape (A, B*C) | ||
*/ | ||
A = nrow(X) | ||
C_new = ncol(W) | ||
out = affine::forward(matrix(X, rows=A*B, cols=C), W, b) | ||
out = matrix(out, rows=A, cols=B*C_new) | ||
} | ||
|
||
layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C) | ||
return (matrix[double] out, matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) { | ||
/* | ||
* Helper function for computing layer norm via 1D batch norm with tensor input, of shpae (A, B*C) | ||
*/ | ||
A = nrow(X) | ||
batch_norm_input = t(matrix(X, rows=A*B, cols=C)) | ||
# EMA matrices are unused and thus empty matrices will be provided | ||
emas_mat = matrix(0, rows=1, cols=A*B) | ||
[batch_norm_out, unused1, unused2, cache_mean, cache_var, cache_norm] = batch_norm::forward( | ||
batch_norm_input, t(gamma), t(beta), "train", emas_mat, emas_mat, 0.0, epsilon) | ||
out = matrix(t(batch_norm_out), rows=A, cols=B*C) | ||
} | ||
|
||
forward = function(matrix[double] states, | ||
int H, int T, int d, int I, | ||
matrix[double] W_Q, matrix[double] b_Q, | ||
matrix[double] W_K, matrix[double] b_K, | ||
matrix[double] W_V, matrix[double] b_V, | ||
matrix[double] W_context, matrix[double] b_context, | ||
matrix[double] W_intermediate, matrix[double] b_intermediate, | ||
matrix[double] W_out, matrix[double] b_out, | ||
double dropout_p_attention, | ||
double dropout_p_output, | ||
double epsilon_ln, | ||
matrix[double] gamma_ln1, matrix[double] beta_ln1, | ||
matrix[double] gamma_ln2, matrix[double] beta_ln2, | ||
string activation) | ||
return (matrix[double] out_states, matrix[double] attention, | ||
list[unknown] outputs, | ||
matrix[double] dropout_mask_attention, | ||
matrix[double] dropout_mask_output_1, | ||
matrix[double] dropout_mask_output_2, | ||
matrix[double] cache_mean_ln1, matrix[double] cache_var_ln1, matrix[double] cache_norm_ln1, | ||
matrix[double] cache_mean_ln2, matrix[double] cache_var_ln2, matrix[double] cache_norm_ln2) { | ||
/* | ||
* Computes the forward pass for a layer of the BERT transformer architecture. | ||
* | ||
* Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads): | ||
* - states: Hidden states, of shape (B, T*D). | ||
* - H: Head count. | ||
* - T: Sequence length. | ||
* - d: Embedding length of single token per head with d*H = D. | ||
* - I: Intemediate embedding length. | ||
* - W_Q: Weights for linear query layer, of shape (D, D). | ||
* - b_Q: Biases for linear query layer, of shape (1, D). | ||
* - W_K: Weights for linear key layer, of shape (D, D). | ||
* - b_K: Biases for linear key layer, of shape (1, D). | ||
* - W_V: Weights for linear value layer, of shape (D, D). | ||
* - b_V: Biases for linear value layer, of shape (1, D). | ||
* - W_context: Weights for linear output layer on context, of shape (D, D). | ||
* - b_context: Biases for linear output layer on context, of shape (1, D). | ||
* - W_intermediate: Weights for intermediate linear layer, of shape (D, I). | ||
* - b_intermediate: Biases for intermediate linear layer, of shape (1, I). | ||
* - W_out: Weights for last linear output layer, of shape (D, D). | ||
* - b_out: Biases for last linear output layer, of shape (1, D). | ||
* - dropout_p_attention: Probability for dropout on attention. | ||
* - dropout_p_output: Probability for dropout on output. | ||
* - epsilon_ln: Epsilon value for layer norm. | ||
* - gamma_ln1: Gamma params for layer norm 1, of shape (1, D). | ||
* - beta_ln1: Beta params for layer norm 1, of shape (1, D). | ||
* - gamma_ln2: Gamma params for layer norm 2, of shape (1, D). | ||
* - beta_ln2: Beta params for layer norm 2, of shape (1, D). | ||
* - activation: String specifying type of activation to use. | ||
* Can be tanh or gelu. | ||
* | ||
* Outputs: | ||
* - out_states: Token output states, of shape (B, T*D) | ||
* - attention: Attention values for keys & querys, of shape (B, H*T*T) | ||
* - outputs: List of relevant outputs for backward pass with following | ||
* order/content: | ||
* -> 1: Output of linear query layer, of shape (B, T*D). | ||
* -> 2: Output of linear key layer, of shape (B, T*D). | ||
* -> 3: Output of linear value layer, of shape (B, T*D). | ||
* -> 4: Output context of attention layer, of shape (B, T*D). | ||
* -> 5: Output attention of attention layer, of shape (B, T*D). | ||
* -> 6: Output of residual pass 1, of shape (B, T*D). | ||
* -> 7: Output of layer norm 1, of shape (B, T*D). | ||
* -> 8: Output of intermediate linear layer, of shape (B, T*I). | ||
* -> 9: Output of activation layer, of shape (B, T*I). | ||
* -> 10: Output of residual pass 2, of shape (B, T*D). | ||
* - dropout_mask_attention: Dropout mask used on attention, of shape (B, H*T*T) | ||
* - dropout_mask_output_1: Dropout mask used on attention output, of shape (B, T*D) | ||
* - dropout_mask_output_2: Dropout mask used on attention output, of shape (B, T*D) | ||
* - cache_mean_ln1: Cached mean from layer norm 1, of shape (1, B*T) | ||
* - cache_var_ln1: Cached mean from layer norm 1, of shape (1, B*T) | ||
* - cache_norm_ln1: Cached mean from layer norm 1, of shape (1, B*T) | ||
* - cache_mean_ln2: Cached mean from layer norm 2, of shape (1, B*T) | ||
* - cache_var_ln2: Cached mean from layer norm 2, of shape (1, B*T) | ||
* - cache_norm_ln2: Cached mean from layer norm 2, of shape (1, B*T) | ||
*/ | ||
# Embedding dim | ||
D = d * H | ||
|
||
# Linear layers for Q, K, V | ||
Q = linear_tensor_forward(states, W_Q, b_Q, T, D) # Shape (B, T*D) | ||
K = linear_tensor_forward(states, W_K, b_K, T, D) # Shape (B, T*D) | ||
V = linear_tensor_forward(states, W_V, b_V, T, D) # Shape (B, T*D) | ||
|
||
# Multi-head self attention | ||
[context, attention, dropout_mask_attention] = attention::forward(Q, K, V, H, T, d, dropout_p_attention) | ||
# Shapes (B, T*D), (B, H*T*T), (B, H*T*T) | ||
outputs = list(Q, K, V, context, attention) | ||
|
||
# Linear layer on attention output (output layer) | ||
out_states = linear_tensor_forward(context, W_context, b_context, T, D) # Shape (B, T*D) | ||
# Dropout on output 1 | ||
dropout_mask_output_1 = matrix(0, 1, 1) | ||
if (dropout_p_output > 0.0) { | ||
[out_states, dropout_mask_output_1] = dropout::forward(out_states, dropout_p_output, -1) | ||
} | ||
|
||
# Residual pass 1 | ||
out_states = out_states + states # Shapes (B, T*D). | ||
outputs = append(outputs, out_states) | ||
# Layer norm 1 for each token | ||
[out_states, cache_mean_ln1, cache_var_ln1, cache_norm_ln1] = layer_norm_forward( | ||
out_states, gamma_ln1, beta_ln1, epsilon_ln, T, D) | ||
outputs = append(outputs, out_states) | ||
|
||
# Save out_states for residual pass | ||
out_states_identity = out_states | ||
# Linear layer of intermediate part | ||
out_states = linear_tensor_forward(out_states, W_intermediate, b_intermediate, T, D) # Shape (B, T*I) | ||
outputs = append(outputs, out_states) | ||
# Activation | ||
if (activation == "gelu") { | ||
out_states = gelu::forward(out_states) | ||
} else if (activation == "tanh") { | ||
out_states = tanh::forward(out_states) | ||
} | ||
outputs = append(outputs, out_states) | ||
|
||
# Final linear output layer | ||
out_states = linear_tensor_forward(out_states, W_out, b_out, T, I) # Shape (B, T*D) | ||
# Dropout on output 2 | ||
dropout_mask_output_2 = matrix(0, 1, 1) | ||
if (dropout_p_output > 0.0) { | ||
[out_states, dropout_mask_output_2] = dropout::forward(out_states, dropout_p_output, -1) | ||
} | ||
# Residual pass 2 | ||
out_states = out_states + out_states_identity | ||
outputs = append(outputs, out_states) | ||
# Layer norm 2 for each token | ||
[out_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2] = layer_norm_forward( | ||
out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D) | ||
} |
117 changes: 117 additions & 0 deletions
117
src/test/java/org/apache/sysds/test/applications/nn/transformers/BertLayerTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.apache.sysds.test.applications.nn.transformers; | ||
|
||
import org.apache.sysds.common.Types; | ||
import org.apache.sysds.test.AutomatedTestBase; | ||
import org.apache.sysds.test.TestConfiguration; | ||
import org.apache.sysds.test.TestUtils; | ||
import org.junit.Test; | ||
|
||
public class BertLayerTest extends AutomatedTestBase{ | ||
private static final String TEST_NAME_FORWARD = "bert_layer_forward"; | ||
private static final String TEST_DIR = "applications/nn/component/"; | ||
private static final String RESOURCE_DIR = "src/test/resources/component/transformers/bert_layer/"; | ||
|
||
@Override | ||
public void setUp() { | ||
TestUtils.clearAssertionInformation(); | ||
addTestConfiguration(TEST_NAME_FORWARD, new TestConfiguration(TEST_DIR, TEST_NAME_FORWARD)); | ||
} | ||
|
||
@Test | ||
public void testBertLayerForwardNormalTanh() { | ||
runBertLayerTest("test1", 5, 4, 6, 2, 3, 7, "tanh", 0, TEST_NAME_FORWARD, | ||
1e-5, true); | ||
} | ||
|
||
@Test | ||
public void testBertLayerForwardNormalGelu() { | ||
runBertLayerTest("test2", 4, 4, 8, 2, 4, 7, "gelu", 0, TEST_NAME_FORWARD, | ||
1e-5, true); | ||
} | ||
|
||
private void runBertLayerTest(String testSuffix, int batchSize, int seqLength, int embeddingDim, int numHeads, | ||
int perHeadEmbeddingDim, int intermediateEmbeddingDim, String activation, int debug, String testname, double precision, | ||
boolean isForward) { | ||
// Set execution platform | ||
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); | ||
|
||
try { | ||
// Load test configuration | ||
getAndLoadTestConfiguration(testname); | ||
fullDMLScriptName = getScript(); | ||
|
||
// Program arguments | ||
if (isForward) { | ||
programArgs = new String[] { | ||
"-stats", "-args", | ||
String.valueOf(debug), String.valueOf(batchSize), | ||
String.valueOf(seqLength), String.valueOf(embeddingDim), | ||
String.valueOf(numHeads), String.valueOf(perHeadEmbeddingDim), | ||
String.valueOf(intermediateEmbeddingDim), activation, | ||
RESOURCE_DIR + "input_states_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_W_Q_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_b_Q_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_W_K_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_b_K_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_W_V_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_b_V_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_W_context_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_b_context_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_W_intermediate_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_b_intermediate_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_W_out_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_b_out_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_gamma_ln1_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_beta_ln1_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_gamma_ln2_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "input_beta_ln2_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "output_states_" + testSuffix + ".csv", | ||
RESOURCE_DIR + "output_attention_" + testSuffix + ".csv", | ||
output("states_error"), | ||
output("attention_error"), | ||
}; | ||
} | ||
|
||
// Run the test | ||
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); | ||
|
||
// Compare results | ||
if (isForward) { | ||
double statesMaxError = (Double) readDMLScalarFromOutputDir("states_error").values().toArray()[0]; | ||
assert statesMaxError < precision; | ||
double attentionMaxError = (Double) readDMLScalarFromOutputDir("attention_error").values().toArray()[0]; | ||
assert attentionMaxError < precision; | ||
} else { | ||
double dqueryMaxError = (Double) readDMLScalarFromOutputDir("dquery_error").values().toArray()[0]; | ||
assert dqueryMaxError < precision; | ||
double dkeyMaxError = (Double) readDMLScalarFromOutputDir("dkey_error").values().toArray()[0]; | ||
assert dkeyMaxError < precision; | ||
double dvalueMaxError = (Double) readDMLScalarFromOutputDir("dvalue_error").values().toArray()[0]; | ||
assert dvalueMaxError < precision; | ||
} | ||
} catch (Throwable ex) { | ||
ex.printStackTrace(System.out); // Log or debug all exceptions or errors | ||
throw new RuntimeException(ex); | ||
} finally { | ||
resetExecMode(platformOld); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6 changes: 6 additions & 0 deletions
6
src/test/resources/component/transformers/bert_layer/input_W_K_test1.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
-0.366342,-0.367447,-0.368428,0.043807,-0.098173,-0.287969 | ||
-0.352497,-0.027552,0.258298,-0.085474,-0.085857,0.018208 | ||
-0.063882,0.359021,-0.047110,0.291535,-0.336430,-0.287788 | ||
0.005280,-0.166521,-0.182245,0.113960,0.221207,-0.224734 | ||
-0.185457,0.368649,0.326457,0.196166,0.324140,-0.237889 | ||
0.153787,0.147849,-0.329904,0.144177,0.279334,0.139517 |
8 changes: 8 additions & 0 deletions
8
src/test/resources/component/transformers/bert_layer/input_W_K_test2.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
0.014873,0.127848,-0.276551,-0.306393,0.044137,0.116741,0.004873,-0.350424 | ||
0.227909,0.044769,-0.185308,0.175143,0.316675,0.265246,-0.060110,0.159592 | ||
-0.267258,-0.002632,0.285492,-0.251829,0.216273,-0.113814,-0.186207,-0.169799 | ||
-0.242719,-0.069891,-0.286925,-0.100361,-0.223521,0.000566,0.046730,-0.235940 | ||
-0.205295,0.044359,-0.025387,-0.118623,0.158570,0.182018,0.292360,-0.203683 | ||
0.247464,-0.080732,0.349749,-0.052357,-0.249925,-0.341919,-0.103351,0.203278 | ||
-0.127090,-0.002484,0.127717,0.003867,-0.149845,0.255612,-0.209903,0.187233 | ||
0.298218,0.045111,0.010010,0.291613,0.103988,-0.292361,-0.130758,0.271360 |
6 changes: 6 additions & 0 deletions
6
src/test/resources/component/transformers/bert_layer/input_W_Q_test1.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
0.160208,-0.102197,-0.265960,0.082622,-0.366738,-0.382060 | ||
0.024454,-0.198863,-0.020951,0.259580,-0.193541,-0.344565 | ||
-0.199196,-0.142819,0.292245,0.386712,0.277978,-0.082808 | ||
0.193179,-0.334609,-0.041968,0.259260,-0.002646,0.223886 | ||
-0.391612,-0.086841,0.011346,0.387596,-0.202918,0.220716 | ||
-0.241971,0.087266,-0.035219,-0.029525,-0.312845,-0.393728 |
8 changes: 8 additions & 0 deletions
8
src/test/resources/component/transformers/bert_layer/input_W_Q_test2.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
-0.123685,0.009826,-0.317605,-0.071714,-0.068100,-0.318219,-0.040798,0.169884 | ||
-0.289780,-0.030501,-0.167612,0.193891,-0.069418,-0.023860,-0.157829,0.124861 | ||
-0.075206,0.071553,0.240736,0.191146,-0.317261,0.310922,0.282720,-0.085021 | ||
0.075574,0.224803,-0.002292,-0.340978,-0.305271,-0.144212,-0.285705,-0.074354 | ||
-0.230328,0.334902,-0.175732,0.220540,-0.055324,0.319260,0.037938,-0.291357 | ||
-0.018144,0.224526,-0.270932,-0.276659,0.004572,0.128041,-0.074023,0.191571 | ||
0.253091,0.335668,-0.330874,-0.074745,-0.160610,-0.319068,0.252477,0.280714 | ||
-0.036345,-0.025570,-0.298402,-0.143356,0.133183,0.223692,0.098693,0.241910 |
6 changes: 6 additions & 0 deletions
6
src/test/resources/component/transformers/bert_layer/input_W_V_test1.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
-0.237054,-0.003039,-0.319334,0.147474,-0.136974,0.249731 | ||
0.285747,-0.080703,-0.213976,0.011559,-0.060456,-0.258100 | ||
-0.146751,0.051221,0.329658,-0.353792,0.004466,0.183101 | ||
0.344353,-0.093221,-0.331313,0.202237,0.336726,-0.288589 | ||
0.147626,-0.002869,-0.029315,-0.290787,0.050965,-0.173026 | ||
0.051695,0.052090,0.403855,-0.115887,0.365665,0.120075 |
8 changes: 8 additions & 0 deletions
8
src/test/resources/component/transformers/bert_layer/input_W_V_test2.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
-0.334274,0.011264,-0.080849,0.244498,0.069431,0.122792,0.029533,0.165114 | ||
0.076385,-0.237900,-0.015723,-0.263194,-0.262510,-0.129004,0.044147,-0.171997 | ||
-0.198408,-0.285785,-0.215330,0.144839,0.058866,0.134202,-0.277945,-0.292986 | ||
-0.315220,0.281811,0.119572,-0.118884,0.150589,0.235453,0.027785,-0.304028 | ||
0.310023,0.057572,0.111782,-0.170578,0.139947,-0.184608,0.244825,0.352708 | ||
-0.229602,0.293317,-0.007293,0.063514,-0.044505,0.003487,0.318592,0.224432 | ||
-0.040221,-0.118525,-0.079515,-0.183656,-0.289839,0.146194,0.207801,-0.244388 | ||
0.101291,0.104141,-0.217941,0.081460,-0.054502,0.027711,0.047377,0.138325 |
6 changes: 6 additions & 0 deletions
6
src/test/resources/component/transformers/bert_layer/input_W_context_test1.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
0.295156,0.337588,-0.196067,0.148081,-0.193161,0.357983 | ||
-0.337590,-0.119339,-0.272441,-0.136338,-0.191908,-0.265121 | ||
0.005627,-0.242375,-0.235192,-0.114084,-0.385986,-0.046443 | ||
-0.069409,-0.150986,0.234725,0.120609,0.088201,0.116961 | ||
-0.215013,-0.404635,0.216198,0.335594,-0.229102,0.013006 | ||
0.053959,0.184281,0.313339,0.111000,-0.363984,-0.274703 |
8 changes: 8 additions & 0 deletions
8
src/test/resources/component/transformers/bert_layer/input_W_context_test2.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
-0.212761,-0.063619,-0.239211,-0.241276,0.290790,0.116962,-0.199302,0.226094 | ||
0.317851,0.137643,-0.348794,-0.057128,0.289475,0.353439,-0.198385,0.026281 | ||
0.170822,0.062799,-0.283923,-0.229609,0.253087,0.183375,-0.272053,-0.166925 | ||
0.192735,0.150426,0.279120,0.245506,0.272988,0.219786,0.237351,0.324932 | ||
-0.221599,-0.120147,0.191285,-0.267289,0.314375,-0.123741,0.251352,0.144582 | ||
0.101434,0.172383,0.331709,-0.172499,-0.090532,0.169645,-0.040239,-0.268398 | ||
-0.123943,-0.246947,0.283239,-0.341565,0.155564,0.040626,-0.204596,0.338380 | ||
0.276251,0.079852,-0.315739,-0.200728,0.314991,-0.084435,0.273263,0.268479 |
6 changes: 6 additions & 0 deletions
6
src/test/resources/component/transformers/bert_layer/input_W_intermediate_test1.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
-0.093357,-0.091816,-0.196967,0.067973,0.141788,0.168810,0.282700 | ||
-0.018155,-0.251656,0.073340,0.173885,-0.148960,0.031998,0.367878 | ||
-0.248641,0.282322,-0.212067,0.161597,0.154963,0.034102,0.239948 | ||
0.138070,-0.303910,0.094062,-0.051390,0.271878,0.050976,0.054707 | ||
0.129074,0.167245,0.080172,-0.334677,-0.213168,-0.320943,0.190658 | ||
-0.008422,-0.137275,-0.303120,-0.062933,0.004026,0.032084,-0.198604 |
8 changes: 8 additions & 0 deletions
8
src/test/resources/component/transformers/bert_layer/input_W_intermediate_test2.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
-0.095970,0.175377,0.300088,-0.339869,0.016123,0.019002,0.112700 | ||
-0.068211,-0.323126,0.023189,-0.343534,-0.318806,0.167661,0.189790 | ||
0.033835,-0.063266,-0.235639,-0.071723,0.293212,-0.283489,0.049253 | ||
0.326977,-0.262741,-0.126673,0.237741,0.190369,-0.101691,-0.236557 | ||
0.018929,-0.150856,0.077203,-0.334631,0.351431,-0.347146,-0.274117 | ||
-0.218298,0.127383,-0.269520,0.293869,0.178619,-0.137706,-0.109077 | ||
0.018121,-0.251069,0.175649,-0.141429,-0.233370,0.076272,0.155195 | ||
0.169524,0.131425,-0.320980,0.103550,0.295070,-0.277597,0.348744 |
Oops, something went wrong.