Search in sources :

Example 1 with ConditionalTreeSample

use of com.amazon.randomcutforest.returntypes.ConditionalTreeSample in project random-cut-forest-by-aws by aws.

the class ConditionalSampleSummarizer method summarize.

public ConditionalSampleSummary summarize(List<ConditionalTreeSample> alist) {
    checkArgument(alist.size() > 0, "incorrect call to summarize");
    /**
     * first we dedupe over the points in the pointStore -- it is likely, and
     * beneficial that different trees acting as different predictors in an ensemble
     * predict the same point that has been seen before. This would be specially
     * true if the time decay is large -- then the whole ensemble starts to behave
     * as a sliding window.
     *
     * note that it is possible that two different *points* predict the same missing
     * value especially when values are repeated in time. however that check of
     * equality of points would be expensive -- and one mechanism is to use a tree
     * (much like an RCT) to test for equality. We will try to not perform such a
     * test.
     */
    double totalWeight = alist.size();
    List<ConditionalTreeSample> newList = ConditionalTreeSample.dedup(alist);
    /**
     * for centrality = 0; there will be no filtration for centrality = 1; at least
     * half the values will be present -- the sum of distance(P33) + distance(P50)
     * appears to be slightly more reasonable than 2 * distance(P50) the distance 0
     * elements correspond to exact matches (on the available fields)
     *
     * it is an open question is the weight of such points should be higher. But if
     * one wants true dynamic adaptability then such a choice to increase weights of
     * exact matches would go against the dynamic sampling based use of RCF.
     */
    newList.sort((o1, o2) -> Double.compare(o1.distance, o2.distance));
    double threshold = 0;
    double currentWeight = 0;
    int alwaysInclude = 0;
    double remainderWeight = totalWeight;
    while (alwaysInclude < newList.size() && newList.get(alwaysInclude).distance == 0) {
        remainderWeight -= newList.get(alwaysInclude).weight;
        ++alwaysInclude;
    }
    for (int j = 1; j < newList.size(); j++) {
        if ((currentWeight < remainderWeight / 3 && currentWeight + newList.get(j).weight >= remainderWeight / 3) || (currentWeight < remainderWeight / 2 && currentWeight + newList.get(j).weight >= remainderWeight / 2)) {
            threshold += centrality * newList.get(j).distance;
        }
        currentWeight += newList.get(j).weight;
    }
    threshold += (1 - centrality) * newList.get(newList.size() - 1).distance;
    int num = 0;
    while (num < newList.size() && newList.get(num).distance <= threshold) {
        ++num;
    }
    /**
     * in the sequel we will create a global synopsis as well as a local one the
     * filtering based on thresholds will apply to the local one (points)
     */
    float[] coordMean = new float[queryPoint.length];
    double[] coordSqSum = new double[queryPoint.length];
    Center center = new Center(missingDimensions.length);
    ProjectedPoint[] points = new ProjectedPoint[num];
    for (int j = 0; j < newList.size(); j++) {
        ConditionalTreeSample e = newList.get(j);
        float[] values = new float[missingDimensions.length];
        for (int i = 0; i < missingDimensions.length; i++) {
            values[i] = e.leafPoint[missingDimensions[i]];
        }
        center.add(values, e.weight);
        for (int i = 0; i < coordMean.length; i++) {
            // weight unchanges
            coordMean[i] += e.leafPoint[i] * e.weight;
            coordSqSum[i] += e.leafPoint[i] * e.leafPoint[i] * e.weight;
        }
        if (j < num) {
            // weight is changed for clustering,
            // based on the distance of the sample from the query point
            double weight = (e.distance <= threshold) ? e.weight : e.weight * threshold / e.distance;
            points[j] = new ProjectedPoint(values, weight);
        }
    }
    // we compute p50 over the entire set
    float[] median = Arrays.copyOf(queryPoint, queryPoint.length);
    center.recompute();
    for (int y = 0; y < missingDimensions.length; y++) {
        median[missingDimensions[y]] = (float) center.coordinate[y];
    }
    // we compute deviation over the entire set, using original weights and no
    // filters
    float[] deviation = new float[queryPoint.length];
    for (int j = 0; j < coordMean.length; j++) {
        coordMean[j] = coordMean[j] / (float) totalWeight;
        deviation[j] = (float) sqrt(max(0, coordSqSum[j] / totalWeight - coordMean[j] * coordMean[j]));
    }
    /**
     * we now seed the centers according toa farthest point heuristic; such a
     * heuristic is used in clustering algorithms such as CURE:
     * https://en.wikipedia.org/wiki/CURE_algorithm to represent a cluster using
     * multiple points (multi-centroid approach) In that algorithm closest pairs are
     * merged -- the notion of closest will be altered here
     *
     * the first step is initialization to twice the final maximum number of
     * clusters
     */
    ArrayList<Center> centers = new ArrayList<>();
    centers.add(new Center(center.coordinate));
    int maxAllowed = min(center.coordinate.length * MAX_NUMBER_OF_TYPICAL_PER_DIMENSION, MAX_NUMBER_OF_TYPICAL_ELEMENTS);
    for (int k = 0; k < 2 * maxAllowed; k++) {
        double maxDist = 0;
        int maxIndex = -1;
        for (int j = 0; j < points.length; j++) {
            double minDist = Double.MAX_VALUE;
            for (int i = 0; i < centers.size(); i++) {
                minDist = min(minDist, distance(points[j], centers.get(i)));
            }
            if (minDist > maxDist) {
                maxDist = minDist;
                maxIndex = j;
            }
        }
        if (maxDist == 0) {
            break;
        } else {
            centers.add(new Center(Arrays.copyOf(points[maxIndex].coordinate, points[maxIndex].coordinate.length)));
        }
    }
    /**
     * we will now prune the number of clusters iteratively; the first step will be
     * assignment of points the next step would be choosing the optimum centers
     * given the assignment
     */
    double measure = 10;
    do {
        for (int i = 0; i < centers.size(); i++) {
            centers.get(i).reset();
        }
        double maxDist = 0;
        for (int j = 0; j < points.length; j++) {
            double[] dist = new double[centers.size()];
            Arrays.fill(dist, Double.MAX_VALUE);
            double minDist = Double.MAX_VALUE;
            for (int i = 0; i < centers.size(); i++) {
                dist[i] = distance(points[j], centers.get(i));
                minDist = min(minDist, dist[i]);
            }
            if (minDist == 0) {
                for (int i = 0; i < centers.size(); i++) {
                    if (dist[i] == 0) {
                        centers.get(i).add(points[j].coordinate, points[j].weight);
                    }
                }
            } else {
                maxDist = max(maxDist, minDist);
                double sum = 0;
                for (int i = 0; i < centers.size(); i++) {
                    if (dist[i] <= WEIGHT_ALLOCATION_THRESHOLD * minDist) {
                        // setting up harmonic mean
                        sum += minDist / dist[i];
                    }
                }
                for (int i = 0; i < centers.size(); i++) {
                    if (dist[i] == 0) {
                        centers.get(i).add(points[j].coordinate, points[j].weight);
                    } else if (dist[i] <= WEIGHT_ALLOCATION_THRESHOLD * minDist) {
                        // harmonic mean
                        centers.get(i).add(points[j].coordinate, points[j].weight * minDist / (dist[i] * sum));
                    }
                }
            }
        }
        for (int i = 0; i < centers.size(); i++) {
            centers.get(i).recompute();
        }
        /**
         * we now find the "closest" pair and merge them; the smaller weight cluster is
         * merged into the larger weight cluster because of L1 errors
         */
        int first = -1;
        int second = -1;
        measure = 0;
        for (int i = 0; i < centers.size(); i++) {
            for (int j = i + 1; j < centers.size(); j++) {
                double dist = distance(centers.get(i), centers.get(j));
                double tempMeasure = (centers.get(i).radius() + centers.get(j).radius()) / dist;
                if (measure < tempMeasure) {
                    first = i;
                    second = j;
                    measure = tempMeasure;
                }
            }
        }
        if (measure >= SEPARATION_RATIO_FOR_MERGE) {
            if (centers.get(first).weight < centers.get(second).weight) {
                centers.remove(first);
            } else {
                centers.remove(second);
            }
        } else if (centers.size() > maxAllowed) {
            // not well separated, remove small weight cluster centers
            centers.sort((o1, o2) -> Double.compare(o1.weight, o2.weight));
            centers.remove(0);
        }
    } while (centers.size() > maxAllowed || measure >= SEPARATION_RATIO_FOR_MERGE);
    // sort in decreasing weight
    centers.sort((o1, o2) -> Double.compare(o2.weight, o1.weight));
    float[][] pointList = new float[centers.size()][];
    float[] likelihood = new float[centers.size()];
    for (int i = 0; i < centers.size(); i++) {
        pointList[i] = Arrays.copyOf(queryPoint, queryPoint.length);
        for (int j = 0; j < missingDimensions.length; j++) {
            pointList[i][missingDimensions[j]] = centers.get(i).coordinate[j];
        }
        likelihood[i] = (float) (centers.get(i).weight / totalWeight);
    }
    return new ConditionalSampleSummary(totalWeight, pointList, likelihood, median, coordMean, deviation);
}
Also used : Math.sqrt(java.lang.Math.sqrt) Arrays(java.util.Arrays) List(java.util.List) ConditionalTreeSample(com.amazon.randomcutforest.returntypes.ConditionalTreeSample) CommonUtils.checkArgument(com.amazon.randomcutforest.CommonUtils.checkArgument) ConditionalSampleSummary(com.amazon.randomcutforest.returntypes.ConditionalSampleSummary) Math.max(java.lang.Math.max) Math.min(java.lang.Math.min) ArrayList(java.util.ArrayList) ConditionalSampleSummary(com.amazon.randomcutforest.returntypes.ConditionalSampleSummary) ArrayList(java.util.ArrayList) ConditionalTreeSample(com.amazon.randomcutforest.returntypes.ConditionalTreeSample)

Aggregations

CommonUtils.checkArgument (com.amazon.randomcutforest.CommonUtils.checkArgument)1 ConditionalSampleSummary (com.amazon.randomcutforest.returntypes.ConditionalSampleSummary)1 ConditionalTreeSample (com.amazon.randomcutforest.returntypes.ConditionalTreeSample)1 Math.max (java.lang.Math.max)1 Math.min (java.lang.Math.min)1 Math.sqrt (java.lang.Math.sqrt)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 List (java.util.List)1