This week, we review the LassoNet method, from its motivation to its theoretical justification.
"LassoNet: Neural Networks with Feature Sparsity" is a recent paper (to appear in JMLR) by Ismael Lemhadri, Feng Ruan, Louis Abraham and Rob Tibshirani. In this work, the authors develop a principled method to train neural networks with learned weights that are sparse with respect to their inputs.
Why the LassoNet?
- In supervised learning, training a model that is sparse with respect to its inputs is appealing: it lets the practitioner identify covariates the are most relevant to predicting the response variable. In the meantime, neural networks are extremely flexible models that perform very well on an ever wider range of supervised learning tasks.
- The LassoNet is meaningful step in this direction, towards feature selection with neural networks.
- In principle, the LassoNet could be useful to Sisu as an improvement to our fact selection procedure.
- Today, our statistical models for fact selection are mostly Generalized Linear Models. On some datasets, such linear models might not be appropriate to model the metric of interest and neural networks are appealing.
- Unfortunately, in its present form, the LassoNet is not appropriate to our setting. Proper training of LassoNet models must be done "from dense to sparse". I.e, one must first train an unpenalized neural network including all input covariates before training a sparse neural network that includes only a subset of them. At Sisu, it is not uncommon to work with hundreds of thousands of input covariates, and training a dense network in infeasible in such settings.
The LassoNet method blends three ingredients:
- The addition of a "skip connection", from the inputs to the output, for the neural network architecture selected on a given task.
- The addition of a Lasso penalty and infinite norm constraints to the problem's objective function. The Lasso penalty enforces sparsity in the skip connection parameters, the infinite norm constraints represent a "hierarchy" over parameters, with the skip connection parameters at the top.
- The use of novel and principled learning algorithm, inspired by stochastic proximal gradient descent and with proven optimization guarantees.
If you like applying these kinds of methods practical ML problems, join our team.