Skip to content
This repository has been archived by the owner on Jun 14, 2023. It is now read-only.

Again Error for Unusual Split of Data into Test and Train Sets #260

Open
larry77 opened this issue Aug 5, 2022 · 1 comment
Open

Again Error for Unusual Split of Data into Test and Train Sets #260

larry77 opened this issue Aug 5, 2022 · 1 comment

Comments

@larry77
Copy link

larry77 commented Aug 5, 2022

Dear All,
I am going back to tidymodels and I dusted off the example discussed at https://github.com/tidymodels/tidymodels.org/issues/198 .
Please have a look at the reprex.

Now I get a new error/warning (and I am sure I did not have this specific issue in the past) which seems to be due to the train data consisting only of one data point, but I am not 100% sure of this nor do I know what to do if this is the case (I really need this unusual data split).
I have two questions

  1. can anyone tell me what to do to fix my code?
  2. given that my data has a strong time component (in df_ini there is a "year" column), I think I should use a different resampling technique (see https://www.tmwr.org/resampling.html#rolling ).
    It must be a one-liner, but I am experiencing some issues. Can anyone show me how to implement in my example (once it has been fixed) the rolling forecasting origin resampling, e.g. non cumulative, analysis size of eight samples (eight years) and an assessment set size of two (two years)?

Thanks a lot!

library(tidymodels)

tidymodels_prefer() 

df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008, 
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018), 
    capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3, 
    3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 
    3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9, 
    7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 
    11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 
    17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605, 
    19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
    20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
    29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53, 
    2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 
    2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
    ), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57, 
    2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 
    2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93, 
    389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 
    392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
    ), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63, 
    515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 
    524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59, 
    1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 
    1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 
    2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75, 
    3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 
    3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 
    4171.72), employment_total_lag_1 = c(14509.58, 15127.99, 
    15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 
    16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 
    17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7, 
    220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9, 
    288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 
    344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8, 
    169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
    213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
    249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6, 
    71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
    91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
    129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4, 
    28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
    39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
    59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7, 
    2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1, 
    48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 
    56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 
    71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4, 
    42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 
    51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
    ), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2, 
    8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 
    10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6, 
    9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 
    12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 
    13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6, 
    16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
    24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
    ), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6, 
    38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 
    49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4, 
    197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6, 
    262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 
    307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2, 
    3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 
    3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
    ), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3, 
    8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
    13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5, 
    19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 
    23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 
    25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19, 
    2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 
    2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 
    2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53, 
    2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 
    2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
    ), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45, 
    387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 
    419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9, 
    505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14, 
    546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
    ), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55, 
    1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55, 
    1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85, 
    3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 
    3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 
    4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87, 
    15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 
    16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 
    17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7, 
    213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4, 
    283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 
    323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8, 
    156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102, 
    200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 
    238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4, 
    67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 
    113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
    126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9, 
    28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 
    38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
    55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2, 
    50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 
    61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4, 
    38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4, 
    51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
    57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1, 
    7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 
    9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191, 
    10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 
    13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 
    13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074, 
    15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 
    23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
    ), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8, 
    37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 
    47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8, 
    190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076, 
    253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
    297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884, 
    3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902, 
    5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488, 
    7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df", 
"tbl", "data.frame"))





set.seed(1234)  ## to make the results reproducible






## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set

## see https://github.com/tidymodels/rsample/issues/158


indices <-
  list(analysis   = seq(nrow(df_ini)-1), 
       assessment = nrow(df_ini)
       )

df_split <- make_splits(indices, df_ini)


## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works

df_train <- training(df_split)
df_test <- testing(df_split)

folded_data <- vfold_cv(df_train,3)



glmnet_recipe <- 
    recipe(formula = berd ~ ., data = df_train) |> 
    update_role(year, new_role = "ID") |> 
  step_zv(all_predictors()) |> 
  step_normalize(all_predictors(), -all_nominal()) 

glmnet_spec <- 
  linear_reg(penalty = tune(), mixture = tune()) |> 
  set_mode("regression") |> 
  set_engine("glmnet") 

glmnet_workflow <- 
  workflow() |> 
  add_recipe(glmnet_recipe) |> 
  add_model(glmnet_spec) 




glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05, 
    0.2, 0.4, 0.6, 0.8, 1)) 

glmnet_tune <- 
  tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) ) 

print(collect_metrics(glmnet_tune))
#> # A tibble: 240 × 8
#>       penalty mixture .metric .estimator    mean     n std_err .config          
#>         <dbl>   <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>            
#>  1 0.000001      0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  2 0.000001      0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  3 0.00000183    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  4 0.00000183    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  5 0.00000336    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  6 0.00000336    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  7 0.00000616    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  8 0.00000616    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  9 0.0000113     0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#> 10 0.0000113     0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#> # … with 230 more rows

