From 52ca4913155401180ccf66c85db8be08e86f9388 Mon Sep 17 00:00:00 2001 From: Matthias Boehm Date: Thu, 6 Feb 2025 17:27:32 +0100 Subject: [PATCH] [SYSTEMDS-3831] New builtin for vectorized simple exponential smoothing This patch introduces a new vectorized builtin function for vectorized simple exponential smoothing which largely relies on cumsumprod. --- scripts/builtin/ses.dml | 55 +++++++++++++++ .../org/apache/sysds/common/Builtins.java | 1 + .../instructions/SPInstructionParser.java | 1 - .../cp/CompressionCPInstruction.java | 1 - .../builtin/part2/BuiltinSESTest.java | 68 +++++++++++++++++++ src/test/scripts/functions/builtin/ses.dml | 26 +++++++ 6 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 scripts/builtin/ses.dml create mode 100644 src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java create mode 100644 src/test/scripts/functions/builtin/ses.dml diff --git a/scripts/builtin/ses.dml b/scripts/builtin/ses.dml new file mode 100644 index 00000000000..f4b82ad3905 --- /dev/null +++ b/scripts/builtin/ses.dml @@ -0,0 +1,55 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Builtin function for simple exponential smoothing (SES). +# +# INPUT: +# ------------------------------------------------------------------------------ +# x Time series vector [shape: n-by-1] +# h Forecasting horizon +# alpha Smoothing parameter yhat_t = alpha * x_y + (1-alpha) * yhat_t-1 +# ------------------------------------------------------------------------------ +# +# OUTPUT: +# ------------------------------------------------------------------------------ +# yhat Forecasts [shape: h-by-1] +# ------------------------------------------------------------------------------ + +m_ses = function(Matrix[Double] x, Integer h = 1, Double alpha = 0.5) + return (Matrix[Double] yhat) +{ + # check and ensure valid parameters + if(h < 1) { + print("SES: forecasting horizon should be larger one."); + h = 1; + } + if(alpha < 0 | alpha > 1) { + print("SES: smooting parameter should be in [0,1]."); + alpha = 0.5; + } + + # vectorized forecasting + # weights are 1 for first value and otherwise replicated alpha + # but to compensate alpha*x for the first, we use 1/alpha + w = rbind(as.matrix(1/alpha), matrix(1-alpha,nrow(x)-1,1)); + y = cumsumprod(cbind(alpha*x, w)); + yhat = matrix(as.scalar(y[nrow(x),1]), h, 1); +} diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 1a7ba207b81..4ff5654de02 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -301,6 +301,7 @@ public enum Builtins { SD("sd", false), SELVARTHRESH("selectByVarThresh", true), SEQ("seq", false), + SES("ses", true), SYMMETRICDIFFERENCE("symmetricDifference", true), SHAPEXPLAINER("shapExplainer", true), SHERLOCK("sherlock", true), diff --git a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java index 5014c0ac30e..e08ef64ab85 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java @@ -39,7 +39,6 @@ import org.apache.sysds.lops.WeightedUnaryMM; import org.apache.sysds.lops.WeightedUnaryMMR; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.spark.AggregateTernarySPInstruction; import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index 4216385b722..efc8e217771 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -22,7 +22,6 @@ import java.util.ArrayList; import java.util.List; -import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java new file mode 100644 index 00000000000..55735495528 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSESTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part2; + +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashMap; + +public class BuiltinSESTest extends AutomatedTestBase { + private final static String TEST_NAME = "ses"; + private final static String TEST_DIR = "functions/builtin/"; + private final static String TEST_CLASS_DIR = TEST_DIR + BuiltinSESTest.class.getSimpleName() + "/"; + + private final static int rows = 200; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"y"})); + } + + @Test + public void testSES05() { + runSESTest(0.5, 199d); + } + + @Test + public void testSES077() { + runSESTest(0.77, 199.7013); + } + + @Test + public void testSES10() { + runSESTest(1.0, 200d); + } + + private void runSESTest(double alpha, double expected) { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-args", + String.valueOf(rows), String.valueOf(alpha), output("y")}; + runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); + HashMap dmlfile = readDMLMatrixFromOutputDir("y"); + Assert.assertEquals(7, dmlfile.size()); //forecast horizon 7 + Assert.assertEquals(expected, dmlfile.get(new CellIndex(1,1)), 1e-3); + } +} diff --git a/src/test/scripts/functions/builtin/ses.dml b/src/test/scripts/functions/builtin/ses.dml new file mode 100644 index 00000000000..2148854c6e5 --- /dev/null +++ b/src/test/scripts/functions/builtin/ses.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + +x = seq(1, $1); +yhat = ses(x=x, alpha=$2, h=7) +write(yhat, $3) +