(via bootstrap aggregation, in ten minutes or less)
If \(Y_1, \ldots,Y_n\) are independent RVs with variance \(\sigma^2\) then
\[ \text{Var}(\bar Y) = \text{Var}\left(\frac{1}{n}\sum_{i=1}^{n} Y_i\right) = \frac{\sigma^2}{n}.\] This suggests that we could reduce the sensitivity of our tree-based predictions to the particular values in our dataset by:
This is the first example of an ensemble model. Common in weather/hazard modelling.
Issue: we only have one set of training data.
Bootstrap Aggregation
\[ \hat f(x) = \frac{1}{B}\sum_{i=1}^{n}\hat f_b(x).\]
Claim: This has a stabilising effect, reducing sensitivity to the particular data values we observe.
Fit a classification tree \(\hat f_1(x)\) to the first bootstrapped data set
Fit a classification tree \(\hat f_2(x)\) to the second bootstrapped data set
Take point-wise modal class over all trees as prediction.
Suppose we have one very strong predictor and several moderately strong predictors.
Exactly the same as bagging, but using only a random subset of \(m << p\) predictors to determine each split.
Seems like throwing away information, but can help us to reduce the dependence between the trees that we are aggregating.
Generalising what we had before:
\[ \text{Var}(\bar Y) = \frac{\sigma^2}{n} + \underset{\text{minimise this}}{\underbrace{\frac{2}{n} \sum_{i<j} \text{Cov}(Y_i, Y_j)}}.\]
Bioinformatics & multicollinearity: multiple genes expressed together, random subsetting allows us to “spread” the influence on the outcome across these genes.
Random forests are an ensemble model, averaging the prediction of multiple regression or classification trees.
Bootstrap aggregation reduces variance of the resulting predictions by averaging over multiple data sets that we might have seen.
Variance is further reduced by considering a random subset of predictors at each split, in order to decorrelate the trees we are aggregating.
Lab: Comparison of single tree, bagging and random forest for student grade prediction.
R version 4.3.3 (2024-02-29)
Platform: x86_64-apple-darwin20 (64-bit)
locale: en_US.UTF-8||en_US.UTF-8||en_US.UTF-8||C||en_US.UTF-8||en_US.UTF-8
attached base packages: stats, graphics, grDevices, utils, datasets, methods and base
other attached packages: DescTools(v.0.99.58), rpart.plot(v.3.1.2), rpart(v.4.1.23), glue(v.1.8.0), ggplot2(v.3.5.1), dplyr(v.1.1.4) and palmerpenguins(v.0.1.1)
loaded via a namespace (and not attached): gld(v.2.6.6), gtable(v.0.3.5), xfun(v.0.43), lattice(v.0.22-5), vctrs(v.0.6.5), tools(v.4.3.3), generics(v.0.1.3), tibble(v.3.2.1), proxy(v.0.4-27), fansi(v.1.0.6), pkgconfig(v.2.0.3), Matrix(v.1.6-1.1), data.table(v.1.15.4), readxl(v.1.4.3), assertthat(v.0.2.1), lifecycle(v.1.0.4), rootSolve(v.1.8.2.4), compiler(v.4.3.3), farver(v.2.1.2), stringr(v.1.5.1), Exact(v.3.3), munsell(v.0.5.1), htmltools(v.0.5.8.1), class(v.7.3-22), yaml(v.2.3.8), pillar(v.1.9.0), crayon(v.1.5.3), MASS(v.7.3-60.0.1), boot(v.1.3-29), tidyselect(v.1.2.1), digest(v.0.6.35), mvtnorm(v.1.2-5), stringi(v.1.8.4), pander(v.0.6.5), purrr(v.1.0.2), showtextdb(v.3.0), labeling(v.0.4.3), forcats(v.1.0.0), zvplot(v.0.0.0.9000), fastmap(v.1.1.1), grid(v.4.3.3), colorspace(v.2.1-1), lmom(v.3.2), expm(v.1.0-0), cli(v.3.6.3), magrittr(v.2.0.3), emo(v.0.0.0.9000), utf8(v.1.2.4), e1071(v.1.7-14), withr(v.3.0.1), scales(v.1.3.0), showtext(v.0.9-7), lubridate(v.1.9.3), timechange(v.0.3.0), rmarkdown(v.2.26), httr(v.1.4.7), sysfonts(v.0.8.9), cellranger(v.1.1.0), hms(v.1.1.3), evaluate(v.0.23), knitr(v.1.45), haven(v.2.5.4), rlang(v.1.1.4), Rcpp(v.1.0.12), rstudioapi(v.0.16.0), jsonlite(v.1.8.8) and R6(v.2.5.1)
November 2024 - Zak Varty