Skip to content

Commit

Permalink
feat options for initial parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AparicioJohan committed Dec 16, 2024
1 parent 52c7a82 commit c095491
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 37 deletions.
12 changes: 6 additions & 6 deletions R/01_read.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ explorer <- function(data, x, y, id, metadata) {
}
x <- names(select(data, {{ x }}))
y <- names(select(data, {{ y }}))
.keep <- names(select(data, {{ metadata }}))
metadata <- names(select(data, {{ metadata }}))
for (i in y) {
class_trait <- data[[i]] |> class()
if (!class_trait %in% c("numeric", "integer")) {
Expand All @@ -62,14 +62,14 @@ explorer <- function(data, x, y, id, metadata) {
)
}
}
check_metadata(data, .keep)
check_metadata(data, metadata)
data <- data |>
select(all_of(c(id, .keep, x, y))) |>
select(all_of(c(id, metadata, x, y))) |>
mutate(uid = .data[[id]], .keep = "unused", .before = 0) |>
rename(x = all_of(x))
resum <- summarize_metadata(data, cols = c("uid", "x", .keep))
resum <- summarize_metadata(data, cols = c("uid", "x", metadata))
dt_long <- data |>
select(uid, all_of(.keep), x, all_of(y)) |>
select(uid, all_of(metadata), x, all_of(y)) |>
pivot_longer(all_of(y), names_to = "var", values_to = "y") |>
relocate(x, .after = var)
summ_vars <- dt_long |>
Expand Down Expand Up @@ -104,7 +104,7 @@ explorer <- function(data, x, y, id, metadata) {
summ_metadata = resum,
locals_min_max = max_min,
dt_long = dt_long,
metadata = .keep,
metadata = metadata,
x_var = x
)
class(out) <- "explorer"
Expand Down
69 changes: 44 additions & 25 deletions R/02_modeler.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
#' @param keep The names of the columns in `data` to keep across the analysis.
#' @param fn A string specifying the name of the function to be used for the curve fitting. Default is \code{"fn_linear_sat"}.
#' @param parameters Can be a named numeric vector specifying the initial values for the parameters to be optimized,
#' or a data frame with columns \code{uid}, and the initial parameter values for each group id. Used for providing specific
#' initial values per group id. Default is \code{NULL}.
#' If a data frame it needs to have a column \code{uid}, and the initial parameter values for each group id.
#' If list it needs to be a named list with the initial parameters and numeric or string values can be used (e.g. list(k = "max(y)", t1 = 40)).
#' Used for providing specific initial values per group id. Default is \code{NULL}.
#' @param lower Numeric vector specifying the lower bounds for the parameters. Default is \code{-Inf} for all parameters.
#' @param upper Numeric vector specifying the upper bounds for the parameters. Default is \code{Inf} for all parameters.
#' @param fixed_params Can be a list or data frame. If data frame it needs columns \code{uid}, and the fixed parameter values for each group id.
Expand Down Expand Up @@ -100,7 +101,7 @@ modeler <- function(data,
if (!inherits(x, "explorer")) {
stop("The object should be of class 'explorer'.")
}
.keep <- x$metadata
metadata <- x$metadata
variable <- unique(x$summ_vars$var)
if (length(variable) != 1) stop("Only single response is allowed.")
# Validate options
Expand Down Expand Up @@ -139,11 +140,11 @@ modeler <- function(data,
droplevels()
if (max_as_last) {
dt <- dt |>
group_by(uid, across(all_of(.keep))) |>
group_by(uid, across(all_of(metadata))) |>
mutate(max = max(y, na.rm = TRUE), pos = x[which.max(y)]) |>
mutate(y = ifelse(x <= pos, y, max)) |>
select(-max, -pos) |>
ungroup() # max_as_last(dt, .keep = .keep)
ungroup()
}
if (check_negative) {
dt <- mutate(dt, y = ifelse(y < 0, 0, y))
Expand All @@ -168,14 +169,14 @@ modeler <- function(data,
if (!all(nam_fix_params %in% args)) {
stop("All fixed_params must be in:", fn)
}
if (length(args) - length(nam_fix_params) <= 1 ) {
if (length(args) - length(nam_fix_params) <= 1) {
stop("More than one parameter needs to be free.")
}
}
# Validate initial values
if (is.null(parameters)) {
stop("Initial parameters need to be provided.")
} else if (is.numeric(parameters)) {
} else if (is.numeric(parameters)) { # Numeric Vector
if (!sum(names(parameters) %in% args) == length(args)) {
stop("names of parameters have to be in: ", fn)
}
Expand All @@ -186,7 +187,7 @@ modeler <- function(data,
pivot_longer(cols = -c(uid), names_to = "coef") |>
nest_by(uid, .key = "initials") |>
mutate(initials = list(pull(initials, value, coef)))
} else if ("data.frame" %in% class(parameters)) {
} else if ("data.frame" %in% class(parameters)) { # Data.frame
nam_ini_vals <- colnames(parameters)
if (!"uid" %in% nam_ini_vals) {
stop("parameters should contain columns 'uid'.")
Expand All @@ -198,6 +199,29 @@ modeler <- function(data,
pivot_longer(cols = -c(uid), names_to = "coef") |>
nest_by(uid, .key = "initials") |>
mutate(initials = list(pull(initials, value, coef)))
} else if ("list" %in% class(parameters)) { # List
if (!sum(names(parameters) %in% args) == length(args)) {
stop("parameters should have the same parameters as the function: ", fn)
}
init <- dt |>
select(uid, x, y) |>
group_by(uid)
for (j in names(parameters)) {
str <- parameters[[j]]
if ("numeric" %in% class(str)) {
express <- str
} else if ("character" %in% class(str)) {
express <- rlang::parse_expr(str)
}
init <- mutate(init, "{j}" := !!express)
}
init <- init |>
ungroup() |>
select(uid, all_of(names(parameters))) |>
unique.data.frame() |>
pivot_longer(cols = -c(uid), names_to = "coef") |>
nest_by(uid, .key = "initials") |>
mutate(initials = list(pull(initials, value, coef)))
}
# Merging with fixed parameters
if (!is.null(fixed_params)) {
Expand All @@ -206,11 +230,6 @@ modeler <- function(data,
pivot_longer(cols = -c(uid), names_to = "coef") |>
nest_by(uid, .key = "fx_params") |>
mutate(fx_params = list(pull(fx_params, value, coef)))
init <- init |>
full_join(fixed, by = c("uid")) |>
mutate(
initials = list(initials[!names(initials) %in% names(fixed_params)])
)
} else if ("list" %in% class(fixed_params)) {
fixed <- dt |>
select(uid, x, y) |>
Expand All @@ -231,12 +250,12 @@ modeler <- function(data,
pivot_longer(cols = -c(uid), names_to = "coef") |>
nest_by(uid, .key = "fx_params") |>
mutate(fx_params = list(pull(fx_params, value, coef)))
init <- init |>
full_join(fixed, by = c("uid")) |>
mutate(
initials = list(initials[!names(initials) %in% names(fixed_params)])
)
}
init <- init |>
full_join(fixed, by = c("uid")) |>
mutate(
initials = list(initials[!names(initials) %in% names(fixed_params)])
)
} else {
fixed <- dt |>
select(uid) |>
Expand All @@ -251,7 +270,7 @@ modeler <- function(data,
fixed <- droplevels(filter(fixed, uid %in% subset))
}
dt_nest <- dt |>
nest_by(uid, across(all_of(.keep))) |>
nest_by(uid, across(all_of(metadata))) |>
full_join(init, by = c("uid"))
if (nrow(dt_nest) == 0) {
stop("Check the ids for which you are filtering.")
Expand Down Expand Up @@ -296,7 +315,7 @@ modeler <- function(data,
upper = upper,
trace = trace,
control = control,
.keep = .keep
metadata = metadata
)
}
end_time <- Sys.time()
Expand Down Expand Up @@ -340,7 +359,7 @@ modeler <- function(data,
execution = end_time - init_time,
response = variable,
x_var = x$x_var,
keep = .keep,
keep = metadata,
fun = fn,
parallel = list("parallel" = parallel, "workers" = workers),
fit = objt
Expand Down Expand Up @@ -392,7 +411,7 @@ modeler <- function(data,
lower,
upper,
control,
.keep,
metadata,
trace) {
dt <- data[data$uid == id, ]
initials <- unlist(dt$initials)
Expand All @@ -415,7 +434,7 @@ modeler <- function(data,
)
# metadata
rr <- cbind(
dt[, c("uid", .keep)],
dt[, c("uid", metadata)],
kkopt |>
tibble::rownames_to_column(var = "method") |>
dplyr::rename(sse = value) |>
Expand Down Expand Up @@ -459,9 +478,9 @@ modeler <- function(data,
}

#' @noRd
max_as_last <- function(data, .keep) {
max_as_last <- function(data, metadata) {
dt_can <- data |>
group_by(uid, across(all_of(.keep))) |>
group_by(uid, across(all_of(metadata))) |>
mutate(
loc_max_at = paste(local_min_max(y, x)$days_max, collapse = "_"),
loc_max = as.numeric(local_min_max(y, x)$days_max[1])
Expand Down
6 changes: 3 additions & 3 deletions R/utils_S3_plots.R
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ plot.explorer <- function(x,
base_size = 13,
return_gg = FALSE,
add_avg = FALSE, ...) {
.keep <- x$metadata
metadata <- x$metadata
colours <- c("#db4437", "white", "#4285f4")
flt <- x$summ_vars |>
filter(`miss%` <= 0.2) |> # & SD > 0
Expand All @@ -399,7 +399,7 @@ plot.explorer <- function(x,
}
var_by_x <- data |>
pivot_wider(names_from = var, values_from = y) |>
select(-c(uid, all_of(.keep))) |>
select(-c(uid, all_of(metadata))) |>
nest_by(x) |>
mutate(
mat = list(
Expand Down Expand Up @@ -453,7 +453,7 @@ plot.explorer <- function(x,
if (type == "x_by_var" || type == 2) {
x_by_var <- data |>
pivot_wider(names_from = x, values_from = y) |>
select(-c(uid, all_of(.keep))) |>
select(-c(uid, all_of(metadata))) |>
nest_by(var) |>
mutate(
mat = list(
Expand Down
2 changes: 1 addition & 1 deletion man/dot-fitter_curve.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions man/modeler.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit c095491

Please sign in to comment.