原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: kmeans集群算法(cluster-reuters)
理论分析
集群中心点计算
1 随机从待分类的向量中选出20个作为20个集群的中心。
2 对所有的点,计算其和每个中心的距离,距离最小者为当前点的集群归属。
3 重新对每个集群计算新的中心,并计算新的中心和老的中心的距离,判断其是否收敛。
4 如果所有集群都收敛或者达到用户指定的条件,则集群完成。否则,从2开始下一轮计算。
集群数据
对所有的点,计算其和每个中心的距离,距离最小者为当前点的集群归属。
代码分析
$MAHOUT kmeans \ -i ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans/tfidf-vectors/ \ -c ${WORK_DIR}/reuters-kmeans-clusters \ -o ${WORK_DIR}/reuters-kmeans \ -dm org.apache.mahout.common.distance.CosineDistanceMeasure \ -x 10 -k 20 -ow --clustering |
在这之前同样需要调用seqdirectory和seq2sparse,请参考贝叶斯分类(classify-20newsgroups)
if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) { clusters = RandomSeedGenerator.buildRandom(getConf(), input, clusters, Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)), measure); } |
随机从待集群的文章中选取20篇文字作为20个集群的中心。
if (runSequential) { ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations); } else { ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations); } public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations) throws IOException, InterruptedException, ClassNotFoundException { ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath); Path clustersOut = null; int iteration = 1; /* 直到等于迭代次数或isConverged收敛*/ while (iteration <= numIterations) { conf.set(PRIOR_PATH_KEY, priorPath.toString()); String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath; Job job = new Job(conf, jobName); job.setMapOutputKeyClass(IntWritable.class); job.setMapOutputValueClass(ClusterWritable.class); job.setOutputKeyClass(IntWritable.class); job.setOutputValueClass(ClusterWritable.class); job.setInputFormatClass(SequenceFileInputFormat.class); job.setOutputFormatClass(SequenceFileOutputFormat.class); job.setMapperClass(CIMapper.class); job.setReducerClass(CIReducer.class); FileInputFormat.addInputPath(job, inPath); clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration); priorPath = clustersOut; FileOutputFormat.setOutputPath(job, clustersOut); job.setJarByClass(ClusterIterator.class); if (!job.waitForCompletion(true)) { throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath); } ClusterClassifier.writePolicy(policy, clustersOut); FileSystem fs = FileSystem.get(outPath.toUri(), conf); iteration++; /* 计算每个Cluster的当前的中心点和本次重新计算出来的中心点的距离,如果都小于给定的convergenceDelta,则本次集群计算收敛*/ if (isConverged(clustersOut, conf, fs)) { break; } } Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX); FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn); } /* CIMapper中的map方法*/ @Override protected void map(WritableComparable<?> key, VectorWritable value, Context context) throws IOException, InterruptedException { /* 使用ClusterClassifier对当前文章进行分类*/ Vector probabilities = classifier.classify(value.get()); Vector selections = policy.select(probabilities); for (Element el : selections.nonZeroes()) { classifier.train(el.index(), value.get(), el.get()); } } /* ClusterClassifier中classify方法 */ @Override public Vector classify(Vector instance) { return policy.classify(instance, this); } /* AbstractClusteringPolicy中的classify方法 */ @Override public Vector classify(Vector data, ClusterClassifier prior) { List<Cluster> models = prior.getModels(); int i = 0; Vector pdfs = new DenseVector(models.size()); /* 用20个集群中心模型对当前文章进行分类并存储在pdfs里面*/ for (Cluster model : models) { pdfs.set(i++, model.pdf(new VectorWritable(data))); } return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum()); } /* DistanceMeasureCluster中的pdf方法 */ @Override public double pdf(VectorWritable vw) { return 1 / (1 + measure.distance(vw.get(), getCenter())); } /* CosineDistanceMeasure中的distance方法,余玄求解2个向量的夹角 */ @Override public double distance(Vector v1, Vector v2) { if (v1.size() != v2.size()) { throw new CardinalityException(v1.size(), v2.size()); } double lengthSquaredv1 = v1.getLengthSquared(); double lengthSquaredv2 = v2.getLengthSquared(); double dotProduct = v2.dot(v1); double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2); // correct for floating-point rounding errors if (denominator < dotProduct) { denominator = dotProduct; } // correct for zero-vector corner case if (denominator == 0 && dotProduct == 0) { return 0; } return 1.0 - dotProduct / denominator; } /* ClusterClassifier的train方法*/ public void train(int actual, Vector data, double weight) { models.get(actual).observe(new VectorWritable(data), weight); } /* AbstractCluster中的observe方法,根据weight给s0计数,s1向量累加,s2向量平方后累加*/ @Override public void observe(VectorWritable x, double weight) { observe(x.get(), weight); } public void observe(Vector x, double weight) { if (weight == 1.0) { observe(x); } else { setS0(getS0() + weight); Vector weightedX = x.times(weight); if (getS1() == null) { setS1(weightedX); } else { getS1().assign(weightedX, Functions.PLUS); } Vector x2 = x.times(x).times(weight); if (getS2() == null) { setS2(x2); } else { getS2().assign(x2, Functions.PLUS); } } } /* CIReducer中reduce方法,对这一轮加入集群的向量进行平均,从新计算集群中心*/ @Override protected void reduce(IntWritable key, Iterable<ClusterWritable> values, Context context) throws IOException, InterruptedException { Iterator<ClusterWritable> iter = values.iterator(); Cluster first = iter.next().getValue(); // there must always be at least one while (iter.hasNext()) { Cluster cluster = iter.next().getValue(); first.observe(cluster); } List<Cluster> models = Lists.newArrayList(); models.add(first); classifier = new ClusterClassifier(models, policy); classifier.close(); context.write(key, new ClusterWritable(first)); } |
本作品采用知识共享署名 4.0 国际许可协议进行许可。