Skip to content

Commit

Permalink
Merge branch 'master' of github.com:mlr-org/mlr3
Browse files Browse the repository at this point in the history
  • Loading branch information
mllg committed Jun 2, 2020
2 parents 577fb48 + 1160724 commit bc1f7d4
Show file tree
Hide file tree
Showing 14 changed files with 97 additions and 51 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3
Title: Machine Learning in R - Next Generation
Version: 0.2.0-9000
Version: 0.3.0
Authors@R:
c(person(given = "Michel",
family = "Lang",
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# mlr3 0.2.0-9000
# mlr3 0.3.0

* Package `future.apply` is now imported (instead of suggested).
This is necessary to ensure reproducibility: This way exactly the same result
is calculated, independent of the parallel backend.
* Fixed a bug where prediction on new data for a task with blocking information
raised an exception (#496).
* New binding: `Task$order`.

# mlr3 0.2.0

Expand Down
4 changes: 2 additions & 2 deletions R/Resampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ Resampling = R6Class("Resampling",

if (is.null(strata)) {
if (is.null(groups)) {
instance = private$.sample(task$row_ids)
instance = private$.sample(task$row_ids, task)
} else {
private$.groups = groups
instance = private$.sample(unique(groups$group))
Expand All @@ -176,7 +176,7 @@ Resampling = R6Class("Resampling",
if (!is.null(groups)) {
stopf("Cannot combine stratification with grouping")
}
instance = private$.combine(lapply(strata$row_id, private$.sample))
instance = private$.combine(lapply(strata$row_id, private$.sample, task = task))
}

self$instance = instance
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingBootstrap.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ ResamplingBootstrap = R6Class("ResamplingBootstrap", inherit = Resampling,
),

private = list(
.sample = function(ids) {
.sample = function(ids, ...) {
pv = self$param_set$values
nr = round(length(ids) * pv$ratio)
x = factor(seq_along(ids))
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ ResamplingCV = R6Class("ResamplingCV", inherit = Resampling,
),

private = list(
.sample = function(ids) {
.sample = function(ids, ...) {
data.table(
row_id = ids,
fold = shuffle(seq_along0(ids) %% as.integer(self$param_set$values$folds) + 1L),
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ResamplingHoldout = R6Class("ResamplingHoldout", inherit = Resampling,
),

private = list(
.sample = function(ids) {
.sample = function(ids, ...) {
nr = round(length(ids) * self$param_set$values$ratio)
ii = shuffle(ids, nr)
list(train = ii, test = setdiff(ids, ii))
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingInsample.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ ResamplingInsample = R6Class("ResamplingInsample", inherit = Resampling,
),

private = list(
.sample = function(ids) {
.sample = function(ids, ...) {
ids
},

Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingRepeatedCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ResamplingRepeatedCV = R6Class("ResamplingRepeatedCV", inherit = Resampling,
),

private = list(
.sample = function(ids) {
.sample = function(ids, ...) {
pv = self$param_set$values
n = length(ids)
folds = as.integer(pv$folds)
Expand Down
2 changes: 1 addition & 1 deletion R/ResamplingSubsampling.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ResamplingSubsampling = R6Class("ResamplingSubsampling", inherit = Resampling,
),

private = list(
.sample = function(ids) {
.sample = function(ids, ...) {
pv = self$param_set$values
n = length(ids)
nr = round(n * pv$ratio)
Expand Down
84 changes: 49 additions & 35 deletions R/Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ Task = R6Class("Task",
#' columns are filtered to only contain features with roles `"target"` and `"feature"`.
#' If invalid `rows` or `cols` are specified, an exception is raised.
#'
#' @param ordered (`logical(1)`)\cr
#' If `TRUE` (default), data is ordered according to the columns with column role `"order"`.
#'
#' @return Depending on the [DataBackend], but usually a [data.table::data.table()].
data = function(rows = NULL, cols = NULL, data_format = "data.table") {
task_data(self, private, rows, cols, data_format)
data = function(rows = NULL, cols = NULL, data_format = "data.table", ordered = TRUE) {
task_data(self, rows, cols, data_format, ordered)
},

#' @description
Expand Down Expand Up @@ -519,7 +522,6 @@ Task = R6Class("Task",
#' * `row_id` (list of `integer()`) as list column with the row ids in the respective subpopulation.
#' Returns `NULL` if there are is no stratification variable.
#' See [Resampling] for more information on stratification.

strata = function(rhs) {
assert_ro_binding(rhs)
cols = private$.col_roles$stratum
Expand All @@ -536,7 +538,7 @@ Task = R6Class("Task",


#' @field groups ([data.table::data.table()])\cr
#' If the task has a column with designated role `"group"`, table with two columns:
#' If the task has a column with designated role `"group"`, a table with two columns:
#'
#' * `row_id` (`integer()`), and
#' * grouping variable `group` (`vector()`).
Expand All @@ -545,28 +547,47 @@ Task = R6Class("Task",
#' See [Resampling] for more information on grouping.
groups = function(rhs) {
assert_ro_binding(rhs)
groups = private$.col_roles$group
if (length(groups) == 0L) {
group_cols = private$.col_roles$group
if (length(group_cols) == 0L) {
return(NULL)
}
data = self$backend$data(private$.row_roles$use, c(self$backend$primary_key, groups))
data = self$backend$data(private$.row_roles$use, c(self$backend$primary_key, group_cols))
setnames(data, c("row_id", "group"))[]
},

#' @field order ([data.table::data.table()])\cr
#' If the task has at least one column with designated role `"order"`, a table with two columns:
#'
#' * `row_id` (`integer()`), and
#' * ordering vector `order` (`integer()`).
#'
#' Returns `NULL` if there are is no order column.
order = function(rhs) {
assert_ro_binding(rhs)

order_cols = private$.col_roles$order
if (length(order_cols) == 0L) {
return(NULL)
}

data = self$backend$data(private$.row_roles$use, order_cols)
data.table(row_id = private$.row_roles$use, order = do.call(order, data))
},

#' @field weights ([data.table::data.table()])\cr
#' If the task has a column with designated role `"weight"`, table with two columns:
#' If the task has a column with designated role `"weight"`, a table with two columns:
#'
#' * `row_id` (`integer()`), and
#' * observation weights `weight` (`numeric()`).
#'
#' Returns `NULL` if there are is no weight column.
weights = function(rhs) {
assert_ro_binding(rhs)
weights = private$.col_roles$weight
if (length(weights) == 0L) {
weight_cols = private$.col_roles$weight
if (length(weight_cols) == 0L) {
return(NULL)
}
data = self$backend$data(private$.row_roles$use, c(self$backend$primary_key, weights))
data = self$backend$data(private$.row_roles$use, c(self$backend$primary_key, weight_cols))
setnames(data, c("row_id", "weight"))[]
}
),
Expand All @@ -584,59 +605,52 @@ Task = R6Class("Task",
)
)

task_data = function(self, private, rows = NULL, cols = NULL, data_format = "data.table", subset_active = c("rows", "cols")) {

task_data = function(self, rows = NULL, cols = NULL, data_format = "data.table", ordered = TRUE, subset_active = c("rows", "cols")) {
assert_choice(data_format, self$backend$data_formats)
row_roles = private$.row_roles
col_roles = private$.col_roles
row_roles = self$row_roles
col_roles = self$col_roles

if (is.null(rows)) {
selected_rows = row_roles$use
rows = row_roles$use
} else {
if ("rows" %in% subset_active) {
assert_subset(rows, row_roles$use)
}
if (is.double(rows)) {
rows = as.integer(rows)
}
selected_rows = rows
}

if (is.null(cols)) {
selected_cols = c(col_roles$target, col_roles$feature)
query_cols = cols = c(col_roles$target, col_roles$feature)
} else {
if ("cols" %in% subset_active) {
assert_subset(cols, c(col_roles$target, col_roles$feature))
}
selected_cols = cols
query_cols = cols
}

order = col_roles$order
if (length(order)) {
reorder_rows = length(col_roles$order) > 0L && isTRUE(ordered)
if (reorder_rows) {
if (data_format != "data.table") {
stopf("Ordering only supported for data_format 'data.table'")
}
order_cols = setdiff(order, selected_cols)
selected_cols = union(selected_cols, order_cols)
} else {
order_cols = character()
query_cols = union(query_cols, col_roles$order)
}

data = self$backend$data(rows = selected_rows, cols = selected_cols, data_format = data_format)
data = self$backend$data(rows = rows, cols = query_cols, data_format = data_format)

if (length(selected_cols) && nrow(data) != length(selected_rows)) {
stopf("DataBackend did not return the rows correctly: %i requested, %i received", length(selected_rows), nrow(data))
if (length(query_cols) && nrow(data) != length(rows)) {
stopf("DataBackend did not return the queried rows correctly: %i requested, %i received", length(rows), nrow(data))
}

if (length(selected_rows) && ncol(data) != length(selected_cols)) {
stopf("DataBackend did not return the cols correctly: %i requested, %i received", length(selected_cols), ncol(data))
if (length(rows) && ncol(data) != length(query_cols)) {
stopf("DataBackend did not return the queried cols correctly: %i requested, %i received", length(cols), ncol(data))
}

if (length(order)) {
setorderv(data, order)[]
if (length(order_cols)) {
data[, (order_cols) := NULL][]
}
if (reorder_rows) {
setorderv(data, col_roles$order)[]
data = remove_named(data, setdiff(col_roles$order, cols))
}

return(data)
Expand Down
7 changes: 5 additions & 2 deletions R/TaskClassif.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ TaskClassif = R6Class("TaskClassif",
#' Calls `$data` from parent class [Task] and ensures that levels of the target column
#' are in the right order.
#'
#' @param ordered (`logical(1)`)\cr
#' If `TRUE` (default), data is ordered according to the columns with column role `"order"`.
#'
#' @return Depending on the [DataBackend], but usually a [data.table::data.table()].
data = function(rows = NULL, cols = NULL, data_format = "data.table") {
data = task_data(self, private, rows, cols, data_format)
data = function(rows = NULL, cols = NULL, data_format = "data.table", ordered = TRUE) {
data = task_data(self, rows, cols, data_format, ordered)
fix_factor_levels(data, set_names(list(self$class_names), self$target_names))
},

Expand Down
18 changes: 15 additions & 3 deletions man/Task.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/TaskClassif.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions tests/testthat/test_Task.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,16 @@ test_that("Rows return ordered", {
x = task$data()
expect_integer(x$t, sorted = TRUE, any.missing = FALSE)

x = task$data(ordered = FALSE)
expect_true(is.unsorted(x$t))

x = task$data(rows = sample(nrow(data), 50))
expect_integer(x$t, sorted = TRUE, any.missing = FALSE)

tab = task$order
expect_data_table(tab, ncols = 2, nrows = task$nrow)
expect_set_equal(names(tab), c("row_id", "order"))
expect_integer(rev(tab$order), sorted = TRUE)
})

test_that("Rows return ordered with multiple order cols", {
Expand Down

0 comments on commit bc1f7d4

Please sign in to comment.