kmeans集群算法(cluster-reuters)

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: kmeans集群算法(cluster-reuters)

理论分析

集群中心点计算

1 随机从待分类的向量中选出20个作为20个集群的中心。
2 对所有的点,计算其和每个中心的距离,距离最小者为当前点的集群归属。
3 重新对每个集群计算新的中心,并计算新的中心和老的中心的距离,判断其是否收敛。
4 如果所有集群都收敛或者达到用户指定的条件,则集群完成。否则,从2开始下一轮计算。

集群数据

对所有的点,计算其和每个中心的距离,距离最小者为当前点的集群归属。

代码分析

  $MAHOUT kmeans \
    -i ${WORK_DIR}/reuters-out-seqdir-sparse-kmeans/tfidf-vectors/ \
    -c ${WORK_DIR}/reuters-kmeans-clusters \
    -o ${WORK_DIR}/reuters-kmeans \
    -dm org.apache.mahout.common.distance.CosineDistanceMeasure \
    -x 10 -k 20 -ow --clustering

在这之前同样需要调用seqdirectory和seq2sparse,请参考贝叶斯分类(classify-20newsgroups)

    if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) {
      clusters = RandomSeedGenerator.buildRandom(getConf(), input, clusters,
          Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)), measure);
    }

随机从待集群的文章中选取20篇文字作为20个集群的中心。

    if (runSequential) {
      ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations);
    } else {
      ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations);
    }
 
  public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations)
    throws IOException, InterruptedException, ClassNotFoundException {
    ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath);
    Path clustersOut = null;
    int iteration = 1;
    /* 直到等于迭代次数或isConverged收敛*/
    while (iteration <= numIterations) {
      conf.set(PRIOR_PATH_KEY, priorPath.toString());
 
      String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath;
      Job job = new Job(conf, jobName);
      job.setMapOutputKeyClass(IntWritable.class);
      job.setMapOutputValueClass(ClusterWritable.class);
      job.setOutputKeyClass(IntWritable.class);
      job.setOutputValueClass(ClusterWritable.class);
 
      job.setInputFormatClass(SequenceFileInputFormat.class);
      job.setOutputFormatClass(SequenceFileOutputFormat.class);
      job.setMapperClass(CIMapper.class);
      job.setReducerClass(CIReducer.class);
 
      FileInputFormat.addInputPath(job, inPath);
      clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration);
      priorPath = clustersOut;
      FileOutputFormat.setOutputPath(job, clustersOut);
 
      job.setJarByClass(ClusterIterator.class);
      if (!job.waitForCompletion(true)) {
        throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath);
      }
      ClusterClassifier.writePolicy(policy, clustersOut);
      FileSystem fs = FileSystem.get(outPath.toUri(), conf);
      iteration++;
      /* 计算每个Cluster的当前的中心点和本次重新计算出来的中心点的距离,如果都小于给定的convergenceDelta,则本次集群计算收敛*/
      if (isConverged(clustersOut, conf, fs)) {
        break;
      }
    }
    Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX);
    FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn);
  }
 
  /* CIMapper中的map方法*/
  @Override
  protected void map(WritableComparable<?> key, VectorWritable value, Context context) throws IOException,
      InterruptedException {
    /* 使用ClusterClassifier对当前文章进行分类*/
    Vector probabilities = classifier.classify(value.get());
    Vector selections = policy.select(probabilities);
    for (Element el : selections.nonZeroes()) {
      classifier.train(el.index(), value.get(), el.get());
    }
  }
 
  /* ClusterClassifier中classify方法 */
  @Override
  public Vector classify(Vector instance) {
    return policy.classify(instance, this);
  }
 
  /* AbstractClusteringPolicy中的classify方法 */
  @Override
  public Vector classify(Vector data, ClusterClassifier prior) {
    List<Cluster> models = prior.getModels();
    int i = 0;
    Vector pdfs = new DenseVector(models.size());
    /* 用20个集群中心模型对当前文章进行分类并存储在pdfs里面*/
    for (Cluster model : models) {
      pdfs.set(i++, model.pdf(new VectorWritable(data)));
    }
    return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
  }
 
  /* DistanceMeasureCluster中的pdf方法 */
  @Override
  public double pdf(VectorWritable vw) {
    return 1 / (1 + measure.distance(vw.get(), getCenter()));
  }
 
  /* CosineDistanceMeasure中的distance方法,余玄求解2个向量的夹角 */
  @Override
  public double distance(Vector v1, Vector v2) {
    if (v1.size() != v2.size()) {
      throw new CardinalityException(v1.size(), v2.size());
    }
    double lengthSquaredv1 = v1.getLengthSquared();
    double lengthSquaredv2 = v2.getLengthSquared();
 
    double dotProduct = v2.dot(v1);
    double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2);
 
    // correct for floating-point rounding errors
    if (denominator < dotProduct) {
      denominator = dotProduct;
    }
 
    // correct for zero-vector corner case
    if (denominator == 0 && dotProduct == 0) {
      return 0;
    }
 
    return 1.0 - dotProduct / denominator;
  }
 
  /*  ClusterClassifier的train方法*/
  public void train(int actual, Vector data, double weight) {
    models.get(actual).observe(new VectorWritable(data), weight);
  }
 
  /* AbstractCluster中的observe方法,根据weight给s0计数,s1向量累加,s2向量平方后累加*/
  @Override
  public void observe(VectorWritable x, double weight) {
    observe(x.get(), weight);
  }
 
  public void observe(Vector x, double weight) {
    if (weight == 1.0) {
      observe(x);
    } else {
      setS0(getS0() + weight);
      Vector weightedX = x.times(weight);
      if (getS1() == null) {
        setS1(weightedX);
      } else {
        getS1().assign(weightedX, Functions.PLUS);
      }
      Vector x2 = x.times(x).times(weight);
      if (getS2() == null) {
        setS2(x2);
      } else {
        getS2().assign(x2, Functions.PLUS);
      }
    }
  }
 
  /* CIReducer中reduce方法,对这一轮加入集群的向量进行平均,从新计算集群中心*/
  @Override
  protected void reduce(IntWritable key, Iterable<ClusterWritable> values, Context context) throws IOException,
      InterruptedException {
    Iterator<ClusterWritable> iter = values.iterator();
    Cluster first = iter.next().getValue(); // there must always be at least one
    while (iter.hasNext()) {
      Cluster cluster = iter.next().getValue();
      first.observe(cluster);
    }
    List<Cluster> models = Lists.newArrayList();
    models.add(first);
    classifier = new ClusterClassifier(models, policy);
    classifier.close();
    context.write(key, new ClusterWritable(first));
  }

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复