Decision Trees (DTs) are commonly used for many machine
learning tasks due to their high degree of interpretability.
However, learning a DT from data is a difficult optimization
problem, as it is non-convex and non-differentiable.
Therefore, common approaches learn DTs using a greedy growth
algorithm that minimizes the impurity locally at each
internal node. Unfortunately, this greedy procedure can lead
to inaccurate trees. In this paper, we present a novel
approach for learning hard, axis-aligned DTs with gradient
descent. The proposed method uses backpropagation with a
straight-through operator on a dense DT representation, to
jointly optimize all tree parameters. Our approach
outperforms existing methods on binary classification
benchmarks and achieves competitive results for multi-class
tasks. The method is available under:
https://github.com/s-marton/GradTree
%0 Journal Article
%1 Marton2023-op
%A Marton, Sascha
%A Lüdtke, Stefan
%A Bartelt, Christian
%A Stuckenschmidt, Heiner
%D 2023
%K GradTree Learning Zno axis-aligned decision descent gradient trees
%T GradTree: Learning axis-aligned decision trees with gradient descent
%X Decision Trees (DTs) are commonly used for many machine
learning tasks due to their high degree of interpretability.
However, learning a DT from data is a difficult optimization
problem, as it is non-convex and non-differentiable.
Therefore, common approaches learn DTs using a greedy growth
algorithm that minimizes the impurity locally at each
internal node. Unfortunately, this greedy procedure can lead
to inaccurate trees. In this paper, we present a novel
approach for learning hard, axis-aligned DTs with gradient
descent. The proposed method uses backpropagation with a
straight-through operator on a dense DT representation, to
jointly optimize all tree parameters. Our approach
outperforms existing methods on binary classification
benchmarks and achieves competitive results for multi-class
tasks. The method is available under:
https://github.com/s-marton/GradTree
@article{Marton2023-op,
abstract = {Decision Trees (DTs) are commonly used for many machine
learning tasks due to their high degree of interpretability.
However, learning a DT from data is a difficult optimization
problem, as it is non-convex and non-differentiable.
Therefore, common approaches learn DTs using a greedy growth
algorithm that minimizes the impurity locally at each
internal node. Unfortunately, this greedy procedure can lead
to inaccurate trees. In this paper, we present a novel
approach for learning hard, axis-aligned DTs with gradient
descent. The proposed method uses backpropagation with a
straight-through operator on a dense DT representation, to
jointly optimize all tree parameters. Our approach
outperforms existing methods on binary classification
benchmarks and achieves competitive results for multi-class
tasks. The method is available under:
https://github.com/s-marton/GradTree},
added-at = {2025-01-07T15:21:09.000+0100},
author = {Marton, Sascha and L{\"u}dtke, Stefan and Bartelt, Christian and Stuckenschmidt, Heiner},
biburl = {https://puma.scadsai.uni-leipzig.de/bibtex/28db484d92220c0eb5f2860aaccda8247/scadsfct},
eprint = {2305.03515},
interhash = {70715c13901f23d49c08857692cf8c94},
intrahash = {8db484d92220c0eb5f2860aaccda8247},
keywords = {GradTree Learning Zno axis-aligned decision descent gradient trees},
primaryclass = {cs.LG},
timestamp = {2025-01-31T11:42:58.000+0100},
title = {{GradTree}: Learning axis-aligned decision trees with gradient descent},
year = 2023
}