use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.
the class Splitter method sample.
static Optional<SplitResult> sample(List<SplitResult> splitResults) {
if (splitResults.size() == 0) {
return Optional.empty();
}
if (splitResults.get(0).getReduction() == 0) {
return Optional.empty();
}
double total = splitResults.stream().mapToDouble(SplitResult::getReduction).sum();
double[] probs = splitResults.stream().mapToDouble(splitResult -> splitResult.getReduction() / total).toArray();
int[] singletons = IntStream.range(0, splitResults.size()).toArray();
EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(singletons, probs);
int sample = distribution.sample();
return Optional.of(splitResults.get(sample));
}
use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.
the class SamplingPrediction method predict.
public static MultiLabel predict(double[] probabilities, List<MultiLabel> candidates) {
int[] s = IntStream.range(0, probabilities.length).toArray();
EnumeratedIntegerDistribution distribution = new EnumeratedIntegerDistribution(s, probabilities);
int i = distribution.sample();
return candidates.get(i);
}
use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.
the class MultiLabelSynthesizer method sampleFromMix.
/**
* C0, y0: w=(0,1)
* C0, y1: w=(1,1)
* C1, y0: w=(1,0)
* C1, y1: w=(1,-1)
* @return
*/
public static MultiLabelClfDataSet sampleFromMix() {
int numData = 10000;
int numClass = 2;
int numFeature = 2;
int numClusters = 2;
double[] proportions = { 0.4, 0.6 };
int[] indices = { 0, 1 };
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
// generate weights
Vector[][] weights = new Vector[numClusters][numClass];
for (int c = 0; c < numClusters; c++) {
for (int l = 0; l < numClass; l++) {
Vector vector = new DenseVector(numFeature);
weights[c][l] = vector;
}
}
weights[0][0].set(0, 0);
weights[0][0].set(1, 1);
weights[0][1].set(0, 1);
weights[0][1].set(1, 1);
weights[1][0].set(0, 1);
weights[1][0].set(1, 0);
weights[1][1].set(0, 1);
weights[1][1].set(1, -1);
// generate features
for (int i = 0; i < numData; i++) {
for (int j = 0; j < numFeature; j++) {
dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
}
}
IntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, proportions);
// assign labels
for (int i = 0; i < numData; i++) {
int cluster = distribution.sample();
System.out.println("cluster " + cluster);
for (int l = 0; l < numClass; l++) {
System.out.println("row = " + dataSet.getRow(i));
System.out.println("weight = " + weights[cluster][l]);
double dot = weights[cluster][l].dot(dataSet.getRow(i));
System.out.println("dot = " + dot);
if (dot >= 0) {
dataSet.addLabel(i, l);
}
}
}
return dataSet;
}
use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.
the class MultiLabelSynthesizer method flipOneNonUniform.
/**
* y0: w=(0,1)
* y1: w=(1,1)
* y2: w=(1,0)
* y3: w=(1,-1)
* @param numData
* @return
*/
public static MultiLabelClfDataSet flipOneNonUniform(int numData) {
int numClass = 4;
int numFeature = 2;
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
// generate weights
Vector[] weights = new Vector[numClass];
for (int k = 0; k < numClass; k++) {
Vector vector = new DenseVector(numFeature);
weights[k] = vector;
}
weights[0].set(0, 0);
weights[0].set(1, 1);
weights[1].set(0, 1);
weights[1].set(1, 1);
weights[2].set(0, 1);
weights[2].set(1, 0);
weights[3].set(0, 1);
weights[3].set(1, -1);
// generate features
for (int i = 0; i < numData; i++) {
for (int j = 0; j < numFeature; j++) {
dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
}
}
// assign labels
for (int i = 0; i < numData; i++) {
for (int k = 0; k < numClass; k++) {
double dot = weights[k].dot(dataSet.getRow(i));
if (dot >= 0) {
dataSet.addLabel(i, k);
}
}
}
int[] indices = { 0, 1, 2, 3 };
double[] probs = { 0.4, 0.2, 0.2, 0.2 };
IntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, probs);
// flip
for (int i = 0; i < numData; i++) {
int toChange = distribution.sample();
MultiLabel label = dataSet.getMultiLabels()[i];
if (label.matchClass(toChange)) {
label.removeLabel(toChange);
} else {
label.addLabel(toChange);
}
}
return dataSet;
}
use of org.apache.commons.math3.distribution.EnumeratedIntegerDistribution in project pyramid by cheng-li.
the class BM method sample.
/**
* sample a vector from the mixture distribution
* @return
*/
public Vector sample() {
Vector vector = new DenseVector(dimension);
// first sample cluster
int[] clusters = IntStream.range(0, numClusters).toArray();
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(clusters, mixtureCoefficients);
int cluster = enumeratedIntegerDistribution.sample();
// then sample each dimension
for (int d = 0; d < dimension; d++) {
vector.set(d, distributions[cluster][d].sample());
}
return vector;
}
Aggregations