Search in sources :

Example 1 with BasicMatrix

use of org.ojalgo.matrix.BasicMatrix in project pyramid by cheng-li.

the class CBMInspector method covariance.

public static void covariance(CBM CBM, Vector vector, LabelTranslator labelTranslator) {
    int numClusters = CBM.getNumComponents();
    int numClasses = CBM.getNumClasses();
    double[] proportions = CBM.getMultiClassClassifier().predictClassProbs(vector);
    double[][] probabilities = new double[numClusters][numClasses];
    for (int k = 0; k < numClusters; k++) {
        for (int l = 0; l < numClasses; l++) {
            probabilities[k][l] = CBM.getBinaryClassifiers()[k][l].predictClassProb(vector, 1);
        }
    }
    // column vector
    Access2D.Builder<PrimitiveMatrix> meanBuilder = factory.getBuilder(numClasses, 1);
    for (int l = 0; l < numClasses; l++) {
        double sum = 0;
        for (int k = 0; k < numClusters; k++) {
            sum += proportions[k] * probabilities[k][l];
        }
        meanBuilder.set(l, 0, sum);
    }
    BasicMatrix mean = meanBuilder.build();
    //        System.out.println(mean);
    List<BasicMatrix> mus = new ArrayList<>();
    for (int k = 0; k < numClusters; k++) {
        Access2D.Builder<PrimitiveMatrix> muBuilder = factory.getBuilder(numClasses, 1);
        for (int l = 0; l < numClasses; l++) {
            muBuilder.set(l, 0, probabilities[k][l]);
        }
        BasicMatrix muK = muBuilder.build();
        mus.add(muK);
    }
    List<BasicMatrix> sigmas = new ArrayList<>();
    for (int k = 0; k < numClusters; k++) {
        Access2D.Builder<PrimitiveMatrix> sigmaBuilder = factory.getBuilder(numClasses, numClasses);
        for (int l = 0; l < numClasses; l++) {
            double v = probabilities[k][l] * (1 - probabilities[k][l]);
            sigmaBuilder.set(l, l, v);
        }
        BasicMatrix sigmaK = sigmaBuilder.build();
        sigmas.add(sigmaK);
    }
    BasicMatrix covariance = factory.makeZero(numClasses, numClasses);
    for (int k = 0; k < numClusters; k++) {
        BasicMatrix muk = mus.get(k);
        BasicMatrix toadd = (sigmas.get(k).add(muk.multiply(muk.transpose()))).multiply(proportions[k]);
        covariance = covariance.add(toadd);
    }
    covariance = covariance.subtract(mean.multiply(mean.transpose()));
    //        System.out.println("covariance = "+ Matrices.display(covariance));
    Access2D.Builder<PrimitiveMatrix> correlationBuilder = factory.getBuilder(numClasses, numClasses);
    for (int l = 0; l < numClasses; l++) {
        for (int j = 0; j < numClasses; j++) {
            double v = covariance.get(l, j).doubleValue() / (Math.sqrt(covariance.get(l, l).doubleValue()) * Math.sqrt(covariance.get(j, j).doubleValue()));
            correlationBuilder.set(l, j, v);
        }
    }
    BasicMatrix correlation = correlationBuilder.build();
    //        System.out.println("correlation = "+ Matrices.display(correlation));
    List<Pair<String, Double>> list = new ArrayList<>();
    for (int l = 0; l < numClasses; l++) {
        for (int j = 0; j < l; j++) {
            String s = "" + labelTranslator.toExtLabel(l) + ", " + labelTranslator.toExtLabel(j);
            double v = correlation.get(l, j).doubleValue();
            Pair<String, Double> pair = new Pair<>(s, v);
            list.add(pair);
        }
    }
    Comparator<Pair<String, Double>> comparator = Comparator.comparing(pair -> Math.abs(pair.getSecond()));
    List<Pair<String, Double>> top = list.stream().sorted(comparator.reversed()).limit(20).collect(Collectors.toList());
    System.out.println(top);
}
Also used : PrimitiveMatrix(org.ojalgo.matrix.PrimitiveMatrix) BasicMatrix(org.ojalgo.matrix.BasicMatrix) Access2D(org.ojalgo.access.Access2D) Pair(edu.neu.ccs.pyramid.util.Pair)

Aggregations

Pair (edu.neu.ccs.pyramid.util.Pair)1 Access2D (org.ojalgo.access.Access2D)1 BasicMatrix (org.ojalgo.matrix.BasicMatrix)1 PrimitiveMatrix (org.ojalgo.matrix.PrimitiveMatrix)1