Search in sources :

Example 1 with ConditionalSampleSummary

use of com.amazon.randomcutforest.returntypes.ConditionalSampleSummary in project random-cut-forest-by-aws by aws.

the class ConditionalFieldTest method SimpleTest.

@Test
public void SimpleTest() {
    int newDimensions = 30;
    randomSeed = 101;
    sampleSize = 256;
    RandomCutForest newForest = RandomCutForest.builder().numberOfTrees(100).sampleSize(sampleSize).dimensions(newDimensions).randomSeed(randomSeed).compact(true).boundingBoxCacheFraction(0.0).build();
    dataSize = 2000 + 5;
    baseMu = 0.0;
    baseSigma = 1.0;
    anomalyMu = 0.0;
    anomalySigma = 1.0;
    transitionToAnomalyProbability = 0.0;
    // ignoring anomaly cluster for now
    transitionToBaseProbability = 1.0;
    Random prg = new Random(0);
    NormalMixtureTestData generator = new NormalMixtureTestData(baseMu, baseSigma, anomalyMu, anomalySigma, transitionToAnomalyProbability, transitionToBaseProbability);
    double[][] data = generator.generateTestData(dataSize, newDimensions, 100);
    for (int i = 0; i < 2000; i++) {
        // shrink, shift at random
        for (int j = 0; j < newDimensions; j++) data[i][j] *= 0.01;
        if (prg.nextDouble() < 0.5)
            data[i][0] += 5.0;
        else
            data[i][0] -= 5.0;
        newForest.update(data[i]);
    }
    float[] queryOne = new float[newDimensions];
    float[] queryTwo = new float[newDimensions];
    queryTwo[1] = 1;
    ConditionalSampleSummary summary = newForest.getConditionalFieldSummary(queryOne, 1, new int[] { 0 }, 1);
    assert (summary.summaryPoints.length == 2);
    assert (summary.relativeLikelihood.length == 2);
    assert (Math.abs(summary.summaryPoints[0][0] - 5.0) < 0.01 || Math.abs(summary.summaryPoints[0][0] + 5.0) < 0.01);
    assert (Math.abs(summary.summaryPoints[1][0] - 5.0) < 0.01 || Math.abs(summary.summaryPoints[1][0] + 5.0) < 0.01);
    assert (summary.relativeLikelihood[0] > 0.25);
    assert (summary.relativeLikelihood[1] > 0.25);
    summary = newForest.getConditionalFieldSummary(queryTwo, 1, new int[] { 0 }, 1);
    assert (summary.summaryPoints.length == 2);
    assert (summary.relativeLikelihood.length == 2);
    assertEquals(summary.summaryPoints[0][1], 1, 1e-6);
    assertEquals(summary.summaryPoints[1][1], 1, 1e-6);
    assert (Math.abs(summary.summaryPoints[0][0] - 5.0) < 0.01 || Math.abs(summary.summaryPoints[0][0] + 5.0) < 0.01);
    assert (Math.abs(summary.summaryPoints[1][0] - 5.0) < 0.01 || Math.abs(summary.summaryPoints[1][0] + 5.0) < 0.01);
    assert (summary.relativeLikelihood[0] > 0.25);
    assert (summary.relativeLikelihood[1] > 0.25);
}
Also used : ConditionalSampleSummary(com.amazon.randomcutforest.returntypes.ConditionalSampleSummary) Random(java.util.Random) NormalMixtureTestData(com.amazon.randomcutforest.testutils.NormalMixtureTestData) Test(org.junit.jupiter.api.Test)

Example 2 with ConditionalSampleSummary

use of com.amazon.randomcutforest.returntypes.ConditionalSampleSummary in project random-cut-forest-by-aws by aws.

the class ConditionalSampleSummarizer method summarize.

public ConditionalSampleSummary summarize(List<ConditionalTreeSample> alist) {
    checkArgument(alist.size() > 0, "incorrect call to summarize");
    /**
     * first we dedupe over the points in the pointStore -- it is likely, and
     * beneficial that different trees acting as different predictors in an ensemble
     * predict the same point that has been seen before. This would be specially
     * true if the time decay is large -- then the whole ensemble starts to behave
     * as a sliding window.
     *
     * note that it is possible that two different *points* predict the same missing
     * value especially when values are repeated in time. however that check of
     * equality of points would be expensive -- and one mechanism is to use a tree
     * (much like an RCT) to test for equality. We will try to not perform such a
     * test.
     */
    double totalWeight = alist.size();
    List<ConditionalTreeSample> newList = ConditionalTreeSample.dedup(alist);
    /**
     * for centrality = 0; there will be no filtration for centrality = 1; at least
     * half the values will be present -- the sum of distance(P33) + distance(P50)
     * appears to be slightly more reasonable than 2 * distance(P50) the distance 0
     * elements correspond to exact matches (on the available fields)
     *
     * it is an open question is the weight of such points should be higher. But if
     * one wants true dynamic adaptability then such a choice to increase weights of
     * exact matches would go against the dynamic sampling based use of RCF.
     */
    newList.sort((o1, o2) -> Double.compare(o1.distance, o2.distance));
    double threshold = 0;
    double currentWeight = 0;
    int alwaysInclude = 0;
    double remainderWeight = totalWeight;
    while (alwaysInclude < newList.size() && newList.get(alwaysInclude).distance == 0) {
        remainderWeight -= newList.get(alwaysInclude).weight;
        ++alwaysInclude;
    }
    for (int j = 1; j < newList.size(); j++) {
        if ((currentWeight < remainderWeight / 3 && currentWeight + newList.get(j).weight >= remainderWeight / 3) || (currentWeight < remainderWeight / 2 && currentWeight + newList.get(j).weight >= remainderWeight / 2)) {
            threshold += centrality * newList.get(j).distance;
        }
        currentWeight += newList.get(j).weight;
    }
    threshold += (1 - centrality) * newList.get(newList.size() - 1).distance;
    int num = 0;
    while (num < newList.size() && newList.get(num).distance <= threshold) {
        ++num;
    }
    /**
     * in the sequel we will create a global synopsis as well as a local one the
     * filtering based on thresholds will apply to the local one (points)
     */
    float[] coordMean = new float[queryPoint.length];
    double[] coordSqSum = new double[queryPoint.length];
    Center center = new Center(missingDimensions.length);
    ProjectedPoint[] points = new ProjectedPoint[num];
    for (int j = 0; j < newList.size(); j++) {
        ConditionalTreeSample e = newList.get(j);
        float[] values = new float[missingDimensions.length];
        for (int i = 0; i < missingDimensions.length; i++) {
            values[i] = e.leafPoint[missingDimensions[i]];
        }
        center.add(values, e.weight);
        for (int i = 0; i < coordMean.length; i++) {
            // weight unchanges
            coordMean[i] += e.leafPoint[i] * e.weight;
            coordSqSum[i] += e.leafPoint[i] * e.leafPoint[i] * e.weight;
        }
        if (j < num) {
            // weight is changed for clustering,
            // based on the distance of the sample from the query point
            double weight = (e.distance <= threshold) ? e.weight : e.weight * threshold / e.distance;
            points[j] = new ProjectedPoint(values, weight);
        }
    }
    // we compute p50 over the entire set
    float[] median = Arrays.copyOf(queryPoint, queryPoint.length);
    center.recompute();
    for (int y = 0; y < missingDimensions.length; y++) {
        median[missingDimensions[y]] = (float) center.coordinate[y];
    }
    // we compute deviation over the entire set, using original weights and no
    // filters
    float[] deviation = new float[queryPoint.length];
    for (int j = 0; j < coordMean.length; j++) {
        coordMean[j] = coordMean[j] / (float) totalWeight;
        deviation[j] = (float) sqrt(max(0, coordSqSum[j] / totalWeight - coordMean[j] * coordMean[j]));
    }
    /**
     * we now seed the centers according toa farthest point heuristic; such a
     * heuristic is used in clustering algorithms such as CURE:
     * https://en.wikipedia.org/wiki/CURE_algorithm to represent a cluster using
     * multiple points (multi-centroid approach) In that algorithm closest pairs are
     * merged -- the notion of closest will be altered here
     *
     * the first step is initialization to twice the final maximum number of
     * clusters
     */
    ArrayList<Center> centers = new ArrayList<>();
    centers.add(new Center(center.coordinate));
    int maxAllowed = min(center.coordinate.length * MAX_NUMBER_OF_TYPICAL_PER_DIMENSION, MAX_NUMBER_OF_TYPICAL_ELEMENTS);
    for (int k = 0; k < 2 * maxAllowed; k++) {
        double maxDist = 0;
        int maxIndex = -1;
        for (int j = 0; j < points.length; j++) {
            double minDist = Double.MAX_VALUE;
            for (int i = 0; i < centers.size(); i++) {
                minDist = min(minDist, distance(points[j], centers.get(i)));
            }
            if (minDist > maxDist) {
                maxDist = minDist;
                maxIndex = j;
            }
        }
        if (maxDist == 0) {
            break;
        } else {
            centers.add(new Center(Arrays.copyOf(points[maxIndex].coordinate, points[maxIndex].coordinate.length)));
        }
    }
    /**
     * we will now prune the number of clusters iteratively; the first step will be
     * assignment of points the next step would be choosing the optimum centers
     * given the assignment
     */
    double measure = 10;
    do {
        for (int i = 0; i < centers.size(); i++) {
            centers.get(i).reset();
        }
        double maxDist = 0;
        for (int j = 0; j < points.length; j++) {
            double[] dist = new double[centers.size()];
            Arrays.fill(dist, Double.MAX_VALUE);
            double minDist = Double.MAX_VALUE;
            for (int i = 0; i < centers.size(); i++) {
                dist[i] = distance(points[j], centers.get(i));
                minDist = min(minDist, dist[i]);
            }
            if (minDist == 0) {
                for (int i = 0; i < centers.size(); i++) {
                    if (dist[i] == 0) {
                        centers.get(i).add(points[j].coordinate, points[j].weight);
                    }
                }
            } else {
                maxDist = max(maxDist, minDist);
                double sum = 0;
                for (int i = 0; i < centers.size(); i++) {
                    if (dist[i] <= WEIGHT_ALLOCATION_THRESHOLD * minDist) {
                        // setting up harmonic mean
                        sum += minDist / dist[i];
                    }
                }
                for (int i = 0; i < centers.size(); i++) {
                    if (dist[i] == 0) {
                        centers.get(i).add(points[j].coordinate, points[j].weight);
                    } else if (dist[i] <= WEIGHT_ALLOCATION_THRESHOLD * minDist) {
                        // harmonic mean
                        centers.get(i).add(points[j].coordinate, points[j].weight * minDist / (dist[i] * sum));
                    }
                }
            }
        }
        for (int i = 0; i < centers.size(); i++) {
            centers.get(i).recompute();
        }
        /**
         * we now find the "closest" pair and merge them; the smaller weight cluster is
         * merged into the larger weight cluster because of L1 errors
         */
        int first = -1;
        int second = -1;
        measure = 0;
        for (int i = 0; i < centers.size(); i++) {
            for (int j = i + 1; j < centers.size(); j++) {
                double dist = distance(centers.get(i), centers.get(j));
                double tempMeasure = (centers.get(i).radius() + centers.get(j).radius()) / dist;
                if (measure < tempMeasure) {
                    first = i;
                    second = j;
                    measure = tempMeasure;
                }
            }
        }
        if (measure >= SEPARATION_RATIO_FOR_MERGE) {
            if (centers.get(first).weight < centers.get(second).weight) {
                centers.remove(first);
            } else {
                centers.remove(second);
            }
        } else if (centers.size() > maxAllowed) {
            // not well separated, remove small weight cluster centers
            centers.sort((o1, o2) -> Double.compare(o1.weight, o2.weight));
            centers.remove(0);
        }
    } while (centers.size() > maxAllowed || measure >= SEPARATION_RATIO_FOR_MERGE);
    // sort in decreasing weight
    centers.sort((o1, o2) -> Double.compare(o2.weight, o1.weight));
    float[][] pointList = new float[centers.size()][];
    float[] likelihood = new float[centers.size()];
    for (int i = 0; i < centers.size(); i++) {
        pointList[i] = Arrays.copyOf(queryPoint, queryPoint.length);
        for (int j = 0; j < missingDimensions.length; j++) {
            pointList[i][missingDimensions[j]] = centers.get(i).coordinate[j];
        }
        likelihood[i] = (float) (centers.get(i).weight / totalWeight);
    }
    return new ConditionalSampleSummary(totalWeight, pointList, likelihood, median, coordMean, deviation);
}
Also used : Math.sqrt(java.lang.Math.sqrt) Arrays(java.util.Arrays) List(java.util.List) ConditionalTreeSample(com.amazon.randomcutforest.returntypes.ConditionalTreeSample) CommonUtils.checkArgument(com.amazon.randomcutforest.CommonUtils.checkArgument) ConditionalSampleSummary(com.amazon.randomcutforest.returntypes.ConditionalSampleSummary) Math.max(java.lang.Math.max) Math.min(java.lang.Math.min) ArrayList(java.util.ArrayList) ConditionalSampleSummary(com.amazon.randomcutforest.returntypes.ConditionalSampleSummary) ArrayList(java.util.ArrayList) ConditionalTreeSample(com.amazon.randomcutforest.returntypes.ConditionalTreeSample)

