欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

LAD线性判别分析鸢尾花预测

发布时间:2024/3/12 编程问答 48 豆豆
生活随笔 收集整理的这篇文章主要介绍了 LAD线性判别分析鸢尾花预测 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

LAD线性判别分析鸢尾花预测

文章目录

  • LAD线性判别分析鸢尾花预测
    • 数学原理
  • 代码实现

数学原理

代码实现

数据集下载链接:https://www.kaggle.com/uciml/iris/download

#!/usr/bin/env python # -*- coding: utf-8 -*- # @File : LDA.py # @Author: Gowi # @Date : 2021/3/25 # @Desc :import pandas as pd import numpy as np# 计算协方差矩阵 def Sigma(Iris, u):s = np.zeros((4, 4))for i in range(30):a = Iris[i, :] - ua = np.array([a])s = s + np.dot(a.T, a)return sdef predict(Iris_test):num1, num2, num3 = 0, 0, 0for i in range(20):acc1, acc2, acc3 = 0, 0, 0U12_test = np.dot(W12.T, Iris_test[i])U13_test = np.dot(W13.T, Iris_test[i])U23_test = np.dot(W23.T, Iris_test[i])if np.abs(U12_test - U12_1) < np.abs(U12_test - U12_2):acc1 += 1else:acc2 += 1if np.abs(U13_test - U13_1) < np.abs(U13_test - U13_2):acc1 += 1else:acc3 += 1if np.abs(U23_test - U23_1) < np.abs(U23_test - U23_2):acc2 += 1else:acc3 += 1acc = max(acc1, acc2, acc3)if acc == acc1:num1 += 1elif acc == acc2:num2 += 1else:num3 += 1return num1, num2, num3# 读取数据集 df = pd.read_csv(r"Iris.csv", header=None) # 拆分数据集 Iris1_train = df.values[1:31, 1:5] Iris2_train = df.values[51:81, 1:5] Iris3_train = df.values[101:131, 1:5] Iris1_test = df.values[31:51, 1:5] Iris2_test = df.values[81:101, 1:5] Iris3_test = df.values[131:151, 1:5] # 鸢尾花的类别 Iris1_class = 'Iris-setosa' Iris2_class = 'Iris-versicolor' Iris3_class = 'Iris-virginica' # 转换为float Iris1_train = Iris1_train.astype(np.float) Iris2_train = Iris2_train.astype(np.float) Iris3_train = Iris3_train.astype(np.float) Iris1_test = Iris1_test.astype(np.float) Iris2_test = Iris2_test.astype(np.float) Iris3_test = Iris3_test.astype(np.float) # 均值向量 u1 = np.mean(Iris1_train, axis=0) u2 = np.mean(Iris2_train, axis=0) u3 = np.mean(Iris3_train, axis=0) print("均值向量u1") print(u1) # 协方差矩阵 sigma1 = Sigma(Iris1_train, u1) sigma2 = Sigma(Iris2_train, u2) sigma3 = Sigma(Iris3_train, u3) print("类内散度矩阵sigma1") print(sigma1) # 类内散度矩阵 Sw12 = sigma1 + sigma2 Sw13 = sigma1 + sigma3 Sw23 = sigma2 + sigma2 print("类内散度矩阵Sw12") print(Sw12) # 类间散度矩阵 Sb12 = np.dot(np.array(u1 - u2), np.array(u1 - u2).T) Sb13 = np.dot(np.array(u1 - u3), np.array(u1 - u3).T) Sb23 = np.dot(np.array(u2 - u3), np.array(u2 - u3).T) # 斜率 W12 = np.dot(np.linalg.inv(Sw12), (u1 - u2)) W13 = np.dot(np.linalg.inv(Sw13), (u1 - u3)) W23 = np.dot(np.linalg.inv(Sw23), (u2 - u3)) print("斜率W12") print(W12) # 投影后的均值点 U12_1 = np.dot(W12.T, u1) U12_2 = np.dot(W12.T, u2) U13_1 = np.dot(W13.T, u1) U13_2 = np.dot(W13.T, u2) U23_1 = np.dot(W23.T, u2) U23_2 = np.dot(W23.T, u3) print("投影后的均值点U12_1") print(U12_1) # 预测Iris predict_1, _, _ = predict(Iris1_test) print("判断为" + Iris1_class + "的个数") print(predict_1) _, predict_2, _ = predict(Iris2_test) print("判断为" + Iris2_class + "的个数") print(predict_2) _, _, predict_3 = predict(Iris3_test) print("判断为" + Iris3_class + "的个数") print(predict_3) print("准确率为") print((predict_1 + predict_2 + predict_3) / 60 * 100, "%")

总结

以上是生活随笔为你收集整理的LAD线性判别分析鸢尾花预测的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。