首页 > 编程知识 正文

用逻辑回归实现鸢尾花分类,使用逻辑回归实现鸢尾花分类

时间:2023-05-06 04:02:47 阅读:262115 作者:621

基于逻辑回归模型对鸢尾花数据集进行分类 理论知识

        不做过多赘述,相关知识有:指数分布族、GLM建模(分布函数+连接函数,对于本例来说是二项分布+sigmoid函数)、最大似然函数、交叉熵函数(评估逻辑回归模型的目标函数)。该分类问题关注的是通过已知的概率结果来推算出未知参数。对未知参数做自变量的对数似然函数求导,解得极值处的方程并代入连接函数,根据分布类型,即可推导出sigmoid函数,这便是该模型的来历。由这个函数我们便能把线性模型映射到0~1来解决概率问题。

直接上代码 import numpy as npimport seaborn as snsfrom pandas import read_csvfrom pandas.plotting import scatter_matrixfrom matplotlib import pyplotfrom sklearn.model_selection import train_test_splitfrom sklearn.linear_model import LogisticRegressionfrom sklearn.metrics import accuracy_scorefrom sklearn import metrics#导入鸢尾花数据集filename = 'iris.csv'names = ['separ-length', 'separ-width','petal-length','petal-width','class']dataset = read_csv(filename, names=names)#查看数据def display(data): # 显示数据维度 print('数据维度:行 %s,列 %sn'%data.shape) # 查看数据的前十行 print(data.head(20)) # 统计描述数据信息 print(data.describe()) # 分类分布情况 print(data.groupby('class').size())# 可视化统计数据def picture(data): # 箱线图 箱线图用来展示属性和中位值的离散程度 data.plot(kind='box', subplots=True,layout=(2,2), sharex=False,sharey=False) # 直方图 直方图显示每一个特征属性的分布情况 data.hist() # 散点矩阵图 散点矩阵图用来展示每个属性之间的影响关系 scatter_matrix(data) #绘图 pyplot.show()#逻辑回归模型训练与预测def LR_train(data): array = data.values #提取数值 X = array[:, 0:4] #提取样本特征 Y = array[:,4] #提取标签 validation_size = 0.8 #八成数据训练 两成数据评估 seed = 0 #随机数种子为零 # 划分训练集和测试集 X_train, X_validation, Y_train,Y_validation = train_test_split(X, Y, test_size = validation_size,random_state=seed) #构造逻辑回归模型并调用 LR = LogisticRegression(max_iter=10000) LR.fit(X_train, Y_train) y_pred= LR.predict(X_validation) print("模型精度:{:.2f}".format(np.mean(y_pred==Y_validation))) print("模型精度:{:.2f}".format(LR.score(X_validation,Y_validation))) X_new = np.array([[5.8,3.1,5.0,1.7]]) #预测目标 prediction = LR.predict(X_new) print("预测的目标类别是:{}".format(prediction)) #查看混淆矩阵(预测值和真实值的各类情况统计矩阵) confusion_matrix_result=metrics.confusion_matrix(y_pred,Y_validation) print('The confusion matrix result:n',confusion_matrix_result) #利用热力图对于结果进行可视化 pyplot.figure(figsize=(8,6)) sns.heatmap(confusion_matrix_result,annot=True,cmap='Blues') pyplot.xlabel('Predictedlabels') pyplot.ylabel('Truelabels') pyplot.show() if __name__ == '__main__': display(dataset) picture(dataset) LR_train(dataset) 一些问题

可能是数据集只有150个的原因,当我的训练集与测试集数量二八开时,模型进行三分类精度能达到100%。当数量比变成八二开时,精确度依然可以达到90%

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。