-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubmit.R
79 lines (52 loc) · 1.9 KB
/
submit.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!/usr/bin/Rscript
# TODO: add docopt to scripts
reticulate::source_python("python/load_test_data.py")
# if (length(args) < 2)
# stop("Usage: eval.R RUNDIR CSVFILE")
#
# args <- commandArgs(TRUE)
# rundir <- normalizePath(file.path("~/internal/runs/", args[[1]]))
# csvfile <- normalizePath(args[[2]])
# rundir <- normalizePath("~/internal/runs/2020-02-22T13-49-35.272Z")
# rundir <- normalizePath("~/internal/runs/2020-02-17T18-53-46.997Z")
rundir <- ("./")
csvfile <- paste0("submissions/", timestamp(), ".csv")
model <- restore_model(rundir)
# testfiles <- list('data/data-raw/test_image_data_0.parquet')
npa <- load_test_data()
sess <- tf$compat$v1$keras$backend$get_session()
predict <- function(input) {
input <- input %>%
as_tensor(dtype = tf$float32) %>%
layer_expand_dims()
raw_probs <- sess$run(model(input))
# browser()
preds <- raw_probs %>%
argmax() %>%
listarrays::bind_as_cols() %>%
tibble::as_tibble()
names(preds) <- c("grapheme_root", "consonant_diacritic", "vowel_diacritic")
tidyr::unnest(preds, cols = names(preds))
}
preds <- npa %>%
predict() %>%
tibble::rowid_to_column(var = "rowid")
predictions <- lapply(purrr::transpose(preds), function(pred) {
idx <- pred$rowid - 1L
row_id.root <- paste0("Test_", idx, "_grapheme_root")
row_id.cons <- paste0("Test_", idx, "_consonant_diacritic")
row_id.vowel <- paste0("Test_", idx, "_vowel_diacritic")
cons <- c(row_id.cons, pred$consonant_diacritic)
root <- c(row_id.root, pred$grapheme_root)
vow <- c(row_id.vowel, pred$vowel_diacritic)
predictions <- list(cons, root, vow) %>%
listarrays::bind_as_rows() %>%
tibble::as_tibble()
names(predictions) <- c("row_id", "target")
predictions
})
str(predictions)
csv_preds <- dplyr::bind_rows(predictions)
csv_preds$target <- as.double(csv_preds$target)
readr::write_csv(csv_preds, csvfile)
message("Success!")