Skip to content

Commit

Permalink
[MINOR] Change DNNLSTM to use MatrixBlockReshape
Browse files Browse the repository at this point in the history
  • Loading branch information
Baunsgaard committed Feb 5, 2025
1 parent 5ff6274 commit ee0536d
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ public static long lstmGeneric(DnnParameters params) {

//store caches
ifog = ifo.append(g, true);
MatrixBlock cache_out_t = LibMatrixReorg.reshape(out, new MatrixBlock(), 1, cache_out.clen, true);
MatrixBlock cache_out_t = out.reshape( 1, cache_out.clen, true);
cache_out.leftIndexingOperations(cache_out_t, t, t,0, cache_out.clen - 1, null, MatrixObject.UpdateType.INPLACE );

MatrixBlock cache_c_t = LibMatrixReorg.reshape(c, new MatrixBlock(), 1, cache_c.clen, true);
MatrixBlock cache_c_t = c.reshape(1,cache_c.clen, true);
cache_c.leftIndexingOperations(cache_c_t, t, t,0, cache_c.clen - 1, null, MatrixObject.UpdateType.INPLACE );

MatrixBlock cache_ifog_t = LibMatrixReorg.reshape(ifog, new MatrixBlock(), 1, cache_ifog.clen, true);
MatrixBlock cache_ifog_t = ifog.reshape(1, cache_ifog.clen, true);
cache_ifog.leftIndexingOperations(cache_ifog_t, t, t,0,cache_ifog.clen - 1, null, MatrixObject.UpdateType.INPLACE );
}
return params.output.recomputeNonZeros();
Expand Down Expand Up @@ -373,9 +373,9 @@ public static long lstmBackwardGeneric(DnnParameters params) {
dout_prev = dout.slice(0, dout.rlen-1, t*M, (t+1)*M - 1).binaryOperations(plus, dout_prev);

//load and reuse cached results from forward pass for the current time step
MatrixBlock c_t = LibMatrixReorg.reshape(cache_c.slice(t, t, 0, cache_c.clen - 1), new MatrixBlock(), params.N, M, true);
MatrixBlock c_prev = t==0 ? c0 : LibMatrixReorg.reshape(cache_c.slice(t - 1, t - 1, 0, cache_c.clen - 1), new MatrixBlock(), params.N, M, true);
MatrixBlock ifog = LibMatrixReorg.reshape(cache_ifog.slice(t, t,0, cache_ifog.clen - 1), new MatrixBlock(), params.N, 4*M, true);
MatrixBlock c_t = cache_c.slice(t, t, 0, cache_c.clen - 1).reshape( params.N, M, true);
MatrixBlock c_prev = t==0 ? c0 : cache_c.slice(t - 1, t - 1, 0, cache_c.clen - 1).reshape(params.N, M, true);
MatrixBlock ifog = cache_ifog.slice(t, t,0, cache_ifog.clen - 1).reshape(params.N, 4*M, true);
MatrixBlock i = ifog.slice(0, ifog.rlen - 1, 0, M -1);
MatrixBlock f = ifog.slice(0, ifog.rlen - 1, M, 2*M -1);
MatrixBlock o = ifog.slice(0, ifog.rlen - 1, 2*M, 3*M -1);
Expand Down Expand Up @@ -422,7 +422,7 @@ public static long lstmBackwardGeneric(DnnParameters params) {

//load the current input vector and in the cached previous hidden state
MatrixBlock x_t = x.slice(0, x.rlen - 1, t*params.D , (t+1)*params.D - 1);
MatrixBlock out_prev = t==0 ? out0 : LibMatrixReorg.reshape(cache_out.slice(t - 1, t - 1, 0, cache_out.clen - 1), new MatrixBlock(), params.N, M, true);
MatrixBlock out_prev = t==0 ? out0 : cache_out.slice(t - 1, t - 1, 0, cache_out.clen - 1).reshape( params.N, M, true);

//merge mm for dx and dout_prev: input = cbind(X_t, out_prev) # shape (N, D+M)
MatrixBlock in_t = x_t.append(out_prev, true).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
Expand Down

0 comments on commit ee0536d

Please sign in to comment.