Search in sources :

Example 6 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class Merger method mergeClfData.

private static void mergeClfData(Config config) throws Exception {
    String input1 = config.getString("input.data1");
    String input2 = config.getString("input.data2");
    String output = config.getString("output.data");
    ClfDataSet dataSet1 = TRECFormat.loadClfDataSet(input1, DataSetType.CLF_DENSE, true);
    ClfDataSet dataSet2 = TRECFormat.loadClfDataSet(input2, DataSetType.CLF_DENSE, true);
    ClfDataSet merged = DataSetUtil.concatenateByRow(dataSet1, dataSet2);
    TRECFormat.save(merged, output);
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet)

Example 7 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class Trec2LibSvm method main.

public static void main(String[] args) throws Exception {
    Config config = new Config(args[0]);
    System.out.println(config);
    List<String> trecs = config.getStrings("trec");
    List<String> libSVMs = config.getStrings("libSVM");
    for (int i = 0; i < trecs.size(); i++) {
        ClfDataSet trecDataset = TRECFormat.loadClfDataSet(new File(trecs.get(i)), DataSetType.CLF_SPARSE, false);
        System.out.println(i + " -- Translating on trecs: " + trecs.get(i));
        LibSvmFormat.save(trecDataset, libSVMs.get(i));
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File)

Example 8 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class Trec2Matlab method main.

public static void main(String[] args) throws Exception {
    Config config = new Config(args[0]);
    File trecFile = new File(config.getString("input.trecFile"));
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(trecFile, DataSetType.CLF_SPARSE, false);
    File matlabFile = new File(config.getString("output.matlabFile"));
    matlabFile.getParentFile().mkdirs();
    try (BufferedWriter bw = new BufferedWriter(new FileWriter(matlabFile))) {
        for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
            Vector vector = dataSet.getRow(i);
            for (Vector.Element element : vector.nonZeroes()) {
                int j = element.index();
                double value = element.get();
                bw.write("" + (i + 1));
                bw.write("\t");
                bw.write("" + (j + 1));
                bw.write("\t");
                bw.write("" + value);
                bw.newLine();
            }
        }
    }
}
Also used : Config(edu.neu.ccs.pyramid.configuration.Config) ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) FileWriter(java.io.FileWriter) File(java.io.File) Vector(org.apache.mahout.math.Vector) BufferedWriter(java.io.BufferedWriter)

Example 9 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class RidgeLogisticOptimizerTest method test4.

private static void test4() throws Exception {
    //        ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"),
    //                DataSetType.CLF_SPARSE, true);
    //        ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"),
    //                DataSetType.CLF_SPARSE, true);
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
    //        ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"),
    //                DataSetType.CLF_SPARSE, true);
    //        ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"),
    //                DataSetType.CLF_SPARSE, true);
    double variance = 1000;
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    double[] weights = new double[dataSet.getNumDataPoints()];
    for (int i = 0; i < weights.length; i++) {
        if (Math.random() < 0.1) {
            weights[i] = 0;
        } else {
            weights[i] = 1;
        }
    }
    RidgeLogisticOptimizer optimizer = new RidgeLogisticOptimizer(logisticRegression, dataSet, dataSet.getLabels(), weights, variance, true);
    System.out.println("after initialization");
    System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
    StopWatch stopWatch = new StopWatch();
    stopWatch.start();
    for (int i = 0; i < 20; i++) {
        ((LBFGS) optimizer.getOptimizer()).iterate();
        System.out.println("after iteration " + i);
        System.out.println(stopWatch);
    //            System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
    //            System.out.println("test acc = "+Accuracy.accuracy(logisticRegression,testSet));
    //            System.out.println(logisticRegression);
    }
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File) StopWatch(org.apache.commons.lang3.time.StopWatch)

Example 10 with ClfDataSet

use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.

the class RidgeLogisticOptimizerTest method test1.

private static void test1() throws Exception {
    //        ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"),
    //                DataSetType.CLF_SPARSE, true);
    //        ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"),
    //                DataSetType.CLF_SPARSE, true);
    ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
    ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
    //        ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"),
    //                DataSetType.CLF_SPARSE, true);
    //        ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"),
    //                DataSetType.CLF_SPARSE, true);
    double variance = 1000;
    LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
    RidgeLogisticOptimizer optimizer = new RidgeLogisticOptimizer(logisticRegression, dataSet, variance, true);
    optimizer.getOptimizer().getTerminator().setMaxIteration(10000).setMode(Terminator.Mode.STANDARD);
    System.out.println("after initialization");
    System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
    optimizer.optimize();
    System.out.println("after training");
    System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
    System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
    System.out.println(optimizer.getOptimizer().getTerminator().getHistory());
    System.out.println(logisticRegression);
}
Also used : ClfDataSet(edu.neu.ccs.pyramid.dataset.ClfDataSet) File(java.io.File)

Aggregations

ClfDataSet (edu.neu.ccs.pyramid.dataset.ClfDataSet)35 File (java.io.File)31 StopWatch (org.apache.commons.lang3.time.StopWatch)8 LogisticRegression (edu.neu.ccs.pyramid.classification.logistic_regression.LogisticRegression)7 RidgeLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.RidgeLogisticTrainer)6 LKBoost (edu.neu.ccs.pyramid.classification.lkboost.LKBoost)2 Config (edu.neu.ccs.pyramid.configuration.Config)2 DataSetType (edu.neu.ccs.pyramid.dataset.DataSetType)2 ConjugateGradientDescent (edu.neu.ccs.pyramid.optimization.ConjugateGradientDescent)2 RegTreeConfig (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeConfig)2 RegTreeFactory (edu.neu.ccs.pyramid.regression.regression_tree.RegTreeFactory)2 PriorProbClassifier (edu.neu.ccs.pyramid.classification.PriorProbClassifier)1 LKBOutputCalculator (edu.neu.ccs.pyramid.classification.lkboost.LKBOutputCalculator)1 LKBoostOptimizer (edu.neu.ccs.pyramid.classification.lkboost.LKBoostOptimizer)1 ElasticNetLogisticTrainer (edu.neu.ccs.pyramid.classification.logistic_regression.ElasticNetLogisticTrainer)1 MultiLabelClfDataSet (edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet)1 GradientDescent (edu.neu.ccs.pyramid.optimization.GradientDescent)1 LBFGS (edu.neu.ccs.pyramid.optimization.LBFGS)1 BufferedWriter (java.io.BufferedWriter)1 FileWriter (java.io.FileWriter)1