Uploaded image for project: 'Commons Math'
  1. Commons Math
  2. MATH-1378

KMeansPlusPlusClusterer optimize seeding procedure, by computing sum of squared distances outside the loop.

    XMLWordPrintableJSON

Details

    • Improvement
    • Status: Open
    • Major
    • Resolution: Unresolved
    • None
    • 4.X
    • None
    • None

    Description

      Currently in KMeansPlusPlusClusterer class, function which implements initial clusters seeding chooseInitialCenters, has following computation executed inside the while loop cycle:

              while (resultSet.size() < k) {
      
                  // Sum up the squared distances for the points in pointList not
                  // already taken.
                  double distSqSum = 0.0;
      
                  for (int i = 0; i < numPoints; i++) {
                      if (!taken[i]) {
                          distSqSum += minDistSquared[i];
                      }
                  }
      
      // Rest skipped for simplicity
      

      While computation of this sum could be produced once outside the loop and latter adjusted according to the values of minimum distances to the centers set. E.g.:

              // Sum up the squared distances for the points in pointList not
              // already taken.
              double distSqSum = 0.0;
      
              // There is no need to compute sum of squared distances within the "while" loop
              // we can compute initial value ones and maintain deltas in the loop.
              for (int i = 0; i < numPoints; i++) {
                  if (!taken[i]) {
                      distSqSum += minDistSquared[i];
                  }
              }
      
              while (resultSet.size() < k) {
                  // Add one new data point as a center. Each point x is chosen with
                  // probability proportional to D(x)2
                  final double r = random.nextDouble() * distSqSum;
      
                  // The index of the next point to be added to the resultSet.
                  int nextPointIndex = -1;
      
                  // Sum through the squared min distances again, stopping when
                  // sum >= r.
                  double sum = 0.0;
                  for (int i = 0; i < numPoints; i++) {
                      if (!taken[i]) {
                          sum += minDistSquared[i];
                          if (sum >= r) {
                              nextPointIndex = i;
                              break;
                          }
                      }
                  }
      
                  // If it's not set to >= 0, the point wasn't found in the previous
                  // for loop, probably because distances are extremely small.  Just pick
                  // the last available point.
                  if (nextPointIndex == -1) {
                      for (int i = numPoints - 1; i >= 0; i--) {
                          if (!taken[i]) {
                              nextPointIndex = i;
                              break;
                          }
                      }
                  }
      
                  // We found one.
                  if (nextPointIndex >= 0) {
      
                      final T p = pointList.get(nextPointIndex);
      
                      resultSet.add(new CentroidCluster<T> (p));
      
                      // Mark it as taken.
                      taken[nextPointIndex] = true;
      
                      if (resultSet.size() < k) {
                          // Now update elements of minDistSquared.  We only have to compute
                          // the distance to the new center to do this.
                          for (int j = 0; j < numPoints; j++) {
                              // Only have to worry about the points still not taken.
                              if (!taken[j]) {
                                  double d = distance(p, pointList.get(j));
                                  // Subtracting the old value.
                                  distSqSum -= minDistSquared[j];
                                  // Update minimum distance.
                                  minDistSquared[j] = FastMath.min(d*d, minDistSquared[j]);
                                  // Adjust the overall sum of squared distances.
                                  distSqSum += minDistSquared[j];
                              }
                          }
                      }
      
                  } else {
                      // None found --
                      // Break from the while loop to prevent
                      // an infinite loop.
                      break;
                  }
              }
      

      Attachments

        1. MATH-1378.patch
          3 kB
          Artem Barger

        Issue Links

          Activity

            People

              Unassigned Unassigned
              C0rWin Artem Barger
              Votes:
              0 Vote for this issue
              Watchers:
              1 Start watching this issue

              Dates

                Created:
                Updated: