Skip to contents

Predicts outcomes with a multi-task or transfer learning regression model.

Usage

# S3 method for class 'sparselink'
predict(object, newx, weight = NULL, ...)

Arguments

object

object of class "sparselink" (generated by function sparselink)

newx

features: matrix with \(n\) rows (samples) and \(p\) columns (variables) for multi-task learning; list of \(q\) matrices with \(n_k\) rows (samples) and \(p\) columns (variables) for transfer learning, for each \(k\) in \(1,\ldots,q\)

weight

hyperparameters for scaling external and internal weights: numeric vector of length 2, with the first entry for the external weights (prior coefficients from source data), and the second entry for the internal weights (prior coefficients from target data), selected values must be among the candidate values, default: NULL (using cross-validated weights)

...

(not applicable)

Value

Returns predicted values or predicted probabilities. The output is a list of \(q\) column vectors of length \(n_k\) for \(k\) in \(1,\ldots,q\). Each vector corresponds to one target (multi-task learning) or one dataset (transfer learning).

References

Armin Rauschenberger, Petr N. Nazarov, and Enrico Glaab (2025). "Estimating sparse regression models in multi-task learning and transfer learning through adaptive penalisation". Under revision. https://hdl.handle.net/10993/63425

See also

Use sparselink to fit the model and coef to extract coefficients.

Examples

family <- "gaussian"
type <- "multiple" # try "multiple" or "transfer"
if(type=="multiple"){
 data <- sim_data_multi(family=family)
} else if(type=="transfer"){
 data <- sim_data_trans(family=family)
}
object <- sparselink(x=data$X_train,y=data$y_train,family=family)
#> mode: multi-target learning, alpha.init=0.95 (elastic net), alpha=1 (lasso)
#> Warning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per fold
#> Warning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per fold
#> Warning: Option grouped=FALSE enforced in cv.glmnet, since < 3 observations per fold
y_hat <- predict(object=object,newx=data$X_test)