Search in sources :

Example 1 with EnumeratedIntegerDistribution

use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.

the class Splitter method sample.

static Optional<SplitResult> sample(List<SplitResult> splitResults) {
    if (splitResults.size() == 0) {
        return Optional.empty();
    }
    if (splitResults.get(0).getReduction() == 0) {
        return Optional.empty();
    }
    double total = splitResults.stream().mapToDouble(SplitResult::getReduction).sum();
    double[] probs = splitResults.stream().mapToDouble(splitResult -> splitResult.getReduction() / total).toArray();
    int[] singletons = IntStream.range(0, splitResults.size()).toArray();
    EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(singletons, probs);
    int sample = distribution.sample();
    return Optional.of(splitResults.get(sample));
}
Also used : IntStream(java.util.stream.IntStream) java.util(java.util) Logger(org.apache.logging.log4j.Logger) java.util.concurrent(java.util.concurrent) DataSet(edu.neu.ccs.pyramid.dataset.DataSet) LogManager(org.apache.logging.log4j.LogManager) Collectors(java.util.stream.Collectors) EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution)

Example 2 with EnumeratedIntegerDistribution

use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.

the class SamplingPrediction method predict.

public static MultiLabel predict(double[] probabilities, List<MultiLabel> candidates) {
    int[] s = IntStream.range(0, probabilities.length).toArray();
    EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(s, probabilities);
    int i = distribution.sample();
    return candidates.get(i);
}
Also used : EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution)

Example 3 with EnumeratedIntegerDistribution

use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.

the class MultiLabelSynthesizer method sampleFromMix.

/**
     * C0, y0: w=(0,1)
     * C0, y1: w=(1,1)
     * C1, y0: w=(1,0)
     * C1, y1: w=(1,-1)
     * @return
     */
public static MultiLabelClfDataSet sampleFromMix() {
    int numData = 10000;
    int numClass = 2;
    int numFeature = 2;
    int numClusters = 2;
    double[] proportions = { 0.4, 0.6 };
    int[] indices = { 0, 1 };
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
    // generate weights
    Vector[][] weights = new Vector[numClusters][numClass];
    for (int c = 0; c < numClusters; c++) {
        for (int l = 0; l < numClass; l++) {
            Vector vector = new DenseVector(numFeature);
            weights[c][l] = vector;
        }
    }
    weights[0][0].set(0, 0);
    weights[0][0].set(1, 1);
    weights[0][1].set(0, 1);
    weights[0][1].set(1, 1);
    weights[1][0].set(0, 1);
    weights[1][0].set(1, 0);
    weights[1][1].set(0, 1);
    weights[1][1].set(1, -1);
    // generate features
    for (int i = 0; i < numData; i++) {
        for (int j = 0; j < numFeature; j++) {
            dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
        }
    }
    IntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, proportions);
    // assign labels
    for (int i = 0; i < numData; i++) {
        int cluster = distribution.sample();
        System.out.println("cluster " + cluster);
        for (int l = 0; l < numClass; l++) {
            System.out.println("row = " + dataSet.getRow(i));
            System.out.println("weight = " + weights[cluster][l]);
            double dot = weights[cluster][l].dot(dataSet.getRow(i));
            System.out.println("dot = " + dot);
            if (dot >= 0) {
                dataSet.addLabel(i, l);
            }
        }
    }
    return dataSet;
}
Also used : EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) IntegerDistribution(org.apache.commons.math3.distribution.IntegerDistribution) EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) DenseVector(org.apache.mahout.math.DenseVector)

Example 4 with EnumeratedIntegerDistribution

use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.

the class MultiLabelSynthesizer method flipOneNonUniform.

/**
     * y0: w=(0,1)
     * y1: w=(1,1)
     * y2: w=(1,0)
     * y3: w=(1,-1)
     * @param numData
     * @return
     */
public static MultiLabelClfDataSet flipOneNonUniform(int numData) {
    int numClass = 4;
    int numFeature = 2;
    MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
    // generate weights
    Vector[] weights = new Vector[numClass];
    for (int k = 0; k < numClass; k++) {
        Vector vector = new DenseVector(numFeature);
        weights[k] = vector;
    }
    weights[0].set(0, 0);
    weights[0].set(1, 1);
    weights[1].set(0, 1);
    weights[1].set(1, 1);
    weights[2].set(0, 1);
    weights[2].set(1, 0);
    weights[3].set(0, 1);
    weights[3].set(1, -1);
    // generate features
    for (int i = 0; i < numData; i++) {
        for (int j = 0; j < numFeature; j++) {
            dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
        }
    }
    // assign labels
    for (int i = 0; i < numData; i++) {
        for (int k = 0; k < numClass; k++) {
            double dot = weights[k].dot(dataSet.getRow(i));
            if (dot >= 0) {
                dataSet.addLabel(i, k);
            }
        }
    }
    int[] indices = { 0, 1, 2, 3 };
    double[] probs = { 0.4, 0.2, 0.2, 0.2 };
    IntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, probs);
    // flip
    for (int i = 0; i < numData; i++) {
        int toChange = distribution.sample();
        MultiLabel label = dataSet.getMultiLabels()[i];
        if (label.matchClass(toChange)) {
            label.removeLabel(toChange);
        } else {
            label.addLabel(toChange);
        }
    }
    return dataSet;
}
Also used : EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) IntegerDistribution(org.apache.commons.math3.distribution.IntegerDistribution) EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) DenseVector(org.apache.mahout.math.DenseVector)

Example 5 with EnumeratedIntegerDistribution

use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.

the class BM method sample.

/**
     * sample a vector from the mixture distribution
     * @return
     */
public Vector sample() {
    Vector vector = new DenseVector(dimension);
    // first sample cluster
    int[] clusters = IntStream.range(0, numClusters).toArray();
    EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(clusters, mixtureCoefficients);
    int cluster = enumeratedIntegerDistribution.sample();
    // then sample each dimension
    for (int d = 0; d < dimension; d++) {
        vector.set(d, distributions[cluster][d].sample());
    }
    return vector;
}
Also used : EnumeratedIntegerDistribution(org.apache.commons.math3.distribution.EnumeratedIntegerDistribution) DenseVector(org.apache.mahout.math.DenseVector) Vector(org.apache.mahout.math.Vector) DenseVector(org.apache.mahout.math.DenseVector)

Aggregations

EnumeratedIntegerDistribution (org.apache.commons.math3.distribution.EnumeratedIntegerDistribution)6 DenseVector (org.apache.mahout.math.DenseVector)3 Vector (org.apache.mahout.math.Vector)3 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)2 IntegerDistribution (org.apache.commons.math3.distribution.IntegerDistribution)2 DataSet (edu.neu.ccs.pyramid.dataset.DataSet)1 BernoulliDistribution (edu.neu.ccs.pyramid.util.BernoulliDistribution)1 java.util (java.util)1 ArrayList (java.util.ArrayList)1 java.util.concurrent (java.util.concurrent)1 Collectors (java.util.stream.Collectors)1 IntStream (java.util.stream.IntStream)1 LogManager (org.apache.logging.log4j.LogManager)1 Logger (org.apache.logging.log4j.Logger)1