Example 3 with ConditionalSampleSummary

use of com.amazon.randomcutforest.returntypes.ConditionalSampleSummary in project random-cut-forest-by-aws by aws.

the class RandomCutForest method getConditionalFieldSummary.

public ConditionalSampleSummary getConditionalFieldSummary(float[] point, int numberOfMissingValues, int[] missingIndexes, double centrality) {
    checkArgument(numberOfMissingValues >= 0, "cannot be negative");
    checkNotNull(missingIndexes, "missingIndexes must not be null");
    checkArgument(numberOfMissingValues <= missingIndexes.length, "numberOfMissingValues must be less than or equal to missingIndexes.length");
    checkArgument(centrality >= 0 && centrality <= 1, "centrality needs to be in range [0,1]");
    checkArgument(point != null, " cannot be null");
    if (!isOutputReady()) {
        return new ConditionalSampleSummary(dimensions);
    }
    int[] liftedIndices = transformIndices(missingIndexes, point.length);
    ConditionalSampleSummarizer summarizer = new ConditionalSampleSummarizer(liftedIndices, transformToShingledPoint(point), centrality);
    return summarizer.summarize(getConditionalField(point, numberOfMissingValues, missingIndexes, centrality));
}
Also used : ConditionalSampleSummary(com.amazon.randomcutforest.returntypes.ConditionalSampleSummary) ConditionalSampleSummarizer(com.amazon.randomcutforest.imputation.ConditionalSampleSummarizer)

Aggregations

ConditionalSampleSummary (com.amazon.randomcutforest.returntypes.ConditionalSampleSummary)3 CommonUtils.checkArgument (com.amazon.randomcutforest.CommonUtils.checkArgument)1 ConditionalSampleSummarizer (com.amazon.randomcutforest.imputation.ConditionalSampleSummarizer)1 ConditionalTreeSample (com.amazon.randomcutforest.returntypes.ConditionalTreeSample)1 NormalMixtureTestData (com.amazon.randomcutforest.testutils.NormalMixtureTestData)1 Math.max (java.lang.Math.max)1 Math.min (java.lang.Math.min)1 Math.sqrt (java.lang.Math.sqrt)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 List (java.util.List)1 Random (java.util.Random)1 Test (org.junit.jupiter.api.Test)1