欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

EM算法实践

发布时间:2025/1/21 编程问答 41 豆豆
生活随笔 收集整理的这篇文章主要介绍了 EM算法实践 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

学习1
学习2

一、Basic EM算法

np.random.multivariate_normal(mean,convirance,size)生成多元正态分布()
  • 判断预估的分布与实际分布的顺序是否相同,需要用到样本的标签及数据特征。
    程序的数据是男女身高,女生标签是0,男生是1。我们有先验知识,男生的身高比女生高,所以这个通过比较模型的两个均值,即可预测的那个分布是女生,哪个是男生。
    通过flag表示,女生是第一分布,flag=0.
  • cmp_point = mpl.colors.ListedColormap(['#22B14C','#ED1C24'])
  • 画散点图时,通过c,cmap参数标记不同类别的点
  • import numpy as np from scipy.stats import multivariate_normal from sklearn.mixture import GaussianMixture from sklearn.model_selection import train_test_split from mpl_toolkits.mplot3d import Axes3D import matplotlib as mpl from matplotlib import patches as mpatches import matplotlib.pyplot as plt from sklearn.metrics.pairwise import pairwise_distances_argmin import pandas as pdnp.random.seed(0) data = pd.read_csv('../HeightWeight.csv') print(data.head()) feature = data[['Height(cm)','Weight(kg)']] label = data['Sex'] # print(feature.shape) # print(label.shape) train_x,test_x,train_y,test_y = train_test_split(feature,label,test_size=0.3) print(train_x.shape) print(train_y.shape)# 建模 gmm = GaussianMixture(n_components=2,covariance_type='full',max_iter=100) gmm.fit(train_x) print('均值:\n',gmm.means_)mu1,mu2 = gmm.means_ cov1,cov2 = gmm.covariances_# 根据估计的参数值,建高斯分布 norm1 = multivariate_normal(mu1,cov1) norm2 = multivariate_normal(mu2,cov2) # 计算属于不同分类时的概率 tau1 = norm1.pdf(train_x) tau2 = norm2.pdf(train_x) flag= 0# 判断分布的女生在前还是男生在前 if gmm.means_[0][0]<gmm.means_[1][0]:# 这里使用了标签,女生的身高均值小于男性c1 = tau1 > tau2# 女生是第一个分布,标签为0 else:flag=1c1 = tau1 < tau2 #女生是第一个分布,标签为1 c2 = ~c1 # 预测 tau1_test = norm1.pdf(test_x) tau2_test = norm2.pdf(test_x) if flag:c1_test = tau1_test < tau2_test else:c1_test = tau1_test > tau2_testc2_test = ~c1_testheight_min,height_max = data['Height(cm)'].min(),data['Height(cm)'].max() weight_min,weight_max = data['Weight(kg)'].min(),data['Weight(kg)'].max()x = np.linspace(height_min-0.5,height_max+0.5,300) y = np.linspace(weight_min-0.5,weight_max+0.5,300) xx,yy = np.meshgrid(x,y) # # print(xx) # height_min,height_max = 2,10 # weight_min,weight_max = 2,8 # # x = np.linspace(height_min-0.5,height_max+0.5,5) # y = np.linspace(weight_min-0.5,weight_max+0.5,2) # xx,yy = np.meshgrid(x,y) grid_test = np.stack((xx.flat,yy.flat),axis=1) grid_predict= gmm.predict(grid_test) # print(xx) # print(yy) # print(grid_test) # print(grid_predict) # print(train_x[c1])cmp_point = mpl.colors.ListedColormap(['#22B14C','#ED1C24']) cmp_bkg = mpl.colors.ListedColormap(['#B0E0E6','#FFC0CB'])plt.pcolormesh(xx,yy,grid_predict.reshape(xx.shape),cmap=cmp_bkg) plt.xlabel('Height(cm)') plt.ylabel('Weight(cm)')print(train_x.head()) print(train_x.columns) print('*'*20) print(train_x['Height(cm)'].shape) print(train_y.shape) print('*'*20) # plt.scatter(train_x[c1]['Height(cm)'],train_x[c1]['Weight(kg)'],c =train_x['Sex'] ,marker='o',edgecolors='r',cmap=cmp_point) plt.scatter(train_x['Height(cm)'],train_x['Weight(kg)'],c =train_y ,marker='o',cmap=cmp_point) # plt.scatter(train_x[c2]['Height(cm)'],train_x[c2]['Weight(kg)'],marker='o',edgecolors='b',cmap=cmp_point) #测试数据c =c1_test, plt.scatter(test_x['Height(cm)'],test_x['Weight(kg)'],c =c2_test ,marker='^',s = 60,cmap=cmp_point) # plt.scatter(test_x[c1_test]['Height(cm)'],test_x[c1_test]['Weight(kg)'],marker='^',edgecolors='r',cmap=cmp_point) # plt.scatter(test_x[c2_test]['Height(cm)'],test_x[c2_test]['Weight(kg)'],marker='^',edgecolors='b',cmap=cmp_point)patchs = [mpatches.Patch(color='#B0E0E6', label='girl'),mpatches.Patch(color='#FFC0CB', label='boy'),] plt.legend(handles=patchs, fancybox=True, framealpha=0.8)plt.show() train_acc = np.mean(train_y == c2) test_acc = np.mean(test_y == c2_test) print('trian acc: ',train_acc) print('test acc: ',test_acc)

    二、GMM参数

  • 方差类型
  • covariance_type= ('spherical', 'diag', 'tied', 'full')
  • BIC
    BIC=kln(n)−LBIC=kln(n) -LBIC=kln(n)L
    其中,k为模型参数个数,n为样本数量,L为似然函数。kln(n)惩罚项在维数过大且训练样本数据相对较少的情况下,可以有效避免出现维度灾难现象
  • 三、DPGMM

    DPGMM对于簇的个数选个比较有用

    dpgmm = BayesianGaussianMixture(n_components=n_components, covariance_type='full', max_iter=1000, n_init=5,weight_concentration_prior_type='dirichlet_process',weight_concentration_prior=0.1)

    总结

    以上是生活随笔为你收集整理的EM算法实践的全部内容,希望文章能够帮你解决所遇到的问题。

    如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。