Fitting the Highly Adaptive Lasso with hal9001

Nima Hejazi and Jeremy Coyle

2021-01-21

Introduction

The highly adaptive Lasso (HAL) is a flexible machine learning algorithm that nonparametrically estimates a function based on available data by embedding a set of input observations and covariates in an extremely high-dimensional space (i.e., generating basis functions from the available data). For an input data matrix of \(n\) observations and \(d\) covariates, the number of basis functions generated is approximately \(n \cdot 2^{d - 1}\). To select a set of basis functions from among the full set generated, the Lasso is employed. The hal9001 R package provides an efficient implementation of this routine, relying on the glmnet R package for compatibility with the canonical Lasso implementation while still providing a (faster) custom C++ routine for using the Lasso with an input matrix composed of indicator functions. Consider consulting Benkeser and van der Laan (2016), (???), van der Laan (2017) for detailed theoretical descriptions of the highly adaptive Lasso and its various optimality properties.


Preliminaries

# simulation constants
set.seed(467392)
n_obs <- 200
n_covars <- 3

# make some training data
x <- replicate(n_covars, rnorm(n_obs))
y <- sin(x[, 1]) + sin(x[, 2]) + rnorm(n_obs, mean = 0, sd = 0.2)

# make some testing data
test_x <- replicate(n_covars, rnorm(n_obs))
test_y <- sin(x[, 1]) + sin(x[, 2]) + rnorm(n_obs, mean = 0, sd = 0.2)

Let’s look at simulated data:

head(x)
##             [,1]       [,2]       [,3]
## [1,]  2.44102981 -0.6441252 -0.4632021
## [2,] -1.21932335 -0.9481608  2.6358511
## [3,] -0.40613567  0.4337314 -0.2226760
## [4,] -1.09760477 -1.5845711 -1.0496038
## [5,]  0.23710498  0.1261754  1.4717507
## [6,]  0.06810091 -0.2623992 -0.7534596
head(y)
## [1]  0.31596199 -1.74749349 -0.08198272 -2.13963686  0.42902938 -0.12824651

Using the Highly Adaptive Lasso

library(hal9001)
## Loading required package: Rcpp
## hal9001 v0.2.7: The Scalable Highly Adaptive Lasso

Fitting the model: glmnet

HAL uses the popular glmnet R package for the lasso step:

## [1] "Without your space helmet, Dave. You're going to find that rather difficult."
##                   user.self sys.self elapsed user.child sys.child
## enumerate_basis       0.003    0.000   0.002          0         0
## design_matrix         0.005    0.000   0.006          0         0
## reduce_basis          0.000    0.000   0.000          0         0
## remove_duplicates     0.007    0.000   0.007          0         0
## lasso                 0.945    0.009   0.953          0         0
## total                 0.961    0.009   0.969          0         0

While the raw output object may be examined, it has (usually large) slots that make quick examination challenging. Instead, we recommend the summary method, which provides an interpretable table of basis functions with non-zero coefficients.

