Figure 1. Decision Tree Visualization |
Decision Trees with H2O
With release 3.22.0.1 H2O-3 (a.k.a. open source H2O or simply H2O) added to its family of tree-based algorithms (which already included DRF, GBM, and XGBoost) support for one more: Isolation Forest (random forest for unsupervised anomaly detection). There were no simple way to visualize H2O trees except following clunky (albeit reliable) method of creating a MOJO object and running combination of Java and dot commands.That changed in 3.22.0.1 too with introduction of unified Tree API to work with any of the tree-based algorithms above. Data scientists are now able to utilize powerful visualization tools in R (or Python) without resorting to producing intermediate artifacts like MOJO and running external utilities. Please read this article by Pavel Pscheidl who did superb job of explaining H2O Tree API and S4 classes in R before coming back to take it a step further to visualize trees.
The Workflow: from Data to Decision Tree
Whether you are still here or came back after reading Pavel's excellent post let's set goal straight: create single decision tree model in H2O and visualize its tree graph. With H2O there is always a choice between using Python or R - the choice for R here will become clear when discussing its graphical and analytical capabilities later.CART models operate on labeled data (classification and regression) and offer arguably unmatched model interpretability by means of analyzing a tree graph. In data science there is never single way to solve given problem so let's define end-to-end logical workflow from "raw" data to visualized decision tree:
Figure 2. Workflow of tasks in this post |
- R package data.table for data munging
- H2O grid for hyper-parameter search
- H2O GBM for modeling single decision tree algorithm
- H2O Tree API for tree model representation
- R package data.tree for visualization
Figure 3. Workflow of tasks in this post with implementation details |
Discussion of this workflow continues for the rest of this post.
Titanic Dataset
The famous Titanic dataset contains information about the fate of passengers of the RMS Titanic that sank after colliding with an iceberg. It regularly serves as toy data for blog exercises like this.H2O public S3 bucket holds the Titanic dataset readly available and using package data.table makes it fast one-liner to load into R:
Data Engineering
Passenger features from the Titanic dataset are discussed at length online, e.g. see Predicting the Survival of Titanic Passengers and Predicting Titanic Survival using Five Algorithms. To summarize the following features were selected and engineered for decision tree model:- survived indicates if passenger survived the wreck
- boat and body leak survival outcome and were dropped completely before modeling
- name and cabin are too noisy as they are and only used to derive new features
- title is parsed from name
- cabin_type is parsed from cabin
- family_size and family_type are derived from combination of count features sibsp (siblings+spouse) and parch (parents+children)
- ticket and home.dest are dropped to preserve simplicity of the model
- missing values in age and fare are imputed using target encoding (mean) over grouping by survived, sex, and embarked columns.
Starting with H2O
Creating models with H2O requires running a server process (remote or local) and a client (package h2o in R available from CRAN) where the latter connects and sends commands to the former. The Tree API was introduced with release 3.22.0.1 (10/26/2018) but due to CRAN policies h2o package usually lags several versions behind (on the time of this writing CRAN hosted version 3.20.0.8). There are two ways to work around this:- Install and run package available from CRAN and use strict_version_check=FALSE inside h2o.connect() to communicate with newer version running on server
- Or install the latest version of h2o available from H2O repository either to connect to remote server or to both connect and run server locally.
Building Decision Tree with H2O
While H2O offers no dedicated single decision tree algorithm there two approaches using superseding models:- Distributed Random Forest (DRF) function h2o.randomForest() with arguments
ntrees = 1
mtries = number of features (would be determined dynamically at runtime)
sample_rate = 1
min_rows = 1 - Gradient Boosting Machine (GBM) function h2o.gbm() with arguments
ntrees = 1
min_rows = 1
sample_rate = 1
col_sample_rate = 1
Choosing GBM option requires one less line of code (no need to calculate number of features to set mtries) so it was used for this post. Otherwise both ways result in the same decision tree with the steps below fully reproducible using h2o.randomForest() instead of h2o.gbm().
Decision Tree Depth
When building single decision tree models maximum tree depth stands as the most important parameter to pick. Shallow trees tend to underfit by failing to capture important relationships in data producing similar trees despite varying training data (error due to high bias). On the other hand trees grown too deep overfit by reacting to noise and slight changes in data (error due to high variance). Tuning H2O model's parameter max_depth that limits decision tree depth aims at balancing the effects of bias and variance. In R using H2O to split data and to tune the model, then visualizing results with ggplot to look for right value unfolds like this:- split Titanic data into training and validation sets
- define grid search object with parameter max_depth
- launch grid search on GBM models and grid object to obtain AUC values (model performance)
- plot grid model AUC'es vs. max_depth values to determine "inflection point" where AUC growth stops or saturates (see plot below)
- register tree depth value at inflection point to use in the final model
and produces chart that points to inflection point for maximum tree depth at 5:
Creating Decision Tree
As evident from the Figure 4 optimal decision tree depth is 5. The code below constructs single decision tree model in H2O and then retrieves tree representation from a GBM model with Tree API functionh2o.getModelTree()
, which creates an instance of S4 class H2OTree
and assigns to variable titanicH2oTree
:
At this point all action moved back inside R with its unparalleled access to analytical and visualization tools. So before navigating and plotting a decision tree - final goal for this post - let's have brief intro to networks in R.
Overview of Network Analysis in R
R offers arguably the richest functionality when it comes to analyzing and visualizing network (graph, tree) objects. Before taking on the task of conquering it spend time visiting a couple of comprehensive articles describing vast landscape of tools and approaches available: Static and dynamic network visualization with R by Katya Ognyanova and Introduction to Network Analysis with R by Jesse Sadler.To summarize there are two commonly used packages to manage and analyze networks in R:
network
(part of statnet family) and igraph
(family in itself). Each package implements namesake classes to represent network structures so there is significant overlap between the two and they mask each other's functions. Preferred approach is picking only one of two: it appears that igraph
is more common for general-purpose applications while network
is preferred for social network and statistical analysis (my subjective assessment). And while researching these packages do not forget about package intergraph
that seamlessly transforms objects between network
and igraph
classes. (And this analysis stopped short of expanding into universe of R packages hosted on Bioconductor).When it comes to visualizing networks choices quickly proliferate. Both
network
and igraph
offer graphical functions that use R base plotting system but it doesn't stop here. Following packages specialize in advanced visualizations for at least one or both of the classes:ggraph
ggnet2
ggnetwork
visNetwork
DiagrammeR
networkD3
Finally, there is package
data.tree
designed specifically to create and analyze trees in R. It fits the bill of representing and visualizing decision trees perfectly, so it became a tool of choice for this post. Still, visualizing
H2O model trees could be fully reproduced with any of network and visualization packages mentioned above.
Visualizing H2O Trees
In the last step a decision tree for the model created by GBM moved from H2O cluster memory toH2OTree
object in R by means of Tree API. Still, specific to H2O the H2OTree
object now contains necessary details about decision tree, but not in the format understood by R packages such asdata.tree.
To fill this gap function
createDataTree(H2OTree)
created that traverses a tree and translates it from H2OTree
into data.tree
accumulating information about decision tree splits and predictions into node and edge attributes of a tree:
Finally everything lined up and ready for the final step of plotting decision tree:
- single decision tree model created in H2O
- its structure made available in R
- and translated to specialized
for network analysis.data.tree
data.tree
objects is built around rich functionality of the DiagrammerR
package. For anything that goes beyond simple plotting read documentation here but also remember that for plotting data.tree
takes advantage of:- hierarchical nature of tree structures
- GraphViz attributes to style graph, node and edge properties
- and dynamic callback functions (in this example
GetEdgeLabel(node), GetNodeShape(node), GetFontName(node)
) to customize tree's feel and look
Figure 5. H2O Decision Tree for Titanic Model Visualized in R using data.tree package |
References
- Anomaly Detection with Isolation Forests using H2O
- Changes in H2O Xia (3.22.0.1) - 10/26/2018
- Visualizing H2O GBM and Random Forest MOJO Model Trees
- Inspecting Decision Trees in H2O
- Classification and Regression Trees
- Predicting the Survival of Titanic Passengers
- Predicting Titanic Survival using Five Algorithms
- Distributed Random Forest (DRF)
- Gradient Boosting Machine (GBM)
- Inflection Point
- Static and dynamic network visualization with R by Katya Ognyanova
- Introduction to Network Analysis with R by Jesse Sadler
- CRAN Graphical Models in R Task View
- Bioconductor GraphAndNetwork packages
- Introduction to data.tree
- DiagrammeR package on github
- Node, Edge, and Graph attributes for Graphviz tools
- Public GitHub gist with source code
No comments:
Post a Comment