Let's train a tree with 2 layers on the famous iris dataset using all the data and print the resulting rules using the brand new function export_text:
from sklearn.tree import DecisionTreeClassifier from sklearn.tree.export import export_text from sklearn.datasets import load_iris iris = load_iris() X = iris['data'] y = ['setosa']*50+['versicolor']*50+['virginica']*50 decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2) decision_tree = decision_tree.fit(X, y) r = export_text(decision_tree, feature_names=iris['feature_names']) print(r)
|--- petal width (cm) <= 0.80 | |--- class: setosa |--- petal width (cm) > 0.80 | |--- petal width (cm) <= 1.75 | | |--- class: versicolor | |--- petal width (cm) > 1.75 | | |--- class: virginicaReading the them we note that if the feature petal width is less or equal than 80mm the samples are always classified as setosa. Otherwise if the petal width is less or equal than 1.75cm they're classified as versicolor or as virginica if the petal width is more than 1.75cm. This model might well suffer from overfitting but tells us some important details of the data. It's easy to note that the petal width is the only feature used, we could even say that the petal width is small for setosa samples, medium for versicolor and large for virginica.
To understand how the rules separate the labels we can also print the number of samples from each class (class weights) on the leaves:
r = export_text(decision_tree, feature_names=iris['feature_names'], decimals=0, show_weights=True) print(r)
|--- petal width (cm) <= 1 | |--- weights: [50, 0, 0] class: setosa |--- petal width (cm) > 1 | |--- petal width (cm) <= 2 | | |--- weights: [0, 49, 5] class: versicolor | |--- petal width (cm) > 2 | | |--- weights: [0, 1, 45] class: virginicaHere we have the number of samples per class among square brackets. Recalling that we have 50 samples per class, we see that all the samples labeled as setosa are correctly modelled by the tree while for 5 virginica and 1 versicolor the model fails to capture the information given by the label.
Check out the documentation of the function export_text to discover all its capabilities here.