Skip to content

NeuroTree - A differentiable tree operator for tabular data

Overview

This work introduces NeuroTree a differentiable binary tree operator adapted for the treatment of tabular data.

  • Address the shortcoming of traditional trees greediness: all node and leaves are learned simultaneously. It provides the ability to learn an optimal configuration across all of the tree levels.

The notion extent also to the collection of trees that are simultaneously learned.

  • Extend the notion of forest/bagging and boosting.

    • Although the predictions from the all of of the trees forming a NeuroTree operator are averaged, each of the tree prediction tuned simultaneously. This is different from boosting (ex XGBoost) where each tree is learned sequentially and over the residual from previous trees. Also, unlike random forest and bagging, trees aren't learned in isolation but tuned collaboratively, resulting in predictions that account for all of the other trees predictions.
  • General operator compatible for composition.

    • Allows integration within Flux's Chain like other standard operators from NNLib. Composition is also illustrated through the built-in StackTree layer, a residual composition of multiple NeuroTree building blocks.
  • Compatible with general purpose machine learning framework.

    • MLJ integration

Architecture

To introduce the implementation of a NeuroTree, we first get back to the architecture of a basic decision tree.

decision-tree

The above is a binary decision tree of depth 2.

Highlighted in green is the decision path taken for a given sample. It goes into depth number of binary decisions, resulting in the path node1 → node3 → leaf3.

One way to view the role of the decision nodes is to provide an index of the leaf prediction to fetch (index 3 in the figure). Such indexing view is applicable given that node routing relies a hard conditions: it's either true or false.

An alternative perspective that we will adopt here is that tree nodes collectively provide weights associated to each leaf. A tree prediction is the weighted sum of the leaf's value and those leaf weights. In regular decision trees, since all conditions are binary, leaf weights take the form of a mask. In the above example, the mask is [0, 0, 1, 0].

By relaxing these these hard condition into soft ones, the mask takes the form of a probability vector associated to each leaf, where ∑(leaf_weights) = 1 and where each each leaf_weight element is [0, 1].

The following illustrate how a basic decision tree is represented as a single differentiable tree within NeuroTree:

decision-tree

Node weights

To derive how a NeuroTree performs those soft decision, we first break down the structure of how the traditional hard decisions are taken. A nodes's split actually relies on 2 binary conditions:

1. Selection of the feature on which to perform the condition

The selection of a feature out of the selected ones In NeuroTree, these hard decisions are translated into soft, differentiable ones: 1.

2. Selection of the condition's threshold value

A NeuroTree operator acts as collection of complete binary trees, ie. trees without any pruned node. In order to be differentiable, hence trainable using gradient based methods such as Adam, each tree path implements a soft decision rather than a hard one like in traditional decision tree.

Leaf weights

Computing the leaf weights consists of accumulating the weights through each tree branch. It's the technically more challenging part as such computation cannot be represented as a form of matrix multiplication, unlike other common operators like Dense, Conv or MultiHeadAttention / Transformer. Performing probability accumulation though a tree index naturally leads to in-place element wise operations, which are notoriously not friendly for auto-differentiation engines. Since NeuroTree was intended to integrate with the Flux.jl ecosystem, Zygote.jl acts as the underlying AD, the approach used was to to manually implement backward / adjoint of the terminal leaf function and instruct the AD to use that custom rule rather than attempt to differentiate a non-AD compliant function.

Below are the algo and actual implementation of the forward and backward function that compute the leaf weights. For brevity, the loops over each observation of the batch and each tree are omitted. Parallelism, both on CPU and GPU, is obtained through parallelization over the tree and batch dimensions.

Forward

julia
function leaf_weights!(nw)
    cw = ones(eltype(nw), 2 * size(nw, 1) + 1)

    for i = 2:2:size(cw, 1)
        cw[i] = cw[i>>1] * nw[i>>1]
        cw[i+1, tree, batch] = cw[i>>1] * (1 - nw[i>>1])
    end
    
    lw = cw[size(nw, 1)+1:size(cw, 1)]
    return (cw, lw)
end

Backward

julia
function Δ_leaf_weights!(Δnw, ȳ, cw, nw, max_depth, node_offset)
    
    for i in axes(nw, 1)        
        depth = floor(Int, log2(i)) # current depth level - starting at 0
        step = 2^(max_depth - depth) # iteration length
        leaf_offset = step * (i - 2^depth) # offset on the leaf row

        for j = (1+leaf_offset):(step÷2+leaf_offset)
            k = j + node_offset # move from leaf position to full tree position 
            Δnw[i] += ȳ[j] * cw[k] / nw[i]
        end
    
        for j = (1+leaf_offset+step÷2):(step+leaf_offset)
            k = j + node_offset
            Δnw[i] -= ȳ[j] * cw[k] / (1 - nw[i])
        end
    end

    return nothing
end

Tree prediction

Composability

  • StackTree

  • General operator: Chain neurotree with MLP

Benchmarks

For each dataset and algo, the following methodology is followed:

  • Data is split in three parts: train, eval and test

  • A random grid of 16 hyper-parameters is generated

  • For each parameter configuration, a model is trained on train data until the evaluation metric tracked against the eval stops improving (early stopping)

  • The trained model is evaluated against the test data

  • The metric presented in below are the ones obtained on the test for the model that generated the best eval metric.

Source code available at MLBenchmarks.jl.

For performance assessment, benchmarks is run on the following selection of common Tabular datasets:

  • Year: min squared error regression

  • MSRank: ranking problem with min squared error regression

  • YahooRank: ranking problem with min squared error regression

  • Higgs: 2-level classification with logistic regression

  • Boston Housing: min squared error regression

  • Titanic: 2-level classification with logistic regression

Comparison is performed against the following algos (implementation in link) considered as state of the art on classification tasks:

Boston

model_typetrain_timemsegini
neurotrees12.818.90.947
evotrees0.20619.70.927
xgboost0.064819.40.935
lightgbm0.86525.40.926
catboost0.051113.90.946

Titanic

model_typetrain_timeloglossaccuracy
neurotrees7.580.4070.828
evotrees0.6730.3820.828
xgboost0.03790.3750.821
lightgbm0.6150.3900.836
catboost0.03260.3880.836

Year

model_typetrain_timemsegini
neurotrees280.076.40.652
evotrees18.680.10.627
xgboost17.280.20.626
lightgbm8.1180.30.624
catboost80.079.20.635

MSRank

model_typetrain_timemsendcg
neurotrees39.10.5780.462
evotrees37.00.5540.504
xgboost12.50.5540.503
lightgbm37.50.5530.503
catboost15.10.5580.497

Yahoo

model_typetrain_timemsendcg
neurotrees417.00.5840.781
evotrees687.00.5450.797
xgboost120.00.5470.798
lightgbm244.00.5400.796
catboost161.00.5610.794

Higgs

model_typetrain_timeloglossaccuracy
neurotrees12300.00.4520.781
evotrees2620.00.4640.776
xgboost1390.00.4620.776
lightgbm1330.00.4610.779
catboost7180.00.4640.775

References