Search in sources :

Example 16 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class ClassificationSynthesizerTest method test3.

private static void test3() {
    ClassificationSynthesizer classificationSynthesizer = ClassificationSynthesizer.getBuilder().setNumDataPoints(1000).setNumFeatures(2).setNoiseSD(0.1).build();
    ClfDataSet trainSet = classificationSynthesizer.multivarLine();
    ClfDataSet testSet = classificationSynthesizer.multivarLine();
    TRECFormat.save(trainSet, new File(TMP, "line3/train.trec"));
    TRECFormat.save(testSet, new File(TMP, "line3/test.trec"));
    RidgeLogisticTrainer trainer = RidgeLogisticTrainer.getBuilder().setGaussianPriorVariance(1).build();
    LogisticRegression logisticRegression = trainer.train(trainSet);
    System.out.println(Accuracy.accuracy(logisticRegression, trainSet));
    System.out.println(Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(logisticRegression.getWeights().getWeightsForClass(0));
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) RidgeLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) File(java.io.File)

Example 17 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class KLDivergence method kl_conditional.

// empirical KL
public static double kl_conditional(MultiLabelClassifier.AssignmentProbEstimator multiLabelClassifier, MultiLabelClfDataSet dataSet) {
    Map<MultiLabel, Integer> q_z = new HashMap<MultiLabel, Integer>();
    Map<MultiLabel, HashMap<MultiLabel, Integer>> q_yz = new HashMap<MultiLabel, HashMap<MultiLabel, Integer>>();
    // get overall empirical distribution
    for (int i = 0; i < dataSet.getNumDataPoints(); ++i) {
        MultiLabel z = new MultiLabel(dataSet.getRow(i));
        MultiLabel y = dataSet.getMultiLabels()[i];
        if (q_z.containsKey(z)) {
            q_z.put(z, q_z.get(z) + 1);
        } else {
            q_z.put(z, 1);
        }
        if (!q_yz.containsKey(z)) {
            q_yz.put(z, new HashMap<MultiLabel, Integer>());
        }
        if (q_yz.get(z).containsKey(y)) {
            q_yz.get(z).put(y, q_yz.get(z).get(y) + 1);
        } else {
            q_yz.get(z).put(y, 1);
        }
    }
    // compute kl divergence
    double kl = 0.0;
    for (Map.Entry<MultiLabel, Integer> e1 : q_z.entrySet()) {
        double kl_y = 0.0;
        for (Map.Entry<MultiLabel, Integer> e2 : q_yz.get(e1.getKey()).entrySet()) {
            double empirical_prob_yz = (double) e2.getValue() / (double) e1.getValue();
            double log_estimated_prob_yz = multiLabelClassifier.predictLogAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), e2.getKey());
            kl_y += empirical_prob_yz * (Math.log(empirical_prob_yz) - log_estimated_prob_yz);
        }
        double empirical_prob_z = (double) e1.getValue() / (double) dataSet.getNumDataPoints();
        kl += empirical_prob_z * kl_y;
    }
    // Printing information if needed
    int occur_threshold = 10;
    double marginal_threshold = 0.01;
    for (Map.Entry<MultiLabel, Integer> e1 : q_z.entrySet()) {
        double[] marginals1 = new double[dataSet.getNumFeatures()];
        for (Map.Entry<MultiLabel, Integer> e2 : q_yz.get(e1.getKey()).entrySet()) {
            double estimated_prob_yz = multiLabelClassifier.predictAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), e2.getKey());
            double empirical_prob_yz = (double) e2.getValue() / (double) e1.getValue();
            if (e1.getValue() >= occur_threshold) {
                System.out.println("#z:" + e1.getValue() + ",z=" + e1.getKey().toStringWithExtLabels(dataSet.getLabelTranslator()) + "->{" + e2.getKey().toStringWithExtLabels(dataSet.getLabelTranslator()) + "},#y:" + e2.getValue() + ",p_y|z_empirical:" + empirical_prob_yz + ",p_y|z_estimated:" + estimated_prob_yz);
            }
            for (int i = 0; i < dataSet.getNumFeatures(); i++) {
                if (e2.getKey().matchClass(i)) {
                    marginals1[i] += e2.getValue();
                }
            }
        }
        if (e1.getValue() >= occur_threshold) {
            double estimated_prob_zz = multiLabelClassifier.predictAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), e1.getKey());
            System.out.println("p(y=z|z)=" + estimated_prob_zz);
            CBM cbm = (CBM) multiLabelClassifier;
            //                List<MultiLabel> sampled = cbm.samples(e1.getKey().toVector(dataSet.getNumFeatures()), 10);
            //                for (int i = 0; i < sampled.size(); ++i) {
            //                    double prob = multiLabelClassifier.predictAssignmentProb(e1.getKey().toVector(dataSet.getNumFeatures()), sampled.get(i));
            //                    System.out.println(sampled.get(i).toStringWithExtLabels(dataSet.getLabelTranslator()) + ":" + prob);
            //                }
            System.out.println("p_y|z_estimated marginals are: ");
            double[] marginals = cbm.predictClassProbs(e1.getKey().toVector(dataSet.getNumFeatures()));
            int[] order = ArgSort.argSortDescending(marginals);
            for (int i = 0; i < order.length; ++i) {
                if (marginals[order[i]] > marginal_threshold) {
                    System.out.println(dataSet.getLabelTranslator().toExtLabel(order[i]) + ":" + marginals[order[i]]);
                }
            }
            System.out.println("p_y|z_empirical marginals are: ");
            for (int i = 0; i < dataSet.getNumFeatures(); i++) {
                marginals1[i] /= (double) e1.getValue();
            }
            int[] order1 = ArgSort.argSortDescending(marginals1);
            for (int i = 0; i < order1.length; ++i) {
                if (marginals1[order1[i]] > marginal_threshold) {
                    System.out.println(dataSet.getLabelTranslator().toExtLabel(order1[i]) + ":" + marginals1[order1[i]]);
                }
            }
        }
    }
    System.out.println("LRs for each label:");
    CBM cbm = (CBM) multiLabelClassifier;
    Classifier.ProbabilityEstimator[] estimators = cbm.getBinaryClassifiers()[0];
    for (int i = 0; i < estimators.length; i++) {
        System.out.println("LR:" + dataSet.getLabelTranslator().toExtLabel(i));
        LogisticRegression lr = (LogisticRegression) estimators[i];
        Vector weight_vec = lr.getWeights().getWeightsWithoutBiasForClass(1);
        double[] weights = new double[weight_vec.size()];
        for (int j = 0; j < weight_vec.size(); j++) {
            weights[j] = weight_vec.get(j);
        }
        System.out.println("bias:" + lr.getWeights().getBiasForClass(1));
        int[] order2 = ArgSort.argSortDescending(weights);
        for (int j = 0; j < order2.length; ++j) {
            System.out.println(dataSet.getLabelTranslator().toExtLabel(order2[j]) + ":" + weights[order2[j]]);
        }
    }
    System.out.println("---");
    return kl;
}
Also used : MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) HashMap(java.util.HashMap) CBM(edu.neu.ccs.pyramid.multilabel_classification.cbm.CBM) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) HashMap(java.util.HashMap) Map(java.util.Map) Vector(org.apache.mahout.math.Vector)

