sklearn GMM BIC 模型选择

    xiaoxiao2025-02-06  10

    BIC为似然函数与参数及样本量的组合,选择该值最小的模型。 np.infty: inf 对GMM模型直接调用bic就可以得到其值 itertools.cycle: 实例化圆形迭代器,zip具有压缩取短的性质。 这里还使用了凸组合:bic.min() * 0.97 + 0.03 * bic.max() 下面是一个利用BIC选取GMM的例子: import itertools import numpy as np from scipy import linalg import matplotlib.pyplot as plt import matplotlib as mpl from sklearn import mixture n_samples = 500 np.random.seed(0) C = np.array([[0, -0.1], [1.7, 0.4]]) X = np.r_[np.dot(np.random.randn(n_samples, 2), C),    0.7 * np.random.randn(n_samples, 2) + np.array([-6, 3])] lowest_bic = np.infty bic = [] n_components_range = range(1, 7) cv_types = ['spherical', 'tied', 'diag', 'full'] for cv_type in cv_types:  for n_components in n_components_range:   gmm = mixture.GMM(n_components = n_components, covariance_type = cv_type)   gmm.fit(X)   bic.append(gmm.bic(X))   if bic[-1] < lowest_bic:    lowest_bic = bic[-1]    best_gmm = gmm bic = np.array(bic) color_iter = itertools.cycle(['k', 'r', 'g', 'b', 'c', 'm', 'y']) clf = best_gmm bars = [] spl = plt.subplot(2, 1, 1) for i, (cv_type, color) in enumerate(zip(cv_types, color_iter)):  xpos = np.array(n_components_range) + 0.2 * (i - 2)  bars.append(plt.bar(xpos, bic[i*len(n_components_range): (i + 1) * len(n_components_range)], width = .2, color = color)) plt.xticks(n_components_range) plt.ylim([bic.min() * 1.01 - .01 *bic.max(), bic.max()]) plt.title('BIC score per model') xpos = np.mod(bic.argmin(), len(n_components_range)) + .65 + .2 * np.floor(bic.argmin() / len(n_components_range)) plt.text(xpos, bic.min() * 0.97 + 0.03 * bic.max(), "*", fontsize = 14) spl.set_xlabel("Number of components") spl.legend([b[0] for b in bars], cv_types) splot = plt.subplot(2, 1, 2) Y_ = clf.predict(X) for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covars_, color_iter)):  v, w = linalg.eigh(covar)  if not np.any(Y_ == i):   continue  plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], .8, color = color)  angle = np.arctan2(w[0][1], w[0][0])  angle = 180 * angle / np.pi  v *= 4  ell = mpl.patches.Ellipse(mean, v[0], v[1], 180 + angle, color = color)  ell.set_clip_box(splot.bbox)  ell.set_alpha(.5)  splot.add_artist(ell) plt.xlim(-10, 10) plt.ylim(-3, 6) plt.xticks(()) plt.yticks(()) plt.title("Selected GMM: full model, 2 components") plt.subplots_adjust(hspace = .35, bottom = .02) plt.show()
    转载请注明原文地址: https://ju.6miu.com/read-1296159.html
    最新回复(0)