Skip to content

Commit

Permalink
[SYSTEMDS-3831] New builtin for vectorized simple exponential smoothing
Browse files Browse the repository at this point in the history
This patch introduces a new vectorized builtin function for vectorized
simple exponential smoothing which largely relies on cumsumprod.
  • Loading branch information
mboehm7 committed Feb 6, 2025
1 parent bea9c96 commit 52ca491
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 2 deletions.
55 changes: 55 additions & 0 deletions scripts/builtin/ses.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# Builtin function for simple exponential smoothing (SES).
#
# INPUT:
# ------------------------------------------------------------------------------
# x Time series vector [shape: n-by-1]
# h Forecasting horizon
# alpha Smoothing parameter yhat_t = alpha * x_y + (1-alpha) * yhat_t-1
# ------------------------------------------------------------------------------
#
# OUTPUT:
# ------------------------------------------------------------------------------
# yhat Forecasts [shape: h-by-1]
# ------------------------------------------------------------------------------

m_ses = function(Matrix[Double] x, Integer h = 1, Double alpha = 0.5)
return (Matrix[Double] yhat)
{
# check and ensure valid parameters
if(h < 1) {
print("SES: forecasting horizon should be larger one.");
h = 1;
}
if(alpha < 0 | alpha > 1) {
print("SES: smooting parameter should be in [0,1].");
alpha = 0.5;
}

# vectorized forecasting
# weights are 1 for first value and otherwise replicated alpha
# but to compensate alpha*x for the first, we use 1/alpha
w = rbind(as.matrix(1/alpha), matrix(1-alpha,nrow(x)-1,1));
y = cumsumprod(cbind(alpha*x, w));
yhat = matrix(as.scalar(y[nrow(x),1]), h, 1);
}
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ public enum Builtins {
SD("sd", false),
SELVARTHRESH("selectByVarThresh", true),
SEQ("seq", false),
SES("ses", true),
SYMMETRICDIFFERENCE("symmetricDifference", true),
SHAPEXPLAINER("shapExplainer", true),
SHERLOCK("sherlock", true),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.lops.WeightedUnaryMMR;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.builtin.part2;

import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.junit.Assert;
import org.junit.Test;

import java.util.HashMap;

public class BuiltinSESTest extends AutomatedTestBase {
private final static String TEST_NAME = "ses";
private final static String TEST_DIR = "functions/builtin/";
private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinSESTest.class.getSimpleName() + "/";

private final static int rows = 200;

@Override
public void setUp() {
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"y"}));
}

@Test
public void testSES05() {
runSESTest(0.5, 199d);
}

@Test
public void testSES077() {
runSESTest(0.77, 199.7013);
}

@Test
public void testSES10() {
runSESTest(1.0, 200d);
}

private void runSESTest(double alpha, double expected) {
loadTestConfiguration(getTestConfiguration(TEST_NAME));
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-args",
String.valueOf(rows), String.valueOf(alpha), output("y")};
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("y");
Assert.assertEquals(7, dmlfile.size()); //forecast horizon 7
Assert.assertEquals(expected, dmlfile.get(new CellIndex(1,1)), 1e-3);
}
}
26 changes: 26 additions & 0 deletions src/test/scripts/functions/builtin/ses.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------


x = seq(1, $1);
yhat = ses(x=x, alpha=$2, h=7)
write(yhat, $3)

0 comments on commit 52ca491

Please sign in to comment.