use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class BMSelector method selectAll.
public static Pair<BM, double[][]> selectAll(int numClasses, MultiLabel[] multiLabels, int numClusters) {
DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
for (int i = 0; i < multiLabels.length; i++) {
MultiLabel multiLabel = multiLabels[i];
for (int label : multiLabel.getMatchedLabels()) {
dataSet.setFeatureValue(i, label, 1);
}
}
BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
// System.out.println("bm = "+trainer.bm);
// System.out.println("gamma = "+ Arrays.deepToString(trainer.gammas));
Pair<BM, double[][]> pair = new Pair<>();
pair.setFirst(trainer.getBm());
pair.setSecond(trainer.gammas);
return pair;
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class ArffFormat method writeMatrixFile.
private static void writeMatrixFile(ClfDataSet dataSet, File arffFile) {
File matrixFile = new File(arffFile, ARFF_MATRIX_FILE_NAME);
int numDataPoints = dataSet.getNumDataPoints();
int numFeatures = dataSet.getNumFeatures();
int[] labels = dataSet.getLabels();
try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
bw.write("@RELATION MATRIX" + "\n");
for (int i = 0; i < numFeatures; i++) {
bw.write("@ATTRIBUTE " + i + " NUMERIC" + "\n");
}
bw.write("@ATTRIBUTE class {0");
for (int i = 1; i < dataSet.getNumClasses(); i++) {
bw.write("," + i);
}
bw.write("}" + "\n");
bw.write("@DATA" + "\n");
for (int i = 0; i < numDataPoints; i++) {
int label = labels[i];
bw.write("{");
Vector vector = dataSet.getRow(i);
// only write non-zeros
// only write non-zeros
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : vector.nonZeroes()) {
Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
pairs.add(pair);
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
for (Pair<Integer, Double> pair : sorted) {
bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
}
bw.write(numFeatures + " " + label + "}" + "\n");
}
} catch (IOException e) {
e.printStackTrace();
}
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class CBMInspector method covariance.
public static void covariance(CBM CBM, Vector vector, LabelTranslator labelTranslator) {
int numClusters = CBM.getNumComponents();
int numClasses = CBM.getNumClasses();
double[] proportions = CBM.getMultiClassClassifier().predictClassProbs(vector);
double[][] probabilities = new double[numClusters][numClasses];
for (int k = 0; k < numClusters; k++) {
for (int l = 0; l < numClasses; l++) {
probabilities[k][l] = CBM.getBinaryClassifiers()[k][l].predictClassProb(vector, 1);
}
}
// column vector
Access2D.Builder<PrimitiveMatrix> meanBuilder = factory.getBuilder(numClasses, 1);
for (int l = 0; l < numClasses; l++) {
double sum = 0;
for (int k = 0; k < numClusters; k++) {
sum += proportions[k] * probabilities[k][l];
}
meanBuilder.set(l, 0, sum);
}
BasicMatrix mean = meanBuilder.build();
// System.out.println(mean);
List<BasicMatrix> mus = new ArrayList<>();
for (int k = 0; k < numClusters; k++) {
Access2D.Builder<PrimitiveMatrix> muBuilder = factory.getBuilder(numClasses, 1);
for (int l = 0; l < numClasses; l++) {
muBuilder.set(l, 0, probabilities[k][l]);
}
BasicMatrix muK = muBuilder.build();
mus.add(muK);
}
List<BasicMatrix> sigmas = new ArrayList<>();
for (int k = 0; k < numClusters; k++) {
Access2D.Builder<PrimitiveMatrix> sigmaBuilder = factory.getBuilder(numClasses, numClasses);
for (int l = 0; l < numClasses; l++) {
double v = probabilities[k][l] * (1 - probabilities[k][l]);
sigmaBuilder.set(l, l, v);
}
BasicMatrix sigmaK = sigmaBuilder.build();
sigmas.add(sigmaK);
}
BasicMatrix covariance = factory.makeZero(numClasses, numClasses);
for (int k = 0; k < numClusters; k++) {
BasicMatrix muk = mus.get(k);
BasicMatrix toadd = (sigmas.get(k).add(muk.multiply(muk.transpose()))).multiply(proportions[k]);
covariance = covariance.add(toadd);
}
covariance = covariance.subtract(mean.multiply(mean.transpose()));
// System.out.println("covariance = "+ Matrices.display(covariance));
Access2D.Builder<PrimitiveMatrix> correlationBuilder = factory.getBuilder(numClasses, numClasses);
for (int l = 0; l < numClasses; l++) {
for (int j = 0; j < numClasses; j++) {
double v = covariance.get(l, j).doubleValue() / (Math.sqrt(covariance.get(l, l).doubleValue()) * Math.sqrt(covariance.get(j, j).doubleValue()));
correlationBuilder.set(l, j, v);
}
}
BasicMatrix correlation = correlationBuilder.build();
// System.out.println("correlation = "+ Matrices.display(correlation));
List<Pair<String, Double>> list = new ArrayList<>();
for (int l = 0; l < numClasses; l++) {
for (int j = 0; j < l; j++) {
String s = "" + labelTranslator.toExtLabel(l) + ", " + labelTranslator.toExtLabel(j);
double v = correlation.get(l, j).doubleValue();
Pair<String, Double> pair = new Pair<>(s, v);
list.add(pair);
}
}
Comparator<Pair<String, Double>> comparator = Comparator.comparing(pair -> Math.abs(pair.getSecond()));
List<Pair<String, Double>> top = list.stream().sorted(comparator.reversed()).limit(20).collect(Collectors.toList());
System.out.println(top);
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class GeneralF1Predictor method showSupportPrediction.
public static Analysis showSupportPrediction(List<MultiLabel> combinations, double[] probs, MultiLabel truth, MultiLabel prediction, int numClasses) {
int truthIndex = 0;
for (int i = 0; i < combinations.size(); i++) {
if (combinations.get(i).equals(truth)) {
truthIndex = i;
break;
}
}
double[] trueJoint = new double[combinations.size()];
trueJoint[truthIndex] = 1;
double kl = KLDivergence.kl(trueJoint, probs);
List<Pair<MultiLabel, Double>> list = new ArrayList<>();
for (int i = 0; i < combinations.size(); i++) {
list.add(new Pair<>(combinations.get(i), probs[i]));
}
Comparator<Pair<MultiLabel, Double>> comparator = Comparator.comparing(a -> a.getSecond());
List<Pair<MultiLabel, Double>> sorted = list.stream().sorted(comparator.reversed()).filter(pair -> pair.getSecond() > 0.01).collect(Collectors.toList());
double expectedF1Prediction = expectedF1(combinations, probs, prediction, numClasses);
double expectedF1Truth = expectedF1(combinations, probs, truth, numClasses);
double actualF1 = new InstanceAverage(numClasses, truth, prediction).getF1();
StringBuilder jointString = new StringBuilder();
for (int i = 0; i < sorted.size(); i++) {
jointString.append(sorted.get(i).getFirst()).append(":").append(sorted.get(i).getSecond()).append(", ");
}
Analysis analysis = new Analysis();
analysis.expectedF1Prediction = expectedF1Prediction;
analysis.expectedF1Truth = expectedF1Truth;
analysis.actualF1 = actualF1;
analysis.kl = kl;
analysis.prediction = prediction;
analysis.truth = truth;
analysis.joint = jointString.toString();
return analysis;
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class ClusterLabels method getCluster.
private static List<WordFrequency> getCluster(BM bm, int k) throws Exception {
BernoulliDistribution[][] distributions = bm.getDistributions();
List<Pair<String, Double>> pairs = new ArrayList<>();
for (int d = 0; d < bm.getDimension(); d++) {
Pair<String, Double> pair = new Pair<>(bm.getNames().get(d), distributions[k][d].getP());
pairs.add(pair);
}
Comparator<Pair<String, Double>> comparator = Comparator.comparing(Pair::getSecond);
List<Pair<String, Double>> sorted = pairs.stream().sorted(comparator.reversed()).collect(Collectors.toList());
List<WordFrequency> frequencies = new ArrayList<>();
double sum = sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).mapToDouble(Pair::getSecond).sum();
sorted.stream().filter(pair -> pair.getSecond() > 0).limit(20).forEach(pair -> {
WordFrequency wordFrequency = new WordFrequency(pair.getFirst(), (int) (pair.getSecond() * 200 / sum));
frequencies.add(wordFrequency);
});
return frequencies;
}
Aggregations