Skip to content

Commit

Permalink
Leaves GAM for user to define
Browse files Browse the repository at this point in the history
  • Loading branch information
vwmaus committed Oct 18, 2023
1 parent b211f54 commit c2b5f20
Show file tree
Hide file tree
Showing 17 changed files with 154 additions and 267 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ Depends:
stars,
ggplot2
Imports:
mgcv,
stats,
methods,
tidyr,
proxy
Suggests:
mgcv,
knitr,
rmarkdown,
testthat (>= 3.0.0)
Expand Down
8 changes: 2 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@
S3method(plot,twdtw_knn1)
S3method(predict,twdtw_knn1)
S3method(print,twdtw_knn1)
export(shift_dates)
export(twdtw_knn1)
import(ggplot2)
import(sf)
import(stars)
import(twdtw)
importFrom(mgcv,gam)
importFrom(mgcv,predict.gam)
importFrom(mgcv,s)
importFrom(methods,as)
importFrom(proxy,dist)
importFrom(stats,as.formula)
importFrom(stats,model.frame)
importFrom(stats,predict)
importFrom(stats,setNames)
importFrom(tidyr,nest)
importFrom(tidyr,pivot_longer)
importFrom(tidyr,pivot_wider)
Expand Down
1 change: 1 addition & 0 deletions R/prepare_time_series.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#' @return A nested tibble in wide format. Each row of the tibble corresponds to a unique 'ts_id' that maintains the order from the original stars object.
#' The nested structure contains observations (time series) for each 'ts_id', including the 'time' of each observation, and individual bands are presented as separate columns.
#'
#' @noRd
#' @keywords internal
prepare_time_series <- function(x) {

Expand Down
3 changes: 2 additions & 1 deletion R/shift_dates.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
#'
#' shift_dates(x)
#'
#' @export
#' @noRd
#' @keywords internal
shift_dates <- function(x, origin = "1970-01-01") {

# Convert the input dates to Date objects
Expand Down
151 changes: 75 additions & 76 deletions R/train.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,25 @@
#' See details in \link[twdtw]{twdtw}.
#' @param cycle_length The length of the cycle, e.g. phenological cycles. Details in \link[twdtw]{twdtw}.
#' @param time_scale Specifies the time scale for the observations. Details in \link[twdtw]{twdtw}.
#' @param smooth_fun Either NULL or a function specifying how to reduce samples of the same label.
#' Default uses Generalized Additive Models (GAM) with cubic regression splines create a temporal pattern for each label. See details.
#' @param smooth_fun a function specifying how to create temporal patterns using the samples.
#' If not defined, it will keep all samples. Note that reducing the samples to patterns can significantly
#' improve computational time of predictions. See details.
#' @param start_column Name of the column in y that indicates the start date. Default is 'start_date'.
#' @param end_column Name of the column in y that indicates the end date. Default is 'end_date'.
#' @param label_colum Name of the column in y containing land use labels. Default is 'label'.
#' @param sampling_freq The time frequency for sampling, including the unit (e.g., '16 day').
#' If NULL, the function will infer the frequency. This parameter is only used if `smooth_fun` is provided.
#' @param resampling_freq The time for sampling the time series if `smooth_fun` is given.
#' If NULL, the function will infer the frequency of observations in `x`.
#' @param ... Additional arguments passed to \link[twdtw]{twdtw}.
#'
#' @details If \code{smooth_fun} is NULL, the KNN-1 model will retain all training samples.
#' @details If \code{smooth_fun} not informed, the KNN-1 model will retain all training samples.
#'
#' If a custom smoothing function is passed to `smooth_fun`, the function will be used to
#' resample values of samples sharing the same label (land cover class). If no function is provided,
#' the default method uses Generalized Additive Models (GAM) with cubic regression splines.
#' resample values of samples sharing the same label (land cover class).
#'
#' The custom smoothing function takes two or three numeric vectors as arguments and return a single numeric vector:
#' The custom smoothing function takes two numeric vectors as arguments and returns a model:
#' \itemize{
#' \item The first argument represents the independent variable (typically time).
#' \item The second argument represents the dependent variable (e.g., band values) corresponding to each coordinate in the first argument.
#' \item Optional. The third argument specifies the locations (e.g., times) where interpolation predictions should be made.
#' }
#' See the examples section for further clarity.
#'
Expand Down Expand Up @@ -62,11 +61,13 @@
#' dc <- split(dc, c("band"))
#'
#' # Create a knn1-twdtw model
#' m <- twdtw_knn1(x = dc,
#' y = samples,
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#' m <- twdtw_knn1(
#' x = dc,
#' y = samples,
#' smooth_fun = function(x, y) gam(y ~ s(x), data = data.frame(x = x, y = y))
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#'
#' print(m)
#'
Expand All @@ -80,25 +81,25 @@
#' ggplot() +
#' geom_stars(data = lu) +
#' theme_minimal()
#'
#'
#' # Create a knn1-twdtw model with custom smoothing function
#'
#'
#'
#' # Create a knn1-twdtw model with custom smoothing function
#'
#' m <- twdtw_knn1(x = dc,
#' y = samples,
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50),
#' smooth_fun = function(x, y) tapply(y, x, mean))
#' y = samples,
#' smooth_fun = function(x, y) lm(y ~ factor(x), data = data.frame(x=x, y=y))
#' cycle_length = 'year',
#' time_scale = 'day',
#' time_weight = c(steepness = 0.1, midpoint = 50))
#'
#' plot(m)
#'
#' }
#' @export
twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
smooth_fun = approx_gam_spline, start_column = 'start_date',
end_column = 'end_date', label_colum = 'label',
sampling_freq = NULL, ...){
twdtw_knn1 <- function(x, y, smooth_fun = NULL, resampling_freq = NULL,
time_weight, cycle_length, time_scale,
start_column = 'start_date', end_column = 'end_date',
label_colum = 'label', ...){

# Check if x is a stars object with a time dimension
if (!inherits(x, "stars") || dim(x)['time'] < 1 || length(dim(x)) != 3) {
Expand Down Expand Up @@ -138,16 +139,12 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
ts_data <- prepare_time_series(as.data.frame(ts_data))
ts_data$ts_id <- NULL

smooth_models <- NULL
if(!is.null(smooth_fun)) {

# Check if smooth_fun has two or three arguments
if(!length(formals(smooth_fun)) %in% c(2, 3)) {
stop("The smooth_fun function should have two or three arguments!")
}

# Determine sampling frequency
if (is.null(sampling_freq)) {
sampling_freq <- get_time_series_freq(ts_data)
if(length(formals(smooth_fun)) != c(2)) {
stop("The smooth function should have only two arguments!")
}

# Shift dates
Expand All @@ -158,42 +155,60 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,
ts_data <- nest(ts_data, .by = 'label', .key = "observations")

# Apply smooth function
ts_data$observations <- lapply(ts_data$observations, function(ts) {
smooth_models <- lapply(ts_data$observations, function(ts) {

# Get timeline
y_time <- ts$time
ts$time <- NULL

# Fit smooth model to each band
smooth_models <- lapply(as.list(ts), function(band) {
smooth_fun(x = as.numeric(y_time), y = band)
})

return(smooth_models)

})

names(smooth_models) <- ts_data$label

ts_data$observations <- lapply(seq_along(ts_data$observations), function(l) {

# Get timeline
ts <- ts_data$observations[[l]]
y_time <- ts$time
ts$time <- NULL

# Determine pred_time
if (length(formals(smooth_fun)) == 3) {
pred_time <- seq(min(y_time), max(y_time), by = sampling_freq)
# Determine time for resampling time sereis
if (is.null(resampling_freq)) {
pred_time <- unique(y_time)
} else {
pred_time <- y_time
pred_time <- seq(min(y_time), max(y_time), by = resampling_freq)
}

# Wrapper function
wrapper_smooth_fun <- function(x, y, z = as.numeric(pred_time)) {
if (length(formals(smooth_fun)) == 3) {
return(as.vector(smooth_fun(x, y, z)))
smoothed_data <- sapply(smooth_models[[l]], function(m) {
# Determine target class
target_class <- class(model.frame(m)[, 2])
# Convert pred_time based on target class
pred_points <- if (target_class == "factor") {
factor(as.numeric(pred_time), levels = levels(model.frame(m)[, 2]))
} else {
return(as.vector(smooth_fun(x, y)))
as(as.numeric(pred_time), target_class)
}
}

# Apply the wrapper function to each band and bind results
smoothed_data <- sapply(as.list(ts), function(band) {
wrapper_smooth_fun(as.numeric(y_time), band)
predict(m, newdata = data.frame(x = pred_points))
})

# Bind time and smoothed data into a data frame
result_df <- data.frame(time = pred_time, smoothed_data)

return(result_df)
})

}

model <- list()
model$call <- match.call()
model$smooth_fun <- smooth_fun
model$smooth_models <- smooth_models
model$data <- ts_data
# add twdtw arguments to model
model$twdtw_args <- list(time_weight = time_weight,
Expand All @@ -212,28 +227,6 @@ twdtw_knn1 <- function(x, y, time_weight, cycle_length, time_scale,

}

#' Approximate temporal patterns using GAM with Cubic Regression Splines
#'
#' This function uses Generalized Additive Models (GAM) with cubic regression splines
#' to interpolate the provided data. It then predicts values at specified locations.
#'
#' @param x A numeric vector representing the independent variable (coordinates) of the points to be interpolated.
#' @param y A numeric vector representing the dependent variable (values) corresponding to each coordinate in `x`.
#' @param xout A numeric vector specifying the locations where interpolation predictions should be made.
#'
#' @return A numeric vector of predicted values at the `xout` locations based on the GAM cubic spline interpolation.
#'
#' @seealso \code{\link[mgcv]{bam}} and \code{\link[mgcv]{s}} for details on the GAM with cubic regression splines.
#'
#' @keywords internal
approx_gam_spline <- function(x, y, xout){
df <- data.frame(x = x, y = y)
gam_fit <- mgcv::bam(data = df, formula = y ~ s(x, bs = "cr"))
predict(gam_fit, newdata = data.frame(x = xout))
}



#' Print method for objects of class twdtw_knn1
#'
#' This method provides a structured printout of the important components
Expand All @@ -254,8 +247,12 @@ print.twdtw_knn1 <- function(x, ...) {
print(x$call)

# Printing the smooth_fun, if available
cat("\nSmooth function:\n")
print(x$smooth_fun)
cat("\nAdjusted R-squared of smooth models:\n")
if(is.null(x$smooth_models)){
print(NULL)
} else {
print(sapply(x$smooth_models, function(m1) sapply(m1, function(m2) summary(m2)$adj.r.squared)))
}

# Printing the data summary
cat("\nData:\n")
Expand All @@ -281,6 +278,7 @@ print.twdtw_knn1 <- function(x, ...) {
#' pretty_arguments(formals(twdtw_knn1))
#' }
#'
#' @noRd
#' @keywords internal
pretty_arguments <- function(args) {

Expand Down Expand Up @@ -318,6 +316,7 @@ pretty_arguments <- function(args) {
#'
#' @return A difftime object representing the most common time difference between consecutive samples.
#'
#' @noRd
#' @keywords internal
get_time_series_freq <- function(x) {

Expand Down
6 changes: 3 additions & 3 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#' @import sf
#' @import stars
#' @import ggplot2
#' @importFrom stats as.formula predict setNames
#' @importFrom mgcv gam s predict.gam
#' @importFrom stats predict model.frame
#' @importFrom methods as
#' @importFrom tidyr pivot_longer pivot_wider nest unnest
#' @importFrom proxy dist
#'
#'
NULL
26 changes: 0 additions & 26 deletions man/approx_gam_spline.Rd

This file was deleted.

19 changes: 0 additions & 19 deletions man/get_time_series_freq.Rd

This file was deleted.

Loading

0 comments on commit c2b5f20

Please sign in to comment.