LAD线性判别分析鸢尾花预测
生活随笔
收集整理的这篇文章主要介绍了
LAD线性判别分析鸢尾花预测
小编觉得挺不错的,现在分享给大家,帮大家做个参考.
LAD线性判别分析鸢尾花预测
文章目录
- LAD线性判别分析鸢尾花预测
- 数学原理
- 代码实现
数学原理
代码实现
数据集下载链接:https://www.kaggle.com/uciml/iris/download
#!/usr/bin/env python # -*- coding: utf-8 -*- # @File : LDA.py # @Author: Gowi # @Date : 2021/3/25 # @Desc :import pandas as pd import numpy as np# 计算协方差矩阵 def Sigma(Iris, u):s = np.zeros((4, 4))for i in range(30):a = Iris[i, :] - ua = np.array([a])s = s + np.dot(a.T, a)return sdef predict(Iris_test):num1, num2, num3 = 0, 0, 0for i in range(20):acc1, acc2, acc3 = 0, 0, 0U12_test = np.dot(W12.T, Iris_test[i])U13_test = np.dot(W13.T, Iris_test[i])U23_test = np.dot(W23.T, Iris_test[i])if np.abs(U12_test - U12_1) < np.abs(U12_test - U12_2):acc1 += 1else:acc2 += 1if np.abs(U13_test - U13_1) < np.abs(U13_test - U13_2):acc1 += 1else:acc3 += 1if np.abs(U23_test - U23_1) < np.abs(U23_test - U23_2):acc2 += 1else:acc3 += 1acc = max(acc1, acc2, acc3)if acc == acc1:num1 += 1elif acc == acc2:num2 += 1else:num3 += 1return num1, num2, num3# 读取数据集 df = pd.read_csv(r"Iris.csv", header=None) # 拆分数据集 Iris1_train = df.values[1:31, 1:5] Iris2_train = df.values[51:81, 1:5] Iris3_train = df.values[101:131, 1:5] Iris1_test = df.values[31:51, 1:5] Iris2_test = df.values[81:101, 1:5] Iris3_test = df.values[131:151, 1:5] # 鸢尾花的类别 Iris1_class = 'Iris-setosa' Iris2_class = 'Iris-versicolor' Iris3_class = 'Iris-virginica' # 转换为float Iris1_train = Iris1_train.astype(np.float) Iris2_train = Iris2_train.astype(np.float) Iris3_train = Iris3_train.astype(np.float) Iris1_test = Iris1_test.astype(np.float) Iris2_test = Iris2_test.astype(np.float) Iris3_test = Iris3_test.astype(np.float) # 均值向量 u1 = np.mean(Iris1_train, axis=0) u2 = np.mean(Iris2_train, axis=0) u3 = np.mean(Iris3_train, axis=0) print("均值向量u1") print(u1) # 协方差矩阵 sigma1 = Sigma(Iris1_train, u1) sigma2 = Sigma(Iris2_train, u2) sigma3 = Sigma(Iris3_train, u3) print("类内散度矩阵sigma1") print(sigma1) # 类内散度矩阵 Sw12 = sigma1 + sigma2 Sw13 = sigma1 + sigma3 Sw23 = sigma2 + sigma2 print("类内散度矩阵Sw12") print(Sw12) # 类间散度矩阵 Sb12 = np.dot(np.array(u1 - u2), np.array(u1 - u2).T) Sb13 = np.dot(np.array(u1 - u3), np.array(u1 - u3).T) Sb23 = np.dot(np.array(u2 - u3), np.array(u2 - u3).T) # 斜率 W12 = np.dot(np.linalg.inv(Sw12), (u1 - u2)) W13 = np.dot(np.linalg.inv(Sw13), (u1 - u3)) W23 = np.dot(np.linalg.inv(Sw23), (u2 - u3)) print("斜率W12") print(W12) # 投影后的均值点 U12_1 = np.dot(W12.T, u1) U12_2 = np.dot(W12.T, u2) U13_1 = np.dot(W13.T, u1) U13_2 = np.dot(W13.T, u2) U23_1 = np.dot(W23.T, u2) U23_2 = np.dot(W23.T, u3) print("投影后的均值点U12_1") print(U12_1) # 预测Iris predict_1, _, _ = predict(Iris1_test) print("判断为" + Iris1_class + "的个数") print(predict_1) _, predict_2, _ = predict(Iris2_test) print("判断为" + Iris2_class + "的个数") print(predict_2) _, _, predict_3 = predict(Iris3_test) print("判断为" + Iris3_class + "的个数") print(predict_3) print("准确率为") print((predict_1 + predict_2 + predict_3) / 60 * 100, "%")总结
以上是生活随笔为你收集整理的LAD线性判别分析鸢尾花预测的全部内容,希望文章能够帮你解决所遇到的问题。
- 上一篇: 损失函数MSELoss和CELoss
- 下一篇: [高项]已知风险VS未知风险