Java学习者论坛

 找回密码
 立即注册

QQ登录

只需一步,快速开始

手机号码,快捷登录

恭喜Java学习者论坛(https://www.javaxxz.com)已经为数万Java学习者服务超过8年了!积累会员资料超过10000G+
成为本站VIP会员,下载本站10000G+会员资源,购买链接:点击进入购买VIP会员
JAVA高级面试进阶视频教程Java架构师系统进阶VIP课程

分布式高可用全栈开发微服务教程

Go语言视频零基础入门到精通

Java架构师3期(课件+源码)

Java开发全终端实战租房项目视频教程

SpringBoot2.X入门到高级使用教程

大数据培训第六期全套视频教程

深度学习(CNN RNN GAN)算法原理

Java亿级流量电商系统视频教程

互联网架构师视频教程

年薪50万Spark2.0从入门到精通

年薪50万!人工智能学习路线教程

年薪50万!大数据从入门到精通学习路线年薪50万!机器学习入门到精通视频教程
仿小米商城类app和小程序视频教程深度学习数据分析基础到实战最新黑马javaEE2.1就业课程从 0到JVM实战高手教程 MySQL入门到精通教程
查看: 1456|回复: 0

[默认分类] 机器学习算法之聚类算法Kmeans并找出最佳K值的Python实践

[复制链接]
  • TA的每日心情
    开心
    2021-12-13 21:45
  • 签到天数: 15 天

    [LV.4]偶尔看看III

    发表于 2018-3-20 13:58:56 | 显示全部楼层 |阅读模式
    Kmeans聚类算法的大概流程是:①从样本中随机找出K个样本作为中心点;
    ②求所有样本到这些样本的距离,按照最短的进行归类;
    ③求每个聚类中的样本的元素的平均值,作为新的中心点;
    ④继续②,③,知道所有样本再也无法找到新的聚类,就算完成。

    ### 一、接下来使用Numpy实现python代码,测试有效并且带注释:Kmeans.py:

    ```python
    # encoding: utf-8
    """
    Created on 2017年12月11日

    """
    import time

    from numpy import *
    from scipy.cluster.hierarchy import centroid

    import matplotlib.pyplot as plt


    # 计算欧氏距离
    def euclDistance(vector1,vector2):
        return sqrt(sum(power(vector2-vector1,2)))

    # 初始化K个中心点
    def initCentroids(dataSet,k):
        # 拿到数据集的格式 例如[[2,3,4][3,4,5]].shape = (2,3) [1,2,3] = (3,)
        numSamples,dim = dataSet.shape
        # 按照给定的shape,初始化一个数据类型和排列方式的填满0的数组
        centroids = zeros((k,dim))
        for i in range(k):
            index = int(random.uniform(0,numSamples)) #样本集随机挑一个,作为初始质心
            centroids[i,:] = dataSet[index,:]
        return centroids

    # k-means cluster
    def kmeans(dataSet,k):
        numSamples = dataSet.shape[0]
        #mat 对数组转换用于线性操作,类型变为:numpy.matrixlib.defmatrix.matrix
        # 初始化一个二维数据,第一列存储样本属于哪个聚类 第二列存储样本和中心的距离 [[0,0],[0,0]... ...]
        clusterASSMent = mat(zeros((numSamples,2)))
        clusterChanged = True
       
        ## 步骤1:初始化中心点
        centroids = initCentroids(dataSet, k)
       
        while clusterChanged:
            clusterChanged = False
            # 遍历每个样本
            for i in range(numSamples):
                minDist = 1000000000.0 #与样本点最近族群距离
                minIndex = 0 #所属族
                #步骤2 找到一个最近的中心点
                for j in range(k):
                    distance = euclDistance(centroids[j,:], dataSet[i,:]) #计算每个点到样本点的距离,找出最近的那一个样本
                    if distance  len(mark):
            print("Sorry! Your k is too large!please contact Zouxy")
            return 1
       
        for i in range(numSamples):
            markIndex = int(clusterAssment[i,0]) #每个样本所属族群
            plt.plot(dataSet[i,0],dataSet[i,1],mark[markIndex])
            
        mark = ["Dr","Db","Dg","Dk","^b","+b","sb","db","<b","pb"]
        for i in range(k):
            plt.plot(centroids[i,0],centroids[i,1],mark,markersize = 6)
            
        plt.show()
    ```

    ```python
    # encoding: utf-8
    """
    Created on 2017年12月11日

    """
    from numpy import *
    import time
    import matplotlib.pyplot as plt
    import kmeans.Kmeans as kean

    #步骤1
    print("step 1:load data...")
    dataSet = []
    fileIn = open("D:\Users\zhangjie116\Downloads\Wholesale customers data.csv")
    i = 0;
    for line in fileIn.readlines():
        i = i + 1
        lineArr = line.strip().split(",")
        if i != 1 :
            dataSet.append([float(lineArr[4]),float(lineArr[5])])
       
    #步骤2
    print("step 2:clustering")
    dataSet = mat(dataSet)
    k = 4
    centroids,clusterAssment = kean.kmeans(dataSet, k)

    #步骤3
    print("step 3:show the result...")
    kean.showCluster(dataSet, k, centroids, clusterAssment)
       
    ```

    运行结果:

    ![](http://aodi.paic.com.cn/forum.php?mod=attachment&aid=NDc2M3xjZjAxMDdiMnwxNTEzMTU3NzEyfDEzMTF8NTA0MA%3D%3D&noupdate=yes)

    ### 二、使用scikit-learn库并且使用轮廓系数找出最佳K值计算

    ```python
    # encoding: utf-8
    """
    Created on 2017年12月13日
    """
    import numpy as np
    from sklearn import cluster
    from sklearn.cluster import KMeans
    import matplotlib.pyplot as plt
    from sklearn import metrics

    #从数据集中加载数据
    dataSet = []
    fileIn = open("D:\Users\zhangjie116\Downloads\Wholesale customers data.csv")
    i = 0;
    for line in fileIn.readlines():
        i = i + 1
        lineArr = line.strip().split(",")
        if i != 1 :
            dataSet.append([float(lineArr[4]),float(lineArr[5])])
            
    #代码生成数据集
    cluster1=np.random.uniform(0.5,1.5,(2,10))
    cluster2=np.random.uniform(3.5,4.5,(2,10))
    cluster3=np.random.uniform(7.5,8.5,(2,10))
    dataSet=np.hstack((cluster1,cluster2,cluster3)).T        
    max_silhouette_coefficient = 0
    max_k = 0
    max_centroids = []
    max_labels_ = []
    numSamples = 0
    for k in range(2,10):
        #设定K
        clf = KMeans(n_clusters=k)
        #加载数据集合
        s = clf.fit(dataSet)
        #样本数量
        numSamples = len(dataSet)
        #中心点
        centroids = clf.cluster_centers_
        labels_ = clf.labels_
       
        #获取轮廓系数
        silhouette_coefficient = metrics.silhouette_score(dataSet, clf.labels_,metric="euclidean",sample_size=numSamples)
        print "k:%d ==== silhouette_coefficient:%f"%(k,silhouette_coefficient)
        #找到轮廓系数最大的K值,为效果最好的
        if max_silhouette_coefficient < silhouette_coefficient :
            max_silhouette_coefficient = silhouette_coefficient
            max_k = k
            max_centroids = centroids
            max_labels_ = labels_
       
        #获取聚类效果值
        print "k:%d ==== inertia_:%f"%(k,clf.inertia_)
       
    print "max_k:%d ==== max_silhouette_coefficient:%f"%(max_k,max_silhouette_coefficient)   

    #画出所有样例点 属于同一分类的绘制同样的颜色
    mark1 = ["or", "ob", "og", "ok", "^r", "+r", "sr", "dr", "<r", "pr"]
    for i in xrange(numSamples):
        plt.plot(dataSet[0], dataSet[1], mark1[max_labels_]) #mark[markIndex])
    mark2 = ["Dr", "Db", "Dg", "Dk", "^b", "+b", "sb", "db", "<b", "pb"]
    # 画出质点,用特殊图型
    for i in range(max_k):
        plt.plot(max_centroids[i,0], max_centroids[i,1], mark2,markersize = 12)
    plt.show()
    ```

    运行结果:

    ```
    k:2 ==== silhouette_coefficient:0.749433
    k:2 ==== inertia_:93.451706
    k:3 ==== silhouette_coefficient:0.887454
    k:3 ==== inertia_:4.294235
    k:4 ==== silhouette_coefficient:0.746558
    k:4 ==== inertia_:3.366372
    k:5 ==== silhouette_coefficient:0.579589
    k:5 ==== inertia_:2.555257
    k:6 ==== silhouette_coefficient:0.622791
    k:6 ==== inertia_:2.088957
    k:7 ==== silhouette_coefficient:0.452267
    k:7 ==== inertia_:1.602857
    k:8 ==== silhouette_coefficient:0.447269
    k:8 ==== inertia_:1.265230
    k:9 ==== silhouette_coefficient:0.454158
    k:9 ==== inertia_:0.976325
    max_k:3 ==== max_silhouette_coefficient:0.887454
    ```

    ![](http://aodi.paic.com.cn/forum.php?mod=attachment&aid=NDgwMnw0YWRjNjY0ZHwxNTEzMTU3NzEyfDEzMTF8NTA0MA%3D%3D&noupdate=yes)

    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    QQ|手机版|Java学习者论坛 ( 声明:本站资料整理自互联网,用于Java学习者交流学习使用,对资料版权不负任何法律责任,若有侵权请及时联系客服屏蔽删除 )

    GMT+8, 2025-1-21 18:46 , Processed in 0.331949 second(s), 38 queries .

    Powered by Discuz! X3.4

    © 2001-2017 Comsenz Inc.

    快速回复 返回顶部 返回列表