Example 18 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class SparkCBMOptimizer method updateBinaryClassifiers.

private void updateBinaryClassifiers() {
    if (logger.isDebugEnabled()) {
        logger.debug("start updateBinaryClassifiers");
    }
    Classifier.ProbabilityEstimator[][] localBinaryClassifiers = cbm.binaryClassifiers;
    double[][] localGammasT = gammasT;
    Broadcast<MultiLabelClfDataSet> localDataSetBroadcast = dataSetBroadCast;
    Broadcast<double[][][]> localTargetsBroadcast = targetDisBroadCast;
    double localVariance = priorVarianceBinary;
    List<BinaryTask> binaryTaskList = new ArrayList<>();
    for (int k = 0; k < cbm.numComponents; k++) {
        for (int l = 0; l < cbm.numLabels; l++) {
            LogisticRegression logisticRegression = (LogisticRegression) localBinaryClassifiers[k][l];
            double[] weights = localGammasT[k];
            binaryTaskList.add(new BinaryTask(k, l, logisticRegression, weights));
        }
    }
    JavaRDD<BinaryTask> binaryTaskRDD = sparkContext.parallelize(binaryTaskList, binaryTaskList.size());
    List<BinaryTaskResult> results = binaryTaskRDD.map(binaryTask -> {
        int labelIndex = binaryTask.classIndex;
        return updateBinaryLogisticRegression(binaryTask.componentIndex, binaryTask.classIndex, binaryTask.logisticRegression, localDataSetBroadcast.value(), binaryTask.weights, localTargetsBroadcast.value()[labelIndex], localVariance);
    }).collect();
    for (BinaryTaskResult result : results) {
        cbm.binaryClassifiers[result.componentIndex][result.classIndex] = result.binaryClassifier;
    }
    //        IntStream.range(0, cbm.numComponents).forEach(this::updateBinaryClassifiers);
    if (logger.isDebugEnabled()) {
        logger.debug("finish updateBinaryClassifiers");
    }
}
Also used : IntStream(java.util.stream.IntStream) JavaSparkContext(org.apache.spark.api.java.JavaSparkContext) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) Terminator(edu.neu.ccs.pyramid.optimization.Terminator) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) RidgeLogisticOptimizer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer) MultiLabel(edu.neu.ccs.pyramid.dataset.MultiLabel) RegTreeFactory(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory) LKBoost(edu.neu.ccs.pyramid.classification.lkboost.LKBoost) LogisticLoss(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticLoss) JavaRDD(org.apache.spark.api.java.JavaRDD) Broadcast(org.apache.spark.broadcast.Broadcast) LKBoostOptimizer(edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer) RegTreeConfig(edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig) Serializable(java.io.Serializable) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) KLDivergence(edu.neu.ccs.pyramid.eval.KLDivergence) List(java.util.List) Logger(org.apache.logging.log4j.Logger) ElasticNetLogisticTrainer(edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer) Entropy(edu.neu.ccs.pyramid.eval.Entropy) Vector(org.apache.mahout.math.Vector) LogManager(org.apache.logging.log4j.LogManager) LKBOutputCalculator(edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator) ArrayList(java.util.ArrayList) Classifier(edu.neu.ccs.pyramid.classification.Classifier) MultiLabelClfDataSet(edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)