print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 × 8
#>      penalty mixture .metric .estimator  mean     n std_err .config             
#>        <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 0.000001      0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 2 0.00000183    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 3 0.00000336    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 4 0.00000616    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 5 0.0000113     0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…

best_net <- select_best(glmnet_tune, "rmse")


final_net <- finalize_workflow(
  glmnet_workflow,
  best_net
)


final_res_net <- last_fit(final_net, df_split)
#> ! train/test split: internal: A correlation computation is required, but the inputs are size zero or o...


print(final_res_net)
#> # Resampling results
#> # Manual resampling 
#> # A tibble: 1 × 6
#>   splits         id               .metrics .notes   .predictions     .workflow 
#>   <list>         <chr>            <list>   <list>   <list>           <list>    
#> 1 <split [16/1]> train/test split <tibble> <tibble> <tibble [1 × 4]> <workflow>
#> 
#> There were issues with some computations:
#> 
#>   - Warning(s) x1: A correlation computation is required, but the inputs are size ze...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.

final_fit <- final_res_net %>%
    collect_predictions()



show_notes(.Last.tune.result)
#> unique notes:
#> ────────────────────────────────────────────────────────────────────────────────
#> A correlation computation is required, but the inputs are size zero or one and the standard deviation cannot be computed. `NA` will be returned.


sessionInfo()
#> R version 4.2.1 (2022-06-23)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Debian GNU/Linux 11 (bullseye)
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
#>  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
#>  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] glmnet_4.1-4       Matrix_1.4-1       yardstick_1.0.0    workflowsets_1.0.0
#>  [5] workflows_1.0.0    tune_1.0.0         tidyr_1.2.0        tibble_3.1.7      
#>  [9] rsample_1.0.0      recipes_1.0.1      purrr_0.3.4        parsnip_1.0.0     
#> [13] modeldata_1.0.0    infer_1.0.2        ggplot2_3.3.6      dplyr_1.0.9       
#> [17] dials_1.0.0        scales_1.2.0       broom_1.0.0        tidymodels_1.0.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] splines_4.2.1      foreach_1.5.2      prodlim_2019.11.13 assertthat_0.2.1  
#>  [5] conflicted_1.1.0   highr_0.9          GPfit_1.0-8        yaml_2.3.5        
#>  [9] globals_0.15.1     ipred_0.9-13       pillar_1.7.0       backports_1.4.1   
#> [13] lattice_0.20-45    glue_1.6.2         digest_0.6.29      hardhat_1.2.0     
#> [17] colorspace_2.0-3   htmltools_0.5.2    timeDate_3043.102  pkgconfig_2.0.3   
#> [21] lhs_1.1.5          DiceDesign_1.9     listenv_0.8.0      gower_1.0.0       
#> [25] lava_1.6.10        generics_0.1.2     ellipsis_0.3.2     cachem_1.0.6      
#> [29] withr_2.5.0        furrr_0.3.0        nnet_7.3-17        cli_3.3.0         
#> [33] survival_3.3-1     magrittr_2.0.3     crayon_1.5.1       memoise_2.0.1     
#> [37] evaluate_0.15      fs_1.5.2           future_1.26.1      fansi_1.0.3       
#> [41] parallelly_1.32.0  MASS_7.3-57        class_7.3-20       tools_4.2.1       
#> [45] lifecycle_1.0.1    stringr_1.4.0      munsell_0.5.0      reprex_2.0.1      
#> [49] compiler_4.2.1     rlang_1.0.4        grid_4.2.1         iterators_1.0.14  
#> [53] rmarkdown_2.14     gtable_0.3.0       codetools_0.2-18   DBI_1.1.3         
#> [57] R6_2.5.1           lubridate_1.8.0    knitr_1.39         fastmap_1.1.0     
#> [61] future.apply_1.9.0 utf8_1.2.2         shape_1.4.6        stringi_1.7.6     
#> [65] parallel_4.2.1     Rcpp_1.0.8.3       vctrs_0.4.1        rpart_4.1.16      
#> [69] tidyselect_1.1.2   xfun_0.31

Created on 2022-08-06 by the reprex package (v2.0.1)

@larry77
Copy link
Author

larry77 commented Aug 7, 2022

I should add that another used reported that the above script works on his platform with tidymodels 0.2.0, so I believe the issue is not the script per se or glmnet.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant