星驰编程网

免费编程资源分享平台_编程教程_代码示例_开发技术文章

【Python机器学习系列】建立XGBoost模型预测小麦品种(源码)

这是我的第344篇原创文章。

一、引言

对于表格数据,一套完整的机器学习建模流程如下:

针对不同的数据集,有些步骤不适用,其中橘红色框为必要步骤,欢迎大家关注翻看我之前的一些相关文章。前面我介绍了机器学习模型的二分类任务和回归任务接下来做一下机器学习的多分类系列由于本系列案例数据质量较高,有些步骤跳过了,跳过的步骤将单独出文章总结!在Python中,可以使用Scikit-learn库来构建XGBoost分类模型进行多分类预测,本文以预测小麦品种为例,对这个过程做一个简要解读。

二、实现过程

2.1 准备数据

data = pd.read_csv(r'data.csv')
df = pd.DataFrame(data)
print(df.head())

df:

2.2 提取目标变量

target = 'Type'
features = df.columns.drop(target)
print(data["Type"].value_counts()) # 顺便查看一下样本是否平衡

2.3 划分数据集

# df = shuffle(df)
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=0)

2.4 标签编码

le = LabelEncoder()
y_train = le.fit_transform(y_train)

XGBoost分类模型的标签需要从0开始。

2.5 模型的构建

model = LGBMClassifier()

2.6 模型的训练

model.fit(X_train, y_train)

2.7 模型的推理

y_test = le.transform(y_test)
y_pred = model.predict(X_test)
y_scores = model.predict_proba(X_test)
print(y_pred)

2.8 模型的评价

acc = accuracy_score(y_test, y_pred) # 准确率acc
print(f"acc: \n{acc}")
cm = confusion_matrix(y_test, y_pred) # 混淆矩阵
print(f"cm: \n{cm}")
cr = classification_report(y_test, y_pred) # 分类报告
print(f"cr:  \n{cr}")

结果:

混淆矩阵:

ROC:

作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。

原文链接:

【Python机器学习系列】建立XGBoost模型预测小麦品种(案例+源码)

控制面板
您好,欢迎到访网站!
  查看权限
网站分类
最新留言