scikit学习随机森林分类器概率阈值

我用的是 sklearn RandomForestClassifier(随机森林分类器) 的预测任务。

from sklearn.ensemble import RandomForestClassifier

model = RandomForestClassifier(n_estimators=300, n_jobs=-1)
model.fit(x_train,y_train)
model.predict_proba(x_test)

有171个班级需要预测,我只想预测那些班级,其中的 predict_proba(class) 是至少90%。下面的一切都应该设置为 0.

例如,给定以下内容。

     1   2   3   4   5   6   7
0  0.0 0.0 0.1 0.9 0.0 0.0 0.0
1  0.2 0.1 0.1 0.3 0.1 0.0 0.2
2  0.1 0.1 0.1 0.1 0.1 0.4 0.1
3  1.0 0.0 0.0 0.0 0.0 0.0 0.0

我的预期输出是:

0   4
1   0
2   0   
3   1

解决方案:

你可以使用 numpy.argwhere 如下。

from sklearn.ensemble import RandomForestClassifier
import numpy as np

model = RandomForestClassifier(n_estimators=300, n_jobs=-1)
model.fit(x_train,y_train)
preds = model.predict_proba(x_test)

#preds = np.array([[0.0, 0.0, 0.1, 0.9, 0.0, 0.0, 0.0],
#                  [ 0.2, 0.1, 0.1, 0.3, 0.1, 0.0, 0.2],
#                  [ 0.1 ,0.1, 0.1, 0.1, 0.1, 0.4, 0.1],
#                  [ 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]])

r = np.zeros(preds.shape[0], dtype=int)
t = np.argwhere(preds>=0.9)

r[t[:,0]] = t[:,1]+1
r
array([4, 0, 0, 1])

给TA打赏
共{{data.count}}人
人已打赏
解决方案

Django在配置STATIC_ROOT的情况下不提供静态文件服务。

2022-4-20 21:00:09

解决方案

从表中选择数值,其中数值以逗号分隔的字符串。

2022-4-20 21:00:11

0 条回复 A文章作者 M管理员
    暂无讨论,说说你的看法吧
个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索