Skip to content

Commit

Permalink
Fixes gam segfault from C stack overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
vwmaus committed Oct 17, 2023
1 parent e48ea07 commit 563112e
Show file tree
Hide file tree
Showing 15 changed files with 225 additions and 73 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ examples_x64
^doc$
^Meta$

^\.vscode$
vignettes.awk
_pkgdown.yml
^_pkgdown\.yml$
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ revdep/
CRAN-SUBMISSION

# Other files
.vscode
src/symbols.rds
*.o
*.so
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Package: dtwSat
Type: Package
Title: Time-Weighted Dynamic Time Warping for Satellite Image Time Series Analysis
Version: 1.0-1
Date: 2023-09-25
Date: 2023-10-18
Authors@R:
c(person(given = "Victor",
family = "Maus",
Expand Down Expand Up @@ -32,8 +32,8 @@ Description: Provides a robust approach to land use mapping using multi-dimensio
while also requiring minimal training sets. The package includes tools for training the 1-NN-TWDTW model,
visualizing temporal patterns, producing land use maps, and visualizing the results.
License: GPL (>= 3)
URL: https://github.com/vwmaus/dtwSat/
BugReports: https://github.com/vwmaus/dtwSat/issues/
URL: https://github.com/r-spatial/dtwSat/
BugReports: https://github.com/r-spatial/dtwSat/issues/
Maintainer: Victor Maus <[email protected]>
VignetteBuilder:
knitr
Expand Down
2 changes: 1 addition & 1 deletion R/prepare_time_series.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' @return A nested tibble in wide format. Each row of the tibble corresponds to a unique 'ts_id' that maintains the order from the original stars object.
#' The nested structure contains observations (time series) for each 'ts_id', including the 'time' of each observation, and individual bands are presented as separate columns.
#'
#'
#' @keywords internal
prepare_time_series <- function(x) {

# Remove the 'geom' column if it exists
Expand Down
136 changes: 98 additions & 38 deletions R/train.R
Original file line number Diff line number Diff line change
@@ -1,37 +1,48 @@
#'
#' Train a KNN-1 TWDTW model with optional GAM resampling
#' Train a KNN-1 TWDTW model
#'
#' This function prepares a KNN-1 model with the Time Warp Dynamic Time Warping (TWDTW) algorithm.
#' If a formula is provided, the training samples are resampled using Generalized Additive Models (GAM).
#'
#' @param x A three-dimensional stars object (x, y, time) with bands as attributes.
#' @param y An sf object with the coordinates of the training points.
#' @param time_weight A numeric vector with length two (steepness and midpoint of logistic weight) or a function.
#' See details in \link[twdtw]{twdtw}.
#' @param cycle_length The length of the cycle, e.g. phenological cycles. Details in \link[twdtw]{twdtw}.
#' @param time_scale Specifies the time scale for the observations. Details in \link[twdtw]{twdtw}.
#' @param formula Either NULL or a formula to reduce samples of the same label using Generalized Additive Models (GAM).
#' Default is \code{band ~ s(time)}. See details.
#' @param smooth_fun Either NULL or a function specifying how to reduce samples of the same label.
#' Default uses Generalized Additive Models (GAM) with cubic regression splines create a temporal pattern for each label. See details.
#' @param start_column Name of the column in y that indicates the start date. Default is 'start_date'.
#' @param end_column Name of the column in y that indicates the end date. Default is 'end_date'.
#' @param label_colum Name of the column in y containing land use labels. Default is 'label'.
#' @param sampling_freq The time frequency for sampling, including the unit (e.g., '16 day').
#' If NULL, the function will infer the frequency. This parameter is only used if a formula is provided.
#' @param ... Additional arguments passed to the \link[mgcv]{gam} function and to \link[twdtw]{twdtw} function.
#' If NULL, the function will infer the frequency. This parameter is only used if `smooth_fun` is provided.
#' @param ... Additional arguments passed to \link[twdtw]{twdtw}.
#'
#' @details If \code{formula} is NULL, the KNN-1 model will retain all training samples. If a formula is passed (e.g., \code{band ~ \link[mgcv]{s}(time)}),
#' then samples of the same label (land cover class) will be resampled using GAM.
#' Resampling can significantly reduce prediction processing time.
#' @details If \code{smooth_fun} is NULL, the KNN-1 model will retain all training samples.
#'
#' If a custom smoothing function is passed to `smooth_fun`, the function will be used to
#' resample values of samples sharing the same label (land cover class). If no function is provided,
#' the default method uses Generalized Additive Models (GAM) with cubic regression splines.
#'
#' The custom smoothing function takes two or three numeric vectors as arguments and return a single numeric vector:
#' \itemize{
#' \item The first argument represents the independent variable (typically time).
#' \item The second argument represents the dependent variable (e.g., band values) corresponding to each coordinate in the first argument.
#' \item Optional. The third argument specifies the locations (e.g., times) where interpolation predictions should be made.
#' }
#' See the examples section for further clarity.
#'
#' Smooting the samples can significantly reduce the processing time for prediction using `twdtw_knn1` model.
#'
#' @return A 'twdtw_knn1' model containing the trained model information and the data used.
#'
#' @examples
#' \dontrun{
#'
#' # Read training samples
#' samples_path <-
#' system.file("mato_grosso_brazil/samples.gpkg", package = "dtwSat")
#'
#' samples_path <-
# ' system.file("mato_grosso_brazil/samples.gpkg", package = "dtwSat")
#'
#' samples <- st_read(samples_path, quiet = TRUE)
#'
#' # Get satellite image time sereis files
Expand All @@ -55,8 +66,7 @@
#' y = samples,
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50),
#' formula = band ~ s(time))
#' time_weight = c(steepness = 0.1, midpoint = 50))
#'
#' print(m)
#'
Expand All @@ -70,11 +80,23 @@
#' ggplot() +
#' geom_stars(data = lu) +
#' theme_minimal()
#'
#'
#' # Create a knn1-twdtw model with custom smoothing function
#'
#' m <- twdtw_knn1(x = dc,
#' y = samples,
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50),
#' smooth_fun = function(x, y) tapply(y, x, mean))
#'
#' plot(m)
#'
#' }
#' @export
twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
formula = NULL, start_column = 'start_date',
smooth_fun = approx_gam_spline, start_column = 'start_date',
end_column = 'end_date', label_colum = 'label',
sampling_freq = NULL, ...){

Expand Down Expand Up @@ -116,11 +138,11 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
ts_data <- prepare_time_series(as.data.frame(ts_data))
ts_data$ts_id <- NULL

if(!is.null(formula)) {
if(!is.null(smooth_fun)) {

# Check if formula has two
if(length(all.vars(formula)) != 2) {
stop("The formula should have only one predictor!")
# Check if smooth_fun has two or three arguments
if(!length(formals(smooth_fun)) %in% c(2, 3)) {
stop("The smooth_fun function should have two or three arguments!")
}

# Determine sampling frequency
Expand All @@ -135,29 +157,43 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
ts_data <- unnest(ts_data, cols = 'observations')
ts_data <- nest(ts_data, .by = 'label', .key = "observations")

# Define GAM function
gam_fun <- function(band, t, pred_t, formula, ...){
df <- setNames(list(band, as.numeric(t)), all.vars(formula))
pred_t[[all.vars(formula)[2]]] <- as.numeric(pred_t[[all.vars(formula)[2]]])
fit <- mgcv::gam(data = df, formula = formula, ...)
predict(fit, newdata = pred_t)
}

# Apply GAM function
ts_data$observations <- lapply(ts_data$observations, function(ts){
# Apply smooth function
ts_data$observations <- lapply(ts_data$observations, function(ts) {
y_time <- ts$time
ts$time <- NULL
pred_time <- setNames(list(seq(min(y_time), max(y_time), by = sampling_freq)), all.vars(formula)[2])
cbind(pred_time, as.data.frame(sapply(as.list(ts), function(band) {
gam_fun(band, y_time, pred_time, formula, ...)
})))

# Determine pred_time
if (length(formals(smooth_fun)) == 3) {
pred_time <- seq(min(y_time), max(y_time), by = sampling_freq)
} else {
pred_time <- y_time
}

# Wrapper function
wrapper_smooth_fun <- function(x, y, z = as.numeric(pred_time)) {
if (length(formals(smooth_fun)) == 3) {
return(as.vector(smooth_fun(x, y, z)))
} else {
return(as.vector(smooth_fun(x, y)))
}
}

# Apply the wrapper function to each band and bind results
smoothed_data <- sapply(as.list(ts), function(band) {
wrapper_smooth_fun(as.numeric(y_time), band)
})

# Bind time and smoothed data into a data frame
result_df <- data.frame(time = pred_time, smoothed_data)

return(result_df)
})

}

model <- list()
model$call <- match.call()
model$formula <- formula
model$smooth_fun <- smooth_fun
model$data <- ts_data
# add twdtw arguments to model
model$twdtw_args <- list(time_weight = time_weight,
Expand All @@ -176,14 +212,36 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,

}

#' Approximate temporal patterns using GAM with Cubic Regression Splines
#'
#' This function uses Generalized Additive Models (GAM) with cubic regression splines
#' to interpolate the provided data. It then predicts values at specified locations.
#'
#' @param x A numeric vector representing the independent variable (coordinates) of the points to be interpolated.
#' @param y A numeric vector representing the dependent variable (values) corresponding to each coordinate in `x`.
#' @param xout A numeric vector specifying the locations where interpolation predictions should be made.
#'
#' @return A numeric vector of predicted values at the `xout` locations based on the GAM cubic spline interpolation.
#'
#' @seealso \code{\link[mgcv]{bam}} and \code{\link[mgcv]{s}} for details on the GAM with cubic regression splines.
#'
#' @keywords internal
approx_gam_spline <- function(x, y, xout){
df <- data.frame(x = x, y = y)
gam_fit <- mgcv::bam(data = df, formula = y ~ s(x, bs = "cr"))
predict(gam_fit, newdata = data.frame(x = xout))
}



#' Print method for objects of class twdtw_knn1
#'
#' This method provides a structured printout of the important components
#' of a `twdtw_knn1` object.
#'
#' @param x An object of class `twdtw_knn1`.
#' @param ... ignored
#'
#'
#' @return Invisible `twdtw_knn1` object.
#'
#' @export
Expand All @@ -195,9 +253,9 @@ print.twdtw_knn1 <- function(x, ...) {
cat("Call:\n")
print(x$call)

# Printing the formula, if available
cat("\nFormula:\n")
print(x$formula)
# Printing the smooth_fun, if available
cat("\nSmooth function:\n")
print(x$smooth_fun)

# Printing the data summary
cat("\nData:\n")
Expand All @@ -223,6 +281,7 @@ print.twdtw_knn1 <- function(x, ...) {
#' pretty_arguments(formals(twdtw_knn1))
#' }
#'
#' @keywords internal
pretty_arguments <- function(args) {

if (is.null(args)) {
Expand Down Expand Up @@ -259,6 +318,7 @@ pretty_arguments <- function(args) {
#'
#' @return A difftime object representing the most common time difference between consecutive samples.
#'
#' @keywords internal
get_time_series_freq <- function(x) {

# Extract the time dimension
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<!-- badges: start -->
[![License](https://img.shields.io/badge/license-GPL%20%28%3E=%202%29-brightgreen.svg?style=flat)](https://www.gnu.org/licenses/gpl-3.0.html)
[![R-CMD-check](https://github.com/vwmaus/dtwSat/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/vwmaus/dtwSat/actions/workflows/R-CMD-check.yaml)
[![R-CMD-check](https://github.com/r-spatial/dtwSat/actions/workflows/R-CMD-check.yaml/badge.svg)](https://github.com/r-spatial/dtwSat/actions/workflows/R-CMD-check.yaml)
[![Coverage Status](https://img.shields.io/codecov/c/github/vwmaus/dtwSat/main.svg)](https://app.codecov.io/gh/vwmaus/dtwSat)
[![CRAN](https://www.r-pkg.org/badges/version/dtwSat)](https://cran.r-project.org/package=dtwSat)
[![Downloads](https://cranlogs.r-pkg.org/badges/dtwSat?color=brightgreen)](https://www.r-pkg.org/pkg/dtwSat)
Expand Down Expand Up @@ -33,7 +33,7 @@ install.packages("dtwSat")
Alternatively, you can install the development version from GitHub:

``` r
devtools::install_github("vwmaus/dtwSat")
devtools::install_github("r-spatial/dtwSat")
```

After installation, you can read the vignette for a quick start guide:
Expand Down
26 changes: 26 additions & 0 deletions man/approx_gam_spline.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/get_time_series_freq.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 14 additions & 4 deletions man/plot.twdtw_knn1.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 563112e

Please sign in to comment.