2
$\begingroup$

I have a regression problem with around 1 million samples and 400 features (some not too meaningful and/or are redundant) and 1 target variable. I have been trying very hard to design a neural network architecture that beats linear regression (with regularization) in validation data, but what I'm getting so far is about in-line with linear regression and not meaningfully better.

I'm mostly focusing on feed forward networks with some complexity added such as:

  • Multiple dense layers (with varying choices for number of units, normalization, activation, dropout, etc.).
  • With or without residual connections.
  • Parallel networks with some gating, which depends on a subset of features. Something like stratified linear regression.
  • Train an autoencoder first to reduce input dimensionality, then do all of the above.
  • Briefly tried CNN and other more complex layers but that is not my focus.

They all perform about the same as linear regression if not worse. What's a bit strange to me is why is it so difficult to beat linear regression, I would think adding a small nonlinearity to a baseline linear model should at least help a little.

My question is what would an advanced ML practitioner try other than the above? Any pointers to standard approaches or out of the box ideas are highly appreciated.

$\endgroup$
6
  • 5
    $\begingroup$ An advanced ML practitioner might consider if the extra effort to beat what is essentially the baseline model is really worth it. $\endgroup$ Commented Dec 28, 2023 at 0:29
  • 1
    $\begingroup$ I'd give AutoGluon Tabular a try and if there are no improvements over the linear regression, I'd listen to DemetriPananos's advice. $\endgroup$ Commented Dec 28, 2023 at 3:56
  • $\begingroup$ Thanks @dipetkov, that's a great suggestion $\endgroup$ Commented Dec 28, 2023 at 12:34
  • $\begingroup$ Yes any improvement, even small, on top of a linear model would have direct impact in the final product $\endgroup$ Commented Dec 28, 2023 at 12:36
  • 2
    $\begingroup$ To maximize performance for tabular data, ML people often use boosted trees. So the question is a bit tricky to answer. What makes your setting especially difficult for a normal nn is the high feature count: usually, the first layer has more nodes than inputs, in your case let's say 1000. But that gives already a much too high number of parameters to fit, given the comparably small number of observations. $\endgroup$ Commented Dec 31, 2023 at 13:49

3 Answers 3

3
$\begingroup$

The "ubiquity" of machine learning model superiority over regular linear models is a myth. Often they can provide more predictive power, but this is context-dependent, they come at the costs of an extreme amount of required data to beat typical regression, and there is an extreme time sink related to things like data collection and wrangling. Despite recent popularity, AI has it's own issues. A recent pre-print shows that these models are going to be more and more susceptible to degeneration without sophisticated watermarking of data. So to summarize, ML and AI are excellent tools, but they're not silver bullets.

We can think of this in a very simple case with the following data, where we fit a typical OLS regression in R and a smooth spline to the same data. As you may already be aware, splines offer a lot of flexibility in fitting (specifically for nonlinear patterns) and are relatively painless in simple cases to implement. In theory, it should do a better job of fitting data.

#### Sim and Fit Data #### set.seed(123) x <- rnorm(100) y <- (2*x) + rnorm(100) fit1 <- lm(y ~ x) fit2 <- smooth.spline(x,y) #### Plot Fits #### par(mfrow=c(1,2)) plot(x,y,main="OLS Regression") abline(fit1,col="red") plot(x,y,main="Spline Regression") lines(fit2,col="blue") 

However, we can see quite clearly from the plots that there is essentially no difference in the fitting because the data is pretty linear and thus implementing splines isn't the necessary good we had hoped for.

enter image description here

I specifically highlight this case because you mentioned the following:

What's a bit strange to me is why is it so difficult to beat linear regression, I would think adding a small nonlinearity to a baseline linear model should at least help a little.

Preceded by this:

They all perform about the same as linear regression if not worse.

Overfitting with a nonlinear regression can lead to pretty disastrously poor fits, so your findings are not surprising. Consider this badly fit polynomial regression, which tries to trace as many points in the plot as possible.

#### Extreme Polynomial #### fit3 <- lm(y ~ poly(x,20)) par(mfrow=c(1,1)) plot(x,y,main="Polynomial Regression") newdata <- data.frame(x = seq(min(x),max(x),length.out=200)) pred <- predict(fit3,newdata = newdata) lines(newdata$x,pred,col="red") 

We now have a model that makes no sense and has extremely poor predictive power:

enter image description here

Scaling this up by dimension doesn't change the underlying principle if there are several features like your data, and this is just as true in for machine learning as it is vanilla regression. The question then becomes, as pointed out in the comments, why one should even bother. If we can very clearly fit the data well with simpler methods, is it worth expending time and effort on? One of the nice things linear regression will almost always have over ML is that the fit is far more interpretable and actionable. Sacrificing that for a fancy CNN isn't always ideal if you can get away with the "old" methods.

$\endgroup$
7
  • 1
    $\begingroup$ Not all neural networks are generative (as in the Self-Consuming Generative Models Go MAD paper); the OP seems to be training non-generative NNs (I'd guess for learning purposes). $\endgroup$ Commented Dec 28, 2023 at 2:23
  • $\begingroup$ Yes I meant to simply note that sophistication of models =\= usefulness of models. $\endgroup$ Commented Dec 28, 2023 at 2:26
  • $\begingroup$ Thanks a lot, this a great comment and very helpful, but I think the underlying assumption in the example is that the underlying relationship is linear. Indeed if that's the case a nonlinear approach wouldn't make sense. My motivation was more for a case where there is domain intuition that the relationship is not necessarily linear. In that case, what should we try before declaring that the intuition is incorrect (with reasonable certainty)? $\endgroup$ Commented Dec 28, 2023 at 12:30
  • 1
    $\begingroup$ There is no universal answer for that, as it will depend largely on the type of data, modeling technique, and utility of such a model. It seems you have exhausted many options and linear regression still wins out. I would say that if a handful of ML methods aren't beating linear methods, then stick with the linear models. The motivation to do otherwise isn't clear to me. As for the nonlinearity issue, my point wasn't to assume the data was linear, but rather that overcomplicating fitting (such as fitting splines to clearly linear data) can unnecessary complicate things that aren't complex. $\endgroup$ Commented Dec 28, 2023 at 13:10
  • 2
    $\begingroup$ What I dislike about your answer: the key to a reasonable linear model is feature preprocessing: where to put a log, where to add an interaction, how to deal with high-cardinality categoricals etc. To do this for 400 features is basically impossible. A well-built "modern" ML model (not a neural net usually) will tend to deal with this automatically and thus often performs substantially better than poorly built linear model. Again: 400 features almost imply that the linear model is poorly built. $\endgroup$ Commented Dec 31, 2023 at 13:58
3
$\begingroup$

I would skip NN. there is no established training procedure: try SGD, if that fails try ADAM, try different learning rates etc.

Instead use linear regression with eg splines and regularisation, and/or XGBOOST. if neither perform better then there is no point investigating neural nets.

$\endgroup$
1
$\begingroup$

To avoid overfitting, it might be helpful to try different regularization techniques (L1, L2, BIC, ...) and different strengths of each.

The other thing that comes to mind would be the learning rate (try different step sizes or a scheduler), and checking convergence and overfitting in the train and crossvalidation loss curves.

(Also, see if it helps to standardize the data if you haven't done this already, this is especially important for the regularization to work properly.)

$\endgroup$

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.