hal9001
The highly adaptive Lasso (HAL) is a flexible machine learning algorithm that nonparametrically estimates a function based on available data by embedding a set of input observations and covariates in an extremely high-dimensional space (i.e., generating basis functions from the available data). For an input data matrix of \(n\) observations and \(d\) covariates, the number of basis functions generated is approximately \(n \cdot 2^{d - 1}\). To select a set of basis functions from among the full set generated, the Lasso is employed. The hal9001
R package provides an efficient implementation of this routine, relying on the glmnet
R package for compatibility with the canonical Lasso implementation while still providing a (faster) custom C++ routine for using the Lasso with an input matrix composed of indicator functions. Consider consulting Benkeser and van der Laan (2016), (???), van der Laan (2017) for detailed theoretical descriptions of the highly adaptive Lasso and its various optimality properties.
# simulation constants
set.seed(467392)
n_obs <- 200
n_covars <- 3
# make some training data
x <- replicate(n_covars, rnorm(n_obs))
y <- sin(x[, 1]) + sin(x[, 2]) + rnorm(n_obs, mean = 0, sd = 0.2)
# make some testing data
test_x <- replicate(n_covars, rnorm(n_obs))
test_y <- sin(x[, 1]) + sin(x[, 2]) + rnorm(n_obs, mean = 0, sd = 0.2)
Let’s look at simulated data:
## [,1] [,2] [,3]
## [1,] 2.44102981 -0.6441252 -0.4632021
## [2,] -1.21932335 -0.9481608 2.6358511
## [3,] -0.40613567 0.4337314 -0.2226760
## [4,] -1.09760477 -1.5845711 -1.0496038
## [5,] 0.23710498 0.1261754 1.4717507
## [6,] 0.06810091 -0.2623992 -0.7534596
## [1] 0.31596199 -1.74749349 -0.08198272 -2.13963686 0.42902938 -0.12824651
## Loading required package: Rcpp
## hal9001 v0.2.7: The Scalable Highly Adaptive Lasso
glmnet
HAL uses the popular glmnet
R package for the lasso step:
## [1] "Without your space helmet, Dave. You're going to find that rather difficult."
## user.self sys.self elapsed user.child sys.child
## enumerate_basis 0.003 0.000 0.002 0 0
## design_matrix 0.005 0.000 0.006 0 0
## reduce_basis 0.000 0.000 0.000 0 0
## remove_duplicates 0.007 0.000 0.007 0 0
## lasso 0.945 0.009 0.953 0 0
## total 0.961 0.009 0.969 0 0
While the raw output object may be examined, it has (usually large) slots that make quick examination challenging. Instead, we recommend the summary
method, which provides an interpretable table of basis functions with non-zero coefficients.
## coef
## 1: -7.812231e-01
## 2: 1.968808e-01
## 3: 1.632766e-01
## 4: 1.483639e-01
## 5: 1.462592e-01
## 6: 1.456650e-01
## 7: 1.441887e-01
## 8: 1.367758e-01
## 9: 1.348308e-01
## 10: 1.324568e-01
## 11: 1.247913e-01
## 12: 1.171875e-01
## 13: 1.146518e-01
## 14: 1.125166e-01
## 15: 1.120689e-01
## 16: 1.098491e-01
## 17: 1.038959e-01
## 18: 1.000336e-01
## 19: 8.276983e-02
## 20: 7.750062e-02
## 21: 7.317027e-02
## 22: 7.296611e-02
## 23: 6.535017e-02
## 24: 6.344607e-02
## 25: 6.300968e-02
## 26: 5.772281e-02
## 27: 5.638586e-02
## 28: 5.585372e-02
## 29: 5.433344e-02
## 30: 5.086507e-02
## 31: 5.041330e-02
## 32: 4.925794e-02
## 33: 4.846195e-02
## 34: 4.137366e-02
## 35: 3.869325e-02
## 36: 3.812788e-02
## 37: 3.652213e-02
## 38: 3.631055e-02
## 39: 3.597019e-02
## 40: 3.377831e-02
## 41: 3.351076e-02
## 42: 2.921166e-02
## 43: 2.851651e-02
## 44: 2.676665e-02
## 45: 2.459752e-02
## 46: 2.402039e-02
## 47: 2.344251e-02
## 48: 2.154657e-02
## 49: 1.971008e-02
## 50: 1.873740e-02
## 51: 1.842638e-02
## 52: 1.760974e-02
## 53: 1.755002e-02
## 54: 1.557752e-02
## 55: 1.429619e-02
## 56: 1.267729e-02
## 57: 1.142536e-02
## 58: 7.771376e-03
## 59: 3.178810e-03
## 60: 2.607689e-03
## 61: 2.534992e-03
## 62: 1.790010e-03
## 63: 1.617822e-03
## 64: 1.424163e-03
## 65: 2.166444e-04
## 66: 8.432367e-05
## 67: -8.648258e-05
## 68: -2.281031e-03
## 69: -7.607974e-03
## 70: -2.909367e-02
## 71: -3.439369e-02
## 72: -3.944700e-02
## 73: -4.035840e-02
## 74: -9.726592e-02
## 75: -1.056429e-01
## 76: -1.466408e-01
## 77: -2.054773e-01
## 78: -7.314642e-01
## coef
## term
## 1: Intercept
## 2: I(1 >= -0.1725)
## 3: I(1 >= 0.3807)
## 4: I(2 >= 0.119)
## 5: I(1 >= -0.5685)
## 6: I(2 >= -0.1015)
## 7: I(1 >= 0.1386)
## 8: I(2 >= 1.0291)
## 9: I(2 >= -0.6441)
## 10: I(2 >= 0.6786)
## 11: I(2 >= 1.3571)*I(3 >= 1.4599)
## 12: I(1 >= 0.4752)
## 13: I(2 >= -0.9482)
## 14: I(1 >= -0.5113)
## 15: I(2 >= -0.1938)
## 16: I(2 >= -0.782)
## 17: I(2 >= -0.9921)
## 18: I(1 >= 0.8989)*I(3 >= -1.0482)
## 19: I(2 >= -0.313)
## 20: I(1 >= 0.0911)*I(2 >= 0.3539)
## 21: I(1 >= 0.0345)
## 22: I(1 >= 0.0855)*I(2 >= 0.4067)
## 23: I(1 >= -0.2045)
## 24: I(1 >= 0.634)
## 25: I(2 >= -0.6718)
## 26: I(2 >= 0.6959)
## 27: I(1 >= -0.4061)
## 28: I(1 >= 1.182)*I(2 >= 0.5866)*I(3 >= -0.5114)
## 29: I(1 >= -0.8398)
## 30: I(2 >= 0.1844)
## 31: I(1 >= 0.747)
## 32: I(1 >= -0.908)
## 33: I(2 >= 0.3545)
## 34: I(1 >= -0.1378)*I(3 >= -1.5644)
## 35: I(2 >= 0.6488)*I(3 >= -1.0124)
## 36: I(2 >= -0.7863)
## 37: I(2 >= 1.2184)*I(3 >= -0.6318)
## 38: I(2 >= 0.0232)
## 39: I(2 >= 0.2092)
## 40: I(2 >= 0.4512)
## 41: I(2 >= 0.2844)
## 42: I(1 >= -0.4766)
## 43: I(2 >= 0.8281)
## 44: I(2 >= 0.1737)*I(3 >= 1.8577)
## 45: I(1 >= 1.1353)*I(3 >= -1.0124)
## 46: I(2 >= -1.2687)
## 47: I(1 >= 0.747)*I(3 >= -1.753)
## 48: I(2 >= 0.1737)
## 49: I(1 >= -0.3937)*I(2 >= -0.782)
## 50: I(1 >= 0.3645)
## 51: I(1 >= 1.2006)*I(3 >= -1.011) OR I(1 >= 1.2006)*I(2 >= -1.1474)*I(3 >= -1.011)
## 52: I(1 >= -0.8452)*I(3 >= -0.628)
## 53: I(1 >= -0.3879)
## 54: I(1 >= 0.1485)*I(2 >= 0.9348)
## 55: I(1 >= 1.2523)*I(2 >= -0.0958)*I(3 >= -0.6158)
## 56: I(2 >= -0.3809)
## 57: I(2 >= 0.0444)
## 58: I(2 >= 0.3237)
## 59: I(1 >= 0.1302)
## 60: I(1 >= 0.747)*I(2 >= -0.1954)*I(3 >= -1.753)
## 61: I(1 >= 0.1921)
## 62: I(2 >= 0.738)
## 63: I(1 >= 0.634)*I(3 >= -0.8881)
## 64: I(1 >= -0.8398)*I(3 >= -1.5209)
## 65: I(1 >= -0.3937)
## 66: I(1 >= -0.4836)
## 67: I(2 >= -2.1625)*I(3 >= 0.3336)
## 68: I(1 >= -0.7038)*I(3 >= 0.3925)
## 69: I(2 >= -2.1828)
## 70: I(1 >= 1.7813)*I(2 >= 0.4792)*I(3 >= -0.9682) OR I(1 >= 2.0998)*I(2 >= 0.3074)*I(3 >= -1.3211)
## 71: I(1 >= 1.8165)*I(3 >= 0.7568) OR I(1 >= 1.8717)*I(3 >= 0.4161) OR I(1 >= 1.9814)*I(3 >= 0.0149) OR I(1 >= 2.2918)*I(3 >= 0.0013) OR I(1 >= 1.5647)*I(2 >= 1.3684)*I(3 >= 0.3794) OR I(1 >= 1.8165)*I(2 >= 0.3545)*I(3 >= 0.7568) OR I(1 >= 1.8717)*I(2 >= 0.2858)*I(3 >= 0.4161) OR I(1 >= 1.9814)*I(2 >= -0.7717)*I(3 >= 0.0149) OR I(1 >= 2.2918)*I(2 >= -1.0968)*I(3 >= 0.0013)
## 72: I(1 >= -2.1579)*I(3 >= 0.3054)
## 73: I(1 >= -2.3015)*I(2 >= 2.1782)
## 74: I(2 >= 2.0022)
## 75: I(2 >= 2.1818)
## 76: I(1 >= 2.441) OR I(1 >= 2.441)*I(2 >= -0.6441) OR I(1 >= 2.441)*I(3 >= -0.4632) OR I(1 >= 2.441)*I(2 >= -0.6441)*I(3 >= -0.4632)
## 77: I(2 >= -2.3508)
## 78: I(1 >= -3.0511)
## term
As described in Benkeser and van der Laan (2016), the HAL algorithm operates by first constructing a set of basis functions and subsequently fitting a Lasso model with this set of basis functions as the design matrix. Several approaches are considered for reducing this set of basis functions: 1. Removing duplicated basis functions (done by default in the fit_hal
function), 2. Removing basis functions that correspond to only a small set of observations; a good rule of thumb is to scale with \(\frac{1}{\sqrt{n}}\).
The second of these two options may be invoked by specifying the reduce_basis
argument to the fit_hal
function:
## [1] "Dave, although you took very thorough precautions in the pod against my hearing you, I could see your lips move."
## user.self sys.self elapsed user.child sys.child
## enumerate_basis 0.002 0.000 0.002 0 0
## design_matrix 0.005 0.000 0.005 0 0
## reduce_basis 0.005 0.000 0.005 0 0
## remove_duplicates 0.002 0.000 0.002 0 0
## lasso 0.736 0.005 0.740 0 0
## total 0.751 0.005 0.755 0 0
In the above, all basis functions with fewer than 7.0710678% of observations meeting the criterion imposed are automatically removed prior to the Lasso step of fitting the HAL regression. The results appear below
## coef
## 1: -8.509040e-01
## 2: 2.148854e-01
## 3: 1.901724e-01
## 4: 1.672545e-01
## 5: 1.650474e-01
## 6: 1.531422e-01
## 7: 1.510468e-01
## 8: 1.402116e-01
## 9: 1.326720e-01
## 10: 1.321485e-01
## 11: 1.249675e-01
## 12: 1.115810e-01
## 13: 1.114218e-01
## 14: 1.058269e-01
## 15: 9.755956e-02
## 16: 9.358887e-02
## 17: 7.869787e-02
## 18: 7.592093e-02
## 19: 7.429796e-02
## 20: 7.047377e-02
## 21: 6.739943e-02
## 22: 6.696098e-02
## 23: 6.191745e-02
## 24: 6.106494e-02
## 25: 5.956340e-02
## 26: 5.925512e-02
## 27: 5.917987e-02
## 28: 5.909718e-02
## 29: 5.404921e-02
## 30: 4.790203e-02
## 31: 4.695250e-02
## 32: 4.542445e-02
## 33: 4.247331e-02
## 34: 4.069736e-02
## 35: 4.002579e-02
## 36: 3.625673e-02
## 37: 3.592245e-02
## 38: 3.224930e-02
## 39: 3.203616e-02
## 40: 2.771001e-02
## 41: 2.653015e-02
## 42: 2.596375e-02
## 43: 2.559650e-02
## 44: 2.301804e-02
## 45: 2.256363e-02
## 46: 2.048626e-02
## 47: 1.949827e-02
## 48: 1.753155e-02
## 49: 1.373297e-02
## 50: 1.317963e-02
## 51: 1.029828e-02
## 52: 9.586481e-03
## 53: 9.319139e-03
## 54: 8.921289e-03
## 55: 6.506866e-03
## 56: 5.956689e-03
## 57: 4.749919e-03
## 58: 3.716205e-03
## 59: 1.188659e-03
## 60: 4.492954e-04
## 61: 3.206061e-04
## 62: 8.100191e-05
## 63: 6.515244e-05
## 64: 2.994393e-05
## 65: 1.740718e-05
## 66: -1.165468e-02
## 67: -1.663354e-02
## 68: -2.146260e-02
## 69: -9.731628e-02
## 70: -2.042386e-01
## 71: -6.717951e-01
## coef
## term
## 1: Intercept
## 2: I(1 >= -0.1725)
## 3: I(2 >= -0.1015)
## 4: I(1 >= 0.3807)
## 5: I(1 >= -0.5685)
## 6: I(2 >= 0.119)
## 7: I(2 >= 0.6786)
## 8: I(1 >= 0.1386)
## 9: I(2 >= 1.0291)
## 10: I(2 >= -0.6441)
## 11: I(1 >= 0.4752)
## 12: I(2 >= -0.9482)
## 13: I(2 >= -0.782)
## 14: I(2 >= -0.9921)
## 15: I(2 >= -0.313)
## 16: I(1 >= -0.5113)
## 17: I(1 >= 0.0911)*I(2 >= 0.3539)
## 18: I(1 >= 0.8989)*I(3 >= -1.0482)
## 19: I(2 >= 1.2184)*I(3 >= -0.6318)
## 20: I(1 >= 0.0855)*I(2 >= 0.4067)
## 21: I(2 >= -0.6718)
## 22: I(1 >= 0.747)
## 23: I(1 >= -0.4061)
## 24: I(1 >= 0.634)
## 25: I(1 >= -0.8398)
## 26: I(1 >= -0.2045)
## 27: I(2 >= -0.1938)
## 28: I(1 >= -0.908)
## 29: I(1 >= 0.0345)
## 30: I(2 >= 0.1844)
## 31: I(1 >= -0.1378)*I(3 >= -1.5644)
## 32: I(2 >= 0.8281)
## 33: I(2 >= 0.3545)
## 34: I(2 >= 0.6959)
## 35: I(2 >= 0.1737)
## 36: I(2 >= 0.0232)
## 37: I(2 >= 0.4512)
## 38: I(2 >= -0.7863)
## 39: I(2 >= -1.2687)
## 40: I(2 >= 0.2092)
## 41: I(2 >= 0.3237)
## 42: I(1 >= -0.1388)*I(2 >= 0.9292)
## 43: I(1 >= 1.1353)*I(3 >= -1.0124)
## 44: I(1 >= -0.4766)
## 45: I(1 >= -0.3937)*I(2 >= -0.782)
## 46: I(1 >= -0.3879)
## 47: I(2 >= 0.2844)
## 48: I(1 >= 0.6882)*I(3 >= -0.7632)
## 49: I(1 >= 0.747)*I(3 >= -1.753)
## 50: I(1 >= -0.8452)*I(3 >= -0.628)
## 51: I(2 >= 0.0444)
## 52: I(1 >= 1.2006)*I(3 >= -1.011) OR I(1 >= 1.2006)*I(2 >= -1.1474)*I(3 >= -1.011)
## 53: I(2 >= 1.3571)
## 54: I(1 >= 0.1921)
## 55: I(1 >= 1.1153)
## 56: I(2 >= 0.5866)*I(3 >= -0.5114)
## 57: I(1 >= 1.1153)*I(2 >= -0.3809)
## 58: I(1 >= 1.1072)*I(3 >= -1.4661) OR I(1 >= 1.1072)*I(2 >= -1.3281)*I(3 >= -1.4661)
## 59: I(1 >= -0.4836)
## 60: I(1 >= 0.0954)
## 61: I(1 >= -0.3937)
## 62: I(1 >= 0.5926)
## 63: I(1 >= 0.1884)
## 64: I(1 >= 0.1023)
## 65: I(2 >= 0.3466)
## 66: I(3 >= 0.505)
## 67: I(2 >= -2.1828)
## 68: I(1 >= -2.1579)*I(3 >= 0.3054)
## 69: I(2 >= 1.4561)
## 70: I(2 >= -2.3508)
## 71: I(1 >= -3.0511)
## term
# training sample prediction for HAL vs HAL9000
mse <- function(preds, y) {
mean((preds - y)^2)
}
preds_hal <- predict(object = hal_fit, new_data = x)
mse_hal <- mse(preds = preds_hal, y = y)
mse_hal
## [1] 0.02493478
oob_hal <- predict(object = hal_fit, new_data = test_x)
oob_hal_mse <- mse(preds = oob_hal, y = test_y)
oob_hal_mse
## [1] 1.543119
Benkeser, David, and Mark J van der Laan. 2016. “The Highly Adaptive Lasso Estimator.” In 2016 IEEE International Conference on Data Science and Advanced Analytics (DSAA). IEEE. https://doi.org/10.1109/dsaa.2016.93.
van der Laan, Mark J. 2017. “Finite Sample Inference for Targeted Learning.” https://arxiv.org/abs/1708.09502.