sl3
LearnersThis guide describes the process of implementing a learner class for
a new machine learning algorithm. By writing a learner class for your
favorite machine learning algorithm, you will be able to use it in all
the places you could otherwise use any other sl3
learners,
including Pipeline
s, Stack
s, and Super
Learner. We have done our best to streamline the process of creating new
sl3
learners.
Before diving into defining a new learner, it will likely be helpful
to read some background material. If you haven’t already read it, the “Modern Machine Learning in R” vignette is a
good introduction to the sl3
package and it’s underlying
architecture. The R6
documentation will help you understand how R6
classes are
defined. In addition, the help files for sl3_Task
and Lrnr_base
are good resources for how those objects can be used. If you’re
interested in defining learners that fit sub-learners, reading the
documentation of the delayed
package will be helpful.
In the following sections, we introduce and review a template for a
new sl3
learner, describing the sections that can be used
to define your new learner. This is followed by a discussion of the
important task of documenting and testing your new learner. Finally, we
conclude by explaining how you can add your learner to sl3
so that others may make use of it.
sl3
provides a template of a learner for use in defining
new learners. You can make a copy of the template to work on by invoking
write_learner_template
:
Let’s take a look at that template:
##' Template of a \code{sl3} Learner.
##'
##' This is a template for defining a new learner.
##' This can be copied to a new file using \code{\link{write_learner_template}}.
##' The remainder of this documentation is an example of how you might write
##' documentation for your new learner.
##'
##' This learner uses \code{\link[my_package]{my_ml_fun}} from \code{my_package}
##' to fit my favorite machine learning algorithm.
##'
##' @docType class
##'
##' @importFrom R6 R6Class
##'
##' @export
##'
##' @keywords data
##'
##' @return A learner object inheriting from \code{\link{Lrnr_base}} with
##' methods for training and prediction. For a full list of learner
##' functionality, see the complete documentation of \code{\link{Lrnr_base}}.
##'
##' @format An \code{\link[R6]{R6Class}} object inheriting from
##' \code{\link{Lrnr_base}}.
##'
##' @family Learners
##'
##' @section Parameters:
##' - \code{param_1="default_1"}: This parameter does something and is not
##' already specified in the task.
##' - \code{param_2="default_2"}: This parameter does something else and is
##' not already specified in the task.
##' - \code{...}: Other parameters passed directly to
##' \code{\link[my_package]{my_ml_fun}}. See its documentation for details.
##' Also, any additional parameters that can be considered by
##' \code{\link{Lrnr_base}}.
##'
##' @examples
##' include an example here, e.g. see \code{\link{Lrnr_ranger}}'s example.
##'
Lrnr_template <- R6Class(
classname = "Lrnr_template",
inherit = Lrnr_base,
portable = TRUE,
class = TRUE,
# Above, you should change Lrnr_template (in both the object name and the classname argument)
# to a name that indicates what your learner does
public = list(
# you can define default parameter values here
# if possible, your learner should define defaults for all required parameters
initialize = function(param_1 = "default_1", param_2 = "default_2", ...) {
# this captures all parameters to initialize and saves them as self$params
params <- args_to_list()
super$initialize(params = params, ...)
},
# you can define public functions that allow your learner to do special things here
# for instance glm learner might return prediction standard errors
special_function = function(arg_1) {
}
),
private = list(
# list properties your learner supports here.
# Use sl3_list_properties() for a list of options
.properties = c(""),
# .train takes task data and returns a fit object that can be used to generate predictions
.train = function(task) {
# generate an argument list from the parameters that were
# captured when your learner was initialized.
# this allows users to pass arguments directly to your ml function
args <- self$params
# get outcome variable type
# preferring learner$params$outcome_type first, then task$outcome_type
outcome_type <- self$get_outcome_type(task)
# should pass something on to your learner indicating outcome_type
# e.g. family or objective
# add task data to the argument list
# what these arguments are called depends on the learner you are wrapping
args$x <- as.matrix(task$X_intercept)
args$y <- outcome_type$format(task$Y)
# only add arguments on weights and offset
# if those were specified when the task was generated
if (task$has_node("weights")) {
args$weights <- task$weights
}
if (task$has_node("offset")) {
args$offset <- task$offset
}
if (task$has_node("id")) {
args$id <- task$id
}
# call a function that fits your algorithm
# with the argument list you constructed
fit_object <- call_with_args(my_ml_fun, args)
# return the fit object, which will be stored
# in a learner object and returned from the call
# to learner$predict
return(fit_object)
},
# .predict takes a task and returns predictions from that task
.predict = function(task = NULL) {
self$training_task
self$training_outcome_type
self$fit_object
predictions <- predict(self$fit_object, task$X)
return(predictions)
},
# list any packages required for your learner here.
.required_packages = c("my_package")
)
)
The template has comments indicating where details specific to the learner you’re trying to implement should be filled in. In the next section, we will discuss those details further.
At the top of the template, we define an object
Lrnr_template
and set
classname = "Lrnr_template"
. You should modify these to
match the name of your new learner, which should also match the name of
the corresponding R file. Note that the name should be prefixed by
Lrnr_
and use snake_case
.
public$initialize
This function defines the constructor for your learner, and it stores
the arguments (if any) provided when a user calls
make_learner(Lrnr_your_learner, ...)
. You can also provide
default parameter values, just as the template does with
param_1 = "default_1"
, and
param_2 = "default_2"
. All parameters used by your newly
defined learners should have defaults whenever possible. This will allow
users to use your learner without having to figure out what reasonable
parameter values might be. Parameter values should be documented; see
the section below on documentation for
details.
public$special_function
sYou can of course define functions for things only your learner can
do. These should be public functions like the
special_function
defined in the example. These should be
documented; see the section below on documentation for details.
private$.properties
This field defines properties supported by your learner. This may
include different outcome types that are supported, offsets and weights,
amongst many other possibilities. To see a list of all properties
supported/used by at least one learner, you may invoke
sl3_list_properties
:
## [1] "binomial" "categorical" "continuous" "cv"
## [5] "density" "h2o" "ids" "importance"
## [9] "offset" "preprocessing" "sampling" "screener"
## [13] "timeseries" "weights" "wrapper"
private$.required_packages
This field defines other R packages required for your learner to work properly. These will be loaded when an object of your new learner class is initialized.
If you’ve used sl3
before, you may have noticed that
while users are instructed to use learner$train
,
learner$predict
, and learner$chain
, to train,
generate predictions, and generate a chained task for a given learner
object, respectively, the template does not implement these methods.
Instead, the template implements private methods called
.train
, .predict
, and .chain
. The
specifics of these methods are explained below; however, it is helpful
to first understand how the two sets of methods are related. At the risk
of complicating things further, it is worth noting that there is
actually a third set of methods (learner$base_train
,
learner$base_predict
, and learner$base_chain
)
of which you may not be aware.
So, what happens when a user calls learner$train
? That
method generates a delayed
object using the
delayed_learner_train
function, and then computes that
delayed object. In turn, delayed_learner_train
defines a
delayed computation that calls base_train
, a user-facing
function that can be used to train tasks without using the facilities of
the delayed
package. base_train
validates the
user input, and in turn calls private$.train
. When
private$.train
returns a fit_object
,
base_train
takes that fit object, generates a learner fit
object, and returns it to the user.
Each call to learner$train
involves three separate
training methods:
learner$train
– trains a learner in a
manner that can be parallelized using delayed
, which calls
...
...
the user-facing learner$base_train
that validates user input, and which calls ...
...
the internal private$.train
, which
does the actual work of fitting the learner and returning the fit
object.The logic in the user-facing learner$train
and
learner$base_train
is defined in the Lrnr_base
base class and is shared across all learners. As such, these methods
need not be reimplemented in individual learners. By contrast,
private$.train
contains the behavior that is specific to
each individual learner and should be reimplemented at the level of each
individual learner. Since learner$base_train
does not use
delayed
, it may be helpful to use it when debugging the
training code in a new learner. The program flow used for prediction and
chaining is analogous.
private$.train
This is the main training function, which takes in a task and returns
a fit_object
that contains all information needed to
generate predictions. The fit object should not contain more data than
is absolutely necessary, as including excess information will create
needless inefficiencies. Many learner functions (like glm
)
store one or more copies of their training data – this uses unnecessary
memory and will hurt learner performance for large sample sizes. Thus,
these copies of the data should be removed from the fit object before it
is returned. You may make use of true_obj_size
to estimate
the size of your fit_object
. For most learners,
fit_object
size should not grow linearly with
training sample size. If it does, and this is unexpected, please try to
reduce the size of the fit_object
.
Most of the time, the learner you are implementing will be fit using
a function that already exists elsewhere. We’ve built some tools to
facilitate passing parameter values directly to such functions. The
private$.train
function in the template uses a common
pattern: it builds up an argument list starting with the parameter
values and using data from the task, it then uses
call_with_args
to call my_ml_fun
with that
argument list. It’s not required that learners use this pattern, but it
will be helpful in the common case where the learner is simply wrapping
an underlying my_ml_fun
.
By default, call_with_args
will pass all arguments in
the argument list matched by the definition of the function that it is
calling. This allows the learner to silently drop irrelevant parameters
from the call to my_ml_fun
. Some learners either capture
important arguments using dot arguments (...
) or by passing
important arguments through such dot arguments on to a secondary
function. Both of these cases can be handled using the
other_valid
and keep_all
options to
call_with_args
. The former allows you to list other valid
arguments and the latter disables argument filtering altogether.
private$.predict
This is the main prediction function, and takes in a task and
generates predictions for that task using the fit_object
.
If those predictions are 1-dimensional, they will be coerced to a vector
by base_predict
.
private$.chain
This is the main chaining function. It takes in a task and generates
a chained task (based on the input task) using the given
fit_object
. If this method is not implemented, your learner
will use the default chaining behavior, which is to return a new task
where the covariates are defined as your learner’s predictions for the
current task.
Generally speaking, the above sections will be all that’s required
for implementing a new learner in the sl3
framework. In
some cases, it may be desirable to define learners that have
“sub-learners” or other learners on which they depend. Examples of such
learners are Stack
, Pipeline
,
Lrnr_cv
, and Lrnr_sl
. In order to parallelize
the fitting of these sub-learners using delayed
, these
learners implement a specialized private$.train_sublearners
method that calls delayed_learner_train
on their
sub-learners, returning a single delayed
object that, when
evaluated, returns all relevant fit objects from these sub-learners. The
result of that call is then passed as a second argument to their
private$.train
method, which now has the function prototype
private$.train(task, trained_sublearners)
. Learners defined
in such a manner usually have a much shorter computation time; the
predict
and chain
methods are not currently
parallelized in this way, although this is subject to change in the
future.
If, like these learners, your learner depends on sub-learners, you have two options:
private$.train
as discussed above, being sure to call
sublearner$base_train
and not
sublearner$train
, to avoid nesting calls to
delayed
, which may result in sub-optimal performance.private$.train_sublearners(task)
and
private$.train(task, trained_sublearners)
, to
parallelize sub-learners using delayed
. We suggest
reviewing the implementations of the Stack
,
Pipeline
, Lrnr_cv
and Lrnr_sl
to
get a better understanding of how to implement parallelized
sub-learners.In either case, you should be careful to call
sublearner$base_predict
and
sublearner$base_chain
, instead of
sublearner$predict
and sublearner$chain
,
except in the context of the private$.train_sublearners
function, where you should use delayed_learner_fit_predict
and delayed_learner_fit_chain
.
If you want other people to be able to use your learner, you will need to document and provide unit tests for it. The above template has example documentation, written in the roxygen format. Most importantly, you should describe what your learner does, reference any external code it uses, and document any parameters and public methods defined by it.
It’s also important to test your learner. You
should write unit tests to verify that your learner can train and
predict on new data, and, if applicable, generate a chained task. It
might also be a good idea to use the risk
function in
sl3
to verify your learner’s performance on a sample
dataset. That way, if you change your learner and performance drops, you
know something may have gone wrong.
sl3
Once you’ve implemented your new learner (and made sure that it has
quality documentation and unit tests), please consider adding it to the
sl3
project. This will make it possible for other
sl3
users to use and build on your work. Make sure to add
any R packages listed in .required_packages
to the
Suggests:
field of the DESCRIPTION
file of the
sl3
package. Once this is done, please submit a
Pull Request to the sl3
package on GitHub to request that your
learned be added. If you’ve never made a “Pull Request” before, see this
helpful guide: https://yangsu.github.io/pull-request-tutorial/.
From the sl3
team, thanks for your interest in extending
sl3
!