use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBM method predictLogAssignmentProbs.
/**
* for batch jobs, use this to save computation
* @param x
* @param assignments
* @return
*/
public double[] predictLogAssignmentProbs(Vector x, List<MultiLabel> assignments) {
BMDistribution bmDistribution = computeBM(x);
double[] probs = new double[assignments.size()];
for (int c = 0; c < assignments.size(); c++) {
MultiLabel multiLabel = assignments.get(c);
probs[c] = bmDistribution.logProbability(multiLabel);
}
return probs;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBM method predictByMarginals.
/**
* predict Sign(E(y|x))
* @param vector
* @return
*/
MultiLabel predictByMarginals(Vector vector) {
double[] probs = predictClassProbs(vector);
MultiLabel prediction = new MultiLabel();
for (int l = 0; l < numLabels; l++) {
if (probs[l] > 0.5) {
prediction.addLabel(l);
}
}
return prediction;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class BMDistribution method sample.
List<MultiLabel> sample(int numSamples) {
List<MultiLabel> list = new ArrayList<>();
double[] proportions = Arrays.stream(logProportions).map(Math::exp).toArray();
double[][] classProbs = new double[numComponents][numLabels];
for (int k = 0; k < numComponents; k++) {
for (int l = 0; l < numLabels; l++) {
classProbs[k][l] = Math.exp(logClassProbs[k][l][1]);
}
}
int[] components = IntStream.range(0, numComponents).toArray();
EnumeratedIntegerDistribution enumeratedIntegerDistribution = new EnumeratedIntegerDistribution(components, proportions);
BernoulliDistribution[][] bernoulliDistributions = new BernoulliDistribution[numComponents][numLabels];
for (int k = 0; k < numComponents; k++) {
for (int l = 0; l < numLabels; l++) {
bernoulliDistributions[k][l] = new BernoulliDistribution(classProbs[k][l]);
}
}
for (int num = 0; num < numSamples; num++) {
MultiLabel multiLabel = new MultiLabel();
int k = enumeratedIntegerDistribution.sample();
for (int l = 0; l < numLabels; l++) {
int v = bernoulliDistributions[k][l].sample();
if (v == 1) {
multiLabel.addLabel(l);
}
}
list.add(multiLabel);
}
return list;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class Enumerator method enumerate.
public static List<MultiLabel> enumerate(int numClasses) {
List<MultiLabel> list = new ArrayList<>();
int size = (int) Math.pow(2, numClasses);
for (int i = 0; i < size; i++) {
String ib = toBinary(i, numClasses);
MultiLabel multiLabel = toML(ib);
list.add(multiLabel);
}
return list;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class LossMatrixGenerator method matrix.
public static Matrix matrix(int n, String lossName) {
int size = (int) Math.pow(2, n);
double[][] matrixBuilder = new double[size][size];
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
String ib = toBinary(i, n);
String jb = toBinary(j, n);
MultiLabel multiLabel1 = toML(ib);
MultiLabel multiLabel2 = toML(jb);
MultiLabel[] trueLabels = { multiLabel1 };
MultiLabel[] predicted = { multiLabel2 };
MLConfusionMatrix mlConfusionMatrix = new MLConfusionMatrix(n, trueLabels, predicted);
InstanceAverage instanceAverage = new InstanceAverage(mlConfusionMatrix);
double loss;
switch(lossName.toLowerCase()) {
case "hamming":
loss = instanceAverage.getHammingLoss() * n;
break;
case "overlap":
loss = 1 - instanceAverage.getOverlap();
break;
case "accuracy":
loss = 1 - instanceAverage.getAccuracy();
break;
case "precision":
loss = 1 - instanceAverage.getPrecision();
break;
case "recall":
loss = 1 - instanceAverage.getRecall();
break;
case "f1":
loss = 1 - instanceAverage.getF1();
break;
default:
throw new IllegalArgumentException("unknown loss");
}
matrixBuilder[i][j] = loss;
}
}
Matrix matrix = new DenseMatrix(matrixBuilder);
return matrix;
}
Aggregations