##              coef
##  1: -7.812231e-01
##  2:  1.968808e-01
##  3:  1.632766e-01
##  4:  1.483639e-01
##  5:  1.462592e-01
##  6:  1.456650e-01
##  7:  1.441887e-01
##  8:  1.367758e-01
##  9:  1.348308e-01
## 10:  1.324568e-01
## 11:  1.247913e-01
## 12:  1.171875e-01
## 13:  1.146518e-01
## 14:  1.125166e-01
## 15:  1.120689e-01
## 16:  1.098491e-01
## 17:  1.038959e-01
## 18:  1.000336e-01
## 19:  8.276983e-02
## 20:  7.750062e-02
## 21:  7.317027e-02
## 22:  7.296611e-02
## 23:  6.535017e-02
## 24:  6.344607e-02
## 25:  6.300968e-02
## 26:  5.772281e-02
## 27:  5.638586e-02
## 28:  5.585372e-02
## 29:  5.433344e-02
## 30:  5.086507e-02
## 31:  5.041330e-02
## 32:  4.925794e-02
## 33:  4.846195e-02
## 34:  4.137366e-02
## 35:  3.869325e-02
## 36:  3.812788e-02
## 37:  3.652213e-02
## 38:  3.631055e-02
## 39:  3.597019e-02
## 40:  3.377831e-02
## 41:  3.351076e-02
## 42:  2.921166e-02
## 43:  2.851651e-02
## 44:  2.676665e-02
## 45:  2.459752e-02
## 46:  2.402039e-02
## 47:  2.344251e-02
## 48:  2.154657e-02
## 49:  1.971008e-02
## 50:  1.873740e-02
## 51:  1.842638e-02
## 52:  1.760974e-02
## 53:  1.755002e-02
## 54:  1.557752e-02
## 55:  1.429619e-02
## 56:  1.267729e-02
## 57:  1.142536e-02
## 58:  7.771376e-03
## 59:  3.178810e-03
## 60:  2.607689e-03
## 61:  2.534992e-03
## 62:  1.790010e-03
## 63:  1.617822e-03
## 64:  1.424163e-03
## 65:  2.166444e-04
## 66:  8.432367e-05
## 67: -8.648258e-05
## 68: -2.281031e-03
## 69: -7.607974e-03
## 70: -2.909367e-02
## 71: -3.439369e-02
## 72: -3.944700e-02
## 73: -4.035840e-02
## 74: -9.726592e-02
## 75: -1.056429e-01
## 76: -1.466408e-01
## 77: -2.054773e-01
## 78: -7.314642e-01
##              coef
##                                                                                                                                                                                                                                                                                                                                                                                                   term
##  1:                                                                                                                                                                                                                                                                                                                                                                                          Intercept
##  2:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.1725)
##  3:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.3807)
##  4:                                                                                                                                                                                                                                                                                                                                                                                      I(2 >= 0.119)
##  5:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.5685)
##  6:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.1015)
##  7:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.1386)
##  8:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 1.0291)
##  9:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.6441)
## 10:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.6786)
## 11:                                                                                                                                                                                                                                                                                                                                                                      I(2 >= 1.3571)*I(3 >= 1.4599)
## 12:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.4752)
## 13:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.9482)
## 14:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.5113)
## 15:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.1938)
## 16:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= -0.782)
## 17:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.9921)
## 18:                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.8989)*I(3 >= -1.0482)
## 19:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= -0.313)
## 20:                                                                                                                                                                                                                                                                                                                                                                      I(1 >= 0.0911)*I(2 >= 0.3539)
## 21:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.0345)
## 22:                                                                                                                                                                                                                                                                                                                                                                      I(1 >= 0.0855)*I(2 >= 0.4067)
## 23:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.2045)
## 24:                                                                                                                                                                                                                                                                                                                                                                                      I(1 >= 0.634)
## 25:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.6718)
## 26:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.6959)
## 27:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.4061)
## 28:                                                                                                                                                                                                                                                                                                                                                       I(1 >= 1.182)*I(2 >= 0.5866)*I(3 >= -0.5114)
## 29:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.8398)
## 30:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.1844)
## 31:                                                                                                                                                                                                                                                                                                                                                                                      I(1 >= 0.747)
## 32:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= -0.908)
## 33:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.3545)
## 34:                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.1378)*I(3 >= -1.5644)
## 35:                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.6488)*I(3 >= -1.0124)
## 36:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.7863)
## 37:                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 1.2184)*I(3 >= -0.6318)
## 38:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.0232)
## 39:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.2092)
## 40:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.4512)
## 41:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.2844)
## 42:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.4766)
## 43:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.8281)
## 44:                                                                                                                                                                                                                                                                                                                                                                      I(2 >= 0.1737)*I(3 >= 1.8577)
## 45:                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 1.1353)*I(3 >= -1.0124)
## 46:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -1.2687)
## 47:                                                                                                                                                                                                                                                                                                                                                                       I(1 >= 0.747)*I(3 >= -1.753)
## 48:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.1737)
## 49:                                                                                                                                                                                                                                                                                                                                                                     I(1 >= -0.3937)*I(2 >= -0.782)
## 50:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.3645)
## 51:                                                                                                                                                                                                                                                                                                                   I(1 >= 1.2006)*I(3 >= -1.011)  OR  I(1 >= 1.2006)*I(2 >= -1.1474)*I(3 >= -1.011)
## 52:                                                                                                                                                                                                                                                                                                                                                                     I(1 >= -0.8452)*I(3 >= -0.628)
## 53:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.3879)
## 54:                                                                                                                                                                                                                                                                                                                                                                      I(1 >= 0.1485)*I(2 >= 0.9348)
## 55:                                                                                                                                                                                                                                                                                                                                                     I(1 >= 1.2523)*I(2 >= -0.0958)*I(3 >= -0.6158)
## 56:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -0.3809)
## 57:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.0444)
## 58:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 0.3237)
## 59:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.1302)
## 60:                                                                                                                                                                                                                                                                                                                                                       I(1 >= 0.747)*I(2 >= -0.1954)*I(3 >= -1.753)
## 61:                                                                                                                                                                                                                                                                                                                                                                                     I(1 >= 0.1921)
## 62:                                                                                                                                                                                                                                                                                                                                                                                      I(2 >= 0.738)
## 63:                                                                                                                                                                                                                                                                                                                                                                      I(1 >= 0.634)*I(3 >= -0.8881)
## 64:                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.8398)*I(3 >= -1.5209)
## 65:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.3937)
## 66:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -0.4836)
## 67:                                                                                                                                                                                                                                                                                                                                                                     I(2 >= -2.1625)*I(3 >= 0.3336)
## 68:                                                                                                                                                                                                                                                                                                                                                                     I(1 >= -0.7038)*I(3 >= 0.3925)
## 69:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -2.1828)
## 70:                                                                                                                                                                                                                                                                                                   I(1 >= 1.7813)*I(2 >= 0.4792)*I(3 >= -0.9682)  OR  I(1 >= 2.0998)*I(2 >= 0.3074)*I(3 >= -1.3211)
## 71: I(1 >= 1.8165)*I(3 >= 0.7568)  OR  I(1 >= 1.8717)*I(3 >= 0.4161)  OR  I(1 >= 1.9814)*I(3 >= 0.0149)  OR  I(1 >= 2.2918)*I(3 >= 0.0013)  OR  I(1 >= 1.5647)*I(2 >= 1.3684)*I(3 >= 0.3794)  OR  I(1 >= 1.8165)*I(2 >= 0.3545)*I(3 >= 0.7568)  OR  I(1 >= 1.8717)*I(2 >= 0.2858)*I(3 >= 0.4161)  OR  I(1 >= 1.9814)*I(2 >= -0.7717)*I(3 >= 0.0149)  OR  I(1 >= 2.2918)*I(2 >= -1.0968)*I(3 >= 0.0013)
## 72:                                                                                                                                                                                                                                                                                                                                                                     I(1 >= -2.1579)*I(3 >= 0.3054)
## 73:                                                                                                                                                                                                                                                                                                                                                                     I(1 >= -2.3015)*I(2 >= 2.1782)
## 74:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 2.0022)
## 75:                                                                                                                                                                                                                                                                                                                                                                                     I(2 >= 2.1818)
## 76:                                                                                                                                                                                                                                                             I(1 >= 2.441)  OR  I(1 >= 2.441)*I(2 >= -0.6441)  OR  I(1 >= 2.441)*I(3 >= -0.4632)  OR  I(1 >= 2.441)*I(2 >= -0.6441)*I(3 >= -0.4632)
## 77:                                                                                                                                                                                                                                                                                                                                                                                    I(2 >= -2.3508)
## 78:                                                                                                                                                                                                                                                                                                                                                                                    I(1 >= -3.0511)
##                                                                                                                                                                                                                                                                                                                                                                                                   term

Reducing basis functions

As described in Benkeser and van der Laan (2016), the HAL algorithm operates by first constructing a set of basis functions and subsequently fitting a Lasso model with this set of basis functions as the design matrix. Several approaches are considered for reducing this set of basis functions: 1. Removing duplicated basis functions (done by default in the fit_hal function), 2. Removing basis functions that correspond to only a small set of observations; a good rule of thumb is to scale with \(\frac{1}{\sqrt{n}}\).

The second of these two options may be invoked by specifying the reduce_basis argument to the fit_hal function:

## [1] "Dave, although you took very thorough precautions in the pod against my hearing you, I could see your lips move."
##                   user.self sys.self elapsed user.child sys.child
## enumerate_basis       0.002    0.000   0.002          0         0
## design_matrix         0.005    0.000   0.005          0         0
## reduce_basis          0.005    0.000   0.005          0         0
## remove_duplicates     0.002    0.000   0.002          0         0
## lasso                 0.736    0.005   0.740          0         0
## total                 0.751    0.005   0.755          0         0

In the above, all basis functions with fewer than 7.0710678% of observations meeting the criterion imposed are automatically removed prior to the Lasso step of fitting the HAL regression. The results appear below

##              coef
##  1: -8.509040e-01
##  2:  2.148854e-01
##  3:  1.901724e-01
##  4:  1.672545e-01
##  5:  1.650474e-01
##  6:  1.531422e-01
##  7:  1.510468e-01
##  8:  1.402116e-01
##  9:  1.326720e-01
## 10:  1.321485e-01
## 11:  1.249675e-01
## 12:  1.115810e-01
## 13:  1.114218e-01
## 14:  1.058269e-01
## 15:  9.755956e-02
## 16:  9.358887e-02
## 17:  7.869787e-02
## 18:  7.592093e-02
## 19:  7.429796e-02
## 20:  7.047377e-02
## 21:  6.739943e-02
## 22:  6.696098e-02
## 23:  6.191745e-02
## 24:  6.106494e-02
## 25:  5.956340e-02
## 26:  5.925512e-02
## 27:  5.917987e-02
## 28:  5.909718e-02
## 29:  5.404921e-02
## 30:  4.790203e-02
## 31:  4.695250e-02
## 32:  4.542445e-02
## 33:  4.247331e-02
## 34:  4.069736e-02
## 35:  4.002579e-02
## 36:  3.625673e-02
## 37:  3.592245e-02
## 38:  3.224930e-02
## 39:  3.203616e-02
## 40:  2.771001e-02
## 41:  2.653015e-02
## 42:  2.596375e-02
## 43:  2.559650e-02
## 44:  2.301804e-02
## 45:  2.256363e-02
## 46:  2.048626e-02
## 47:  1.949827e-02
## 48:  1.753155e-02
## 49:  1.373297e-02
## 50:  1.317963e-02
## 51:  1.029828e-02
## 52:  9.586481e-03
## 53:  9.319139e-03
## 54:  8.921289e-03
## 55:  6.506866e-03
## 56:  5.956689e-03
## 57:  4.749919e-03
## 58:  3.716205e-03
## 59:  1.188659e-03
## 60:  4.492954e-04
## 61:  3.206061e-04
## 62:  8.100191e-05
## 63:  6.515244e-05
## 64:  2.994393e-05
## 65:  1.740718e-05
## 66: -1.165468e-02
## 67: -1.663354e-02
## 68: -2.146260e-02
## 69: -9.731628e-02
## 70: -2.042386e-01
## 71: -6.717951e-01
##              coef
##                                                                                   term
##  1:                                                                          Intercept
##  2:                                                                    I(1 >= -0.1725)
##  3:                                                                    I(2 >= -0.1015)
##  4:                                                                     I(1 >= 0.3807)
##  5:                                                                    I(1 >= -0.5685)
##  6:                                                                      I(2 >= 0.119)
##  7:                                                                     I(2 >= 0.6786)
##  8:                                                                     I(1 >= 0.1386)
##  9:                                                                     I(2 >= 1.0291)
## 10:                                                                    I(2 >= -0.6441)
## 11:                                                                     I(1 >= 0.4752)
## 12:                                                                    I(2 >= -0.9482)
## 13:                                                                     I(2 >= -0.782)
## 14:                                                                    I(2 >= -0.9921)
## 15:                                                                     I(2 >= -0.313)
## 16:                                                                    I(1 >= -0.5113)
## 17:                                                      I(1 >= 0.0911)*I(2 >= 0.3539)
## 18:                                                     I(1 >= 0.8989)*I(3 >= -1.0482)
## 19:                                                     I(2 >= 1.2184)*I(3 >= -0.6318)
## 20:                                                      I(1 >= 0.0855)*I(2 >= 0.4067)
## 21:                                                                    I(2 >= -0.6718)
## 22:                                                                      I(1 >= 0.747)
## 23:                                                                    I(1 >= -0.4061)
## 24:                                                                      I(1 >= 0.634)
## 25:                                                                    I(1 >= -0.8398)
## 26:                                                                    I(1 >= -0.2045)
## 27:                                                                    I(2 >= -0.1938)
## 28:                                                                     I(1 >= -0.908)
## 29:                                                                     I(1 >= 0.0345)
## 30:                                                                     I(2 >= 0.1844)
## 31:                                                    I(1 >= -0.1378)*I(3 >= -1.5644)
## 32:                                                                     I(2 >= 0.8281)
## 33:                                                                     I(2 >= 0.3545)
## 34:                                                                     I(2 >= 0.6959)
## 35:                                                                     I(2 >= 0.1737)
## 36:                                                                     I(2 >= 0.0232)
## 37:                                                                     I(2 >= 0.4512)
## 38:                                                                    I(2 >= -0.7863)
## 39:                                                                    I(2 >= -1.2687)
## 40:                                                                     I(2 >= 0.2092)
## 41:                                                                     I(2 >= 0.3237)
## 42:                                                     I(1 >= -0.1388)*I(2 >= 0.9292)
## 43:                                                     I(1 >= 1.1353)*I(3 >= -1.0124)
## 44:                                                                    I(1 >= -0.4766)
## 45:                                                     I(1 >= -0.3937)*I(2 >= -0.782)
## 46:                                                                    I(1 >= -0.3879)
## 47:                                                                     I(2 >= 0.2844)
## 48:                                                     I(1 >= 0.6882)*I(3 >= -0.7632)
## 49:                                                       I(1 >= 0.747)*I(3 >= -1.753)
## 50:                                                     I(1 >= -0.8452)*I(3 >= -0.628)
## 51:                                                                     I(2 >= 0.0444)
## 52:   I(1 >= 1.2006)*I(3 >= -1.011)  OR  I(1 >= 1.2006)*I(2 >= -1.1474)*I(3 >= -1.011)
## 53:                                                                     I(2 >= 1.3571)
## 54:                                                                     I(1 >= 0.1921)
## 55:                                                                     I(1 >= 1.1153)
## 56:                                                     I(2 >= 0.5866)*I(3 >= -0.5114)
## 57:                                                     I(1 >= 1.1153)*I(2 >= -0.3809)
## 58: I(1 >= 1.1072)*I(3 >= -1.4661)  OR  I(1 >= 1.1072)*I(2 >= -1.3281)*I(3 >= -1.4661)
## 59:                                                                    I(1 >= -0.4836)
## 60:                                                                     I(1 >= 0.0954)
## 61:                                                                    I(1 >= -0.3937)
## 62:                                                                     I(1 >= 0.5926)
## 63:                                                                     I(1 >= 0.1884)
## 64:                                                                     I(1 >= 0.1023)
## 65:                                                                     I(2 >= 0.3466)
## 66:                                                                      I(3 >= 0.505)
## 67:                                                                    I(2 >= -2.1828)
## 68:                                                     I(1 >= -2.1579)*I(3 >= 0.3054)
## 69:                                                                     I(2 >= 1.4561)
## 70:                                                                    I(2 >= -2.3508)
## 71:                                                                    I(1 >= -3.0511)
##                                                                                   term

Obtaining model predictions

## [1] 0.02493478
## [1] 1.543119

References

Benkeser, David, and Mark J van der Laan. 2016. “The Highly Adaptive Lasso Estimator.” In 2016 IEEE International Conference on Data Science and Advanced Analytics (DSAA). IEEE. https://doi.org/10.1109/dsaa.2016.93.

van der Laan, Mark J. 2017. “Finite Sample Inference for Targeted Learning.” https://arxiv.org/abs/1708.09502.