从 GradientBoostingClassifier 中提取决策规则

Extracting decision rules from GradientBoostingClassifier(从 GradientBoostingClassifier 中提取决策规则)
本文介绍了从 GradientBoostingClassifier 中提取决策规则的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着跟版网的小编来一起学习吧!

问题描述

我已经解决了以下问题:

I have gone through the below questions:

如何提取 GradientBosstingClassifier 的决策规则

如何从中提取决策规则scikit-learn 决策树?

但是以上两个并没有解决我的目的.以下是我的查询:

However the above two does not solve my purpose. Below is my query:

我需要使用 gradientboostingclassifer 在 Python 中构建一个模型,并在 SAS 平台中实现这个模型.为此,我需要从 gradientboostingclassifer 中提取决策规则.

I need to build a model in Python using gradientboostingclassifer and implement this model in SAS platform. To do this I need to extract decision rules from the gradientboostingclassifer .

以下是我目前尝试过的:

Below is what I have tried so far:

在 IRIS 数据上构建模型:

Build the model on the IRIS data:

# import the most common dataset
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz
from sklearn.externals.six import StringIO  
from IPython.display import Image

X, y = load_iris(return_X_y=True)
# there are 150 observations and 4 features
print(X.shape) # (150, 4)
# let's build a small model = 5 trees with depth no more than 2
model = GradientBoostingClassifier(n_estimators=5, max_depth=3, learning_rate=1.0)
model.fit(X, y==2) # predict 2nd class vs rest, for simplicity
# we can access individual trees
trees = model.estimators_.ravel()

def plot_tree(clf):
    dot_data = StringIO()
    export_graphviz(clf, out_file=dot_data, node_ids=True,
                    filled=True, rounded=True, 
                    special_characters=True)
    graph = pydotplus.graph_from_dot_data([enter image description here][3]dot_data.getvalue())  
    return Image(graph.create_png())

# now we can plot the first tree
plot_tree(trees[0])

绘制图表后,我检查了第一棵树的图表源代码,并使用以下代码写入文本文件:

After the plotting of the graph, I have checked the source code of the graph for the 1st tree and write to text file using the below code:

with open("C:\UsersXXXXDesktopPythoninput_tree.txt", "w") as wrt:
    wrt.write(export_graphviz(trees[0], out_file=None, node_ids=True,
                filled=True, rounded=True, 
                special_characters=True))

以下是输出文件:

digraph Tree {
node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;
edge [fontname=helvetica] ;
0 [label=<node &#35;0<br/>X<SUB>3</SUB> &le; 1.75<br/>friedman_mse = 0.222<br/>samples = 150<br/>value = 0.0>, fillcolor="#e5813955"] ;
1 [label=<node &#35;1<br/>X<SUB>2</SUB> &le; 4.95<br/>friedman_mse = 0.046<br/>samples = 104<br/>value = -0.285>, fillcolor="#e5813945"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label=<node &#35;2<br/>X<SUB>3</SUB> &le; 1.65<br/>friedman_mse = 0.01<br/>samples = 98<br/>value = -0.323>, fillcolor="#e5813943"] ;
1 -> 2 ;
3 [label=<node &#35;3<br/>friedman_mse = 0.0<br/>samples = 97<br/>value = -1.5>, fillcolor="#e5813900"] ;
2 -> 3 ;
4 [label=<node &#35;4<br/>friedman_mse = -0.0<br/>samples = 1<br/>value = 3.0>, fillcolor="#e58139ff"] ;
2 -> 4 ;
5 [label=<node &#35;5<br/>X<SUB>3</SUB> &le; 1.55<br/>friedman_mse = 0.222<br/>samples = 6<br/>value = 0.333>, fillcolor="#e5813968"] ;
1 -> 5 ;
6 [label=<node &#35;6<br/>friedman_mse = 0.0<br/>samples = 3<br/>value = 3.0>, fillcolor="#e58139ff"] ;
5 -> 6 ;
7 [label=<node &#35;7<br/>friedman_mse = 0.222<br/>samples = 3<br/>value = 0.0>, fillcolor="#e5813955"] ;
5 -> 7 ;
8 [label=<node &#35;8<br/>X<SUB>2</SUB> &le; 4.85<br/>friedman_mse = 0.021<br/>samples = 46<br/>value = 0.645>, fillcolor="#e581397a"] ;
0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
9 [label=<node &#35;9<br/>X<SUB>1</SUB> &le; 3.1<br/>friedman_mse = 0.222<br/>samples = 3<br/>value = 0.333>, fillcolor="#e5813968"] ;
8 -> 9 ;
10 [label=<node &#35;10<br/>friedman_mse = 0.0<br/>samples = 2<br/>value = 3.0>, fillcolor="#e58139ff"] ;
9 -> 10 ;
11 [label=<node &#35;11<br/>friedman_mse = -0.0<br/>samples = 1<br/>value = -1.5>, fillcolor="#e5813900"] ;
9 -> 11 ;
12 [label=<node &#35;12<br/>friedman_mse = -0.0<br/>samples = 43<br/>value = 3.0>, fillcolor="#e58139ff"] ;
8 -> 12 ;
}

为了从输出文件中提取决策规则,我尝试了以下 python RegEX 代码来转换为 SAS 代码:

To extract the decision rules from the output file I have tried the below python RegEX code to translate to SAS code:

 import re
with open("C:\UsersXXXXDesktopPythoninput_tree.txt") as f:
    with open("C:\UsersXXXXDesktopPythonoutput.txt", "w") as f1:
        result0 = 'value = 0;'
        f1.write(result0)
        for line in f:
            result1 = re.sub(r'^(d+)s+.*<br/>([A-Z]+)<SUB>(d+)</SUB>s+(.+?)([-d.]+)<br/>friedman_mse.*;$',r"if 23 4 5 then do;",line)
            result2 = re.sub(r'^(d+).*(?!SUB).*(values+=)s([-d.]+).*;$',r"2 value + 3; end;",result1)
            result3 = re.sub(r'^(d+s+->s+d+s+);$',r'1',result2)
            result4 = re.sub(r'^digraph.+|^node.+|^edge.+','',result3)
            result5 = re.sub(r'&(w{2});',r'1',result4)
            result6 = re.sub(r'}','end;',result5)
            f1.write(result6)

以下是上述代码的输出 SAS:

below is the output SAS from the above code:

value = 0;
if X3 le  1.75 then do;
if X2 le  4.95 then do;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
if X3 le  1.65 then do;
1 -> 2 
value = value + -1.5; end;
2 -> 3 
value = value + 3.0; end;
2 -> 4 
if X3 le  1.55 then do;
1 -> 5 
value = value + 3.0; end;
5 -> 6 
value = value + 0.0; end;
5 -> 7 
if X2 le  4.85 then do;
0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
if X1 le  3.1 then do;
8 -> 9 
value = value + 3.0; end;
9 -> 10 
value = value + -1.5; end;
9 -> 11 
value = value + 3.0; end;
8 -> 12 
end;

如您所见,输出文件中缺少一块,即我无法正确打开/关闭 do-end 块.为此,我需要使用节点号,但我没有这样做,因为我在这里找不到任何模式.

As you can see there is a missing piece in the output file i.e. I am not able to open/close the do-end block properly. For this I need to make use of the node numbers but I am failing to so as I am unable to find any pattern here.

谁能帮我解决这个问题.

Could anyone of you please help me with this query.

除此之外,像决策树分类器一样,我不能提取上面第二个链接中提到的 children_left、children_right、阈值.我已经成功提取了GBM的每一棵树

Apart from this, like decisiontreeclassifier can I not extract the children_left, children_right, threshold value as mentioned in the above 2nd link. I have successfully extracted each tree of GBM

trees = model.estimators_.ravel()

但是我没有找到任何有用的函数可以用来提取每棵树的值和规则.如果我能以与 DecisionTreeclassifier 类似的方式使用 grapviz 对象,请提供帮助.

but I didn't find any useful function which I can use to extract the value and rules of each tree. Kindly help if I can use the grapviz object in a similar way of DecisionTreeclassifier.

用任何其他可以解决我的目的的方法来帮助我.

Help me with any other method which can solve my purpose.

推荐答案

不需要使用graphviz导出来访问决策树数据.model.estimators_ 包含模型所包含的所有单个分类器.对于 GradientBoostingClassifier,这是一个形状为 (n_estimators, n_classes) 的 2D numpy 数组,每个项目都是一个 DecisionTreeRegressor.

There is no need to use the graphviz export to access the decision tree data. model.estimators_ contains all the individual classifiers that the model consists of. In the case of a GradientBoostingClassifier, this is a 2D numpy array with shape (n_estimators, n_classes), and each item is a DecisionTreeRegressor.

每个决策树都有一个属性 _tree 和 了解决策树结构 展示了如何从该对象中取出节点、阈值和子对象.

Each decision tree has a property _tree and Understanding the decision tree structure shows how to get out the nodes, thresholds and children from that object.


import numpy
import pandas
from sklearn.ensemble import GradientBoostingClassifier

est = GradientBoostingClassifier(n_estimators=4)
numpy.random.seed(1)
est.fit(numpy.random.random((100, 3)), numpy.random.choice([0, 1, 2], size=(100,)))
print('s', est.estimators_.shape)

n_classes, n_estimators = est.estimators_.shape
for c in range(n_classes):
    for t in range(n_estimators):
        dtree = est.estimators_[c, t]
        print("class={}, tree={}: {}".format(c, t, dtree.tree_))

        rules = pandas.DataFrame({
            'child_left': dtree.tree_.children_left,
            'child_right': dtree.tree_.children_right,
            'feature': dtree.tree_.feature,
            'threshold': dtree.tree_.threshold,
        })
        print(rules)

为每棵树输出如下内容:

Outputs something like this for each tree:

class=0, tree=0: <sklearn.tree._tree.Tree object at 0x7f18a697f370>
   child_left  child_right  feature  threshold
0           1            2        0   0.020702
1          -1           -1       -2  -2.000000
2           3            6        1   0.879058
3           4            5        1   0.543716
4          -1           -1       -2  -2.000000
5          -1           -1       -2  -2.000000
6           7            8        0   0.292586
7          -1           -1       -2  -2.000000
8          -1           -1       -2  -2.000000

这篇关于从 GradientBoostingClassifier 中提取决策规则的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持跟版网!

本站部分内容来源互联网,如果有图片或者内容侵犯您的权益请联系我们删除!

相关文档推荐

patching a class yields quot;AttributeError: Mock object has no attributequot; when accessing instance attributes(修补类会产生“AttributeError:Mock object has no attribute;访问实例属性时)
How to mock lt;ModelClassgt;.query.filter_by() in Flask-SqlAlchemy(如何在 Flask-SqlAlchemy 中模拟 lt;ModelClassgt;.query.filter_by())
FTPLIB error socket.gaierror: [Errno 8] nodename nor servname provided, or not known(FTPLIB 错误 socket.gaierror: [Errno 8] nodename nor servname provided, or not known)
Weird numpy.sum behavior when adding zeros(添加零时奇怪的 numpy.sum 行为)
Why does the #39;int#39; object is not callable error occur when using the sum() function?(为什么在使用 sum() 函数时会出现 int object is not callable 错误?)
How to sum in pandas by unique index in several columns?(如何通过几列中的唯一索引对 pandas 求和?)