Example 19 with LogisticRegression

use of edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression in project pyramid by cheng-li.

the class SparseCBMOptimzer method updateBinaryLogisticRegression.

private void updateBinaryLogisticRegression(int componentIndex, int labelIndex) {
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    double effectivePositives = effectivePositives(componentIndex, labelIndex);
    StringBuilder sb = new StringBuilder();
    sb.append("for component ").append(componentIndex).append(", label ").append(labelIndex);
    sb.append(", effective positives = ").append(effectivePositives);
    if (effectivePositives <= 1) {
        double positiveProb = prior(componentIndex, labelIndex);
        double[] probs = { 1 - positiveProb, positiveProb };
        cbm.binaryClassifiers[componentIndex][labelIndex] = new PriorProbClassifier(probs);
        sb.append(", skip, use prior = ").append(positiveProb);
        sb.append(", time spent = " + stopWatch.toString());
        System.out.println(sb.toString());
        return;
    }
    if (cbm.binaryClassifiers[componentIndex][labelIndex] == null || cbm.binaryClassifiers[componentIndex][labelIndex] instanceof PriorProbClassifier) {
        cbm.binaryClassifiers[componentIndex][labelIndex] = new LogisticRegression(2, dataSet.getNumFeatures());
    }
    RidgeLogisticOptimizer ridgeLogisticOptimizer;
    int[] binaryLabels = DataSetUtil.toBinaryLabels(dataSet.getMultiLabels(), labelIndex);
    // no parallelism
    ridgeLogisticOptimizer = new RidgeLogisticOptimizer((LogisticRegression) cbm.binaryClassifiers[componentIndex][labelIndex], dataSet, binaryLabels, activeGammas, priorVarianceBinary, false);
    //TODO maximum iterations
    ridgeLogisticOptimizer.getOptimizer().getTerminator().setMaxIteration(numBinaryUpdates);
    ridgeLogisticOptimizer.optimize();
    sb.append(", time spent = " + stopWatch.toString());
    System.out.println(sb.toString());
}
Also used : PriorProbClassifier(edu.neu.ccs.pyramid.classification.PriorProbClassifier) LogisticRegression(edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression) RidgeLogisticOptimizer(edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer) StopWatch(org.apache.commons.lang3.time.StopWatch)

Aggregations

LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)19 ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)7 File (java.io.File)7 RidgeLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer)6 ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)5 RidgeLogisticOptimizer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticOptimizer)5 Vector (org.apache.mahout.math.Vector)5 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)3 Classifier (edu.neu.ccs.pyramid.classification.Classifier)2 MultiLabel (edu.neu.ccs.pyramid.dataset.MultiLabel)2 ArrayList (java.util.ArrayList)2 List (java.util.List)2 IntStream (java.util.stream.IntStream)2 StopWatch (org.apache.commons.lang3.time.StopWatch)2 DenseVector (org.apache.mahout.math.DenseVector)2 ClassProbability (edu.neu.ccs.pyramid.classification.ClassProbability)1 PredictionAnalysis (edu.neu.ccs.pyramid.classification.PredictionAnalysis)1 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)1 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)1 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)1