use of org.ojalgo.matrix.BasicMatrix in project pyramid by cheng-li.
the class CBMInspector method covariance.
public static void covariance(CBM CBM, Vector vector, LabelTranslator labelTranslator) {
int numClusters = CBM.getNumComponents();
int numClasses = CBM.getNumClasses();
double[] proportions = CBM.getMultiClassClassifier().predictClassProbs(vector);
double[][] probabilities = new double[numClusters][numClasses];
for (int k = 0; k < numClusters; k++) {
for (int l = 0; l < numClasses; l++) {
probabilities[k][l] = CBM.getBinaryClassifiers()[k][l].predictClassProb(vector, 1);
}
}
// column vector
Access2D.Builder<PrimitiveMatrix> meanBuilder = factory.getBuilder(numClasses, 1);
for (int l = 0; l < numClasses; l++) {
double sum = 0;
for (int k = 0; k < numClusters; k++) {
sum += proportions[k] * probabilities[k][l];
}
meanBuilder.set(l, 0, sum);
}
BasicMatrix mean = meanBuilder.build();
// System.out.println(mean);
List<BasicMatrix> mus = new ArrayList<>();
for (int k = 0; k < numClusters; k++) {
Access2D.Builder<PrimitiveMatrix> muBuilder = factory.getBuilder(numClasses, 1);
for (int l = 0; l < numClasses; l++) {
muBuilder.set(l, 0, probabilities[k][l]);
}
BasicMatrix muK = muBuilder.build();
mus.add(muK);
}
List<BasicMatrix> sigmas = new ArrayList<>();
for (int k = 0; k < numClusters; k++) {
Access2D.Builder<PrimitiveMatrix> sigmaBuilder = factory.getBuilder(numClasses, numClasses);
for (int l = 0; l < numClasses; l++) {
double v = probabilities[k][l] * (1 - probabilities[k][l]);
sigmaBuilder.set(l, l, v);
}
BasicMatrix sigmaK = sigmaBuilder.build();
sigmas.add(sigmaK);
}
BasicMatrix covariance = factory.makeZero(numClasses, numClasses);
for (int k = 0; k < numClusters; k++) {
BasicMatrix muk = mus.get(k);
BasicMatrix toadd = (sigmas.get(k).add(muk.multiply(muk.transpose()))).multiply(proportions[k]);
covariance = covariance.add(toadd);
}
covariance = covariance.subtract(mean.multiply(mean.transpose()));
// System.out.println("covariance = "+ Matrices.display(covariance));
Access2D.Builder<PrimitiveMatrix> correlationBuilder = factory.getBuilder(numClasses, numClasses);
for (int l = 0; l < numClasses; l++) {
for (int j = 0; j < numClasses; j++) {
double v = covariance.get(l, j).doubleValue() / (Math.sqrt(covariance.get(l, l).doubleValue()) * Math.sqrt(covariance.get(j, j).doubleValue()));
correlationBuilder.set(l, j, v);
}
}
BasicMatrix correlation = correlationBuilder.build();
// System.out.println("correlation = "+ Matrices.display(correlation));
List<Pair<String, Double>> list = new ArrayList<>();
for (int l = 0; l < numClasses; l++) {
for (int j = 0; j < l; j++) {
String s = "" + labelTranslator.toExtLabel(l) + ", " + labelTranslator.toExtLabel(j);
double v = correlation.get(l, j).doubleValue();
Pair<String, Double> pair = new Pair<>(s, v);
list.add(pair);
}
}
Comparator<Pair<String, Double>> comparator = Comparator.comparing(pair -> Math.abs(pair.getSecond()));
List<Pair<String, Double>> top = list.stream().sorted(comparator.reversed()).limit(20).collect(Collectors.toList());
System.out.println(top);
}
Aggregations