use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "spam/trec_data/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
cmlcrf.setConsiderPair(true);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(crfLoss);
for (int i = 0; i < 5000; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(crfLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
// LBFGS optimizer = new LBFGS(crfLoss);
// optimizer.getTerminator().setAbsoluteEpsilon(0.01);
// optimizer.optimize();
// predTrain = cmlcrf.predict(dataSet);
// predTest = cmlcrf.predict(testSet);
// System.out.print("Train acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
// System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
// System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
// System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFTest method test8.
private static void test8() throws Exception {
System.out.println(config);
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
// loading or save model infos.
String output = config.getString("output");
String modelName = config.getString("modelName");
CMLCRF cmlcrf = new CMLCRF(trainSet);
BlockwiseCD blockwiseCD = new BlockwiseCD(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
MultiLabel[] predTrain;
MultiLabel[] predTest;
for (int i = 0; i < 10000; i++) {
blockwiseCD.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tobjective: " + String.format("%.4f", blockwiseCD.getValue()));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain F1 " + String.format("%.4f", FMeasure.f1(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.print("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest F1 " + String.format("%.4f", FMeasure.f1(testSet.getMultiLabels(), predTest)));
}
System.out.println();
System.out.println();
System.out.println("--------------------------------Results-----------------------------\n");
MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
System.out.println("========== Train ==========\n");
System.out.println(measures);
System.out.println("========== Test ==========\n");
long startTimePred = System.nanoTime();
MultiLabel[] preds = cmlcrf.predict(testSet);
long stopTimePred = System.nanoTime();
long predTime = stopTimePred - startTimePred;
System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
System.out.println(new MLMeasures(cmlcrf, testSet));
System.out.println("\n\n");
InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
System.out.println("Plugin F1");
System.out.println(new MLMeasures(pluginF1, testSet));
if (config.getBoolean("saveModel")) {
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFTest method test4.
private static void test4() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(crfLoss);
for (int i = 0; i < 50; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(crfLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class BMSelector method selectAll.
public static Pair<BM, double[][]> selectAll(int numClasses, MultiLabel[] multiLabels, int numClusters) {
DataSet dataSet = DataSetBuilder.getBuilder().numDataPoints(multiLabels.length).numFeatures(numClasses).density(Density.SPARSE_RANDOM).build();
for (int i = 0; i < multiLabels.length; i++) {
MultiLabel multiLabel = multiLabels[i];
for (int label : multiLabel.getMatchedLabels()) {
dataSet.setFeatureValue(i, label, 1);
}
}
BMTrainer trainer = BMSelector.selectTrainer(dataSet, numClusters, 10);
// System.out.println("bm = "+trainer.bm);
// System.out.println("gamma = "+ Arrays.deepToString(trainer.gammas));
Pair<BM, double[][]> pair = new Pair<>();
pair.setFirst(trainer.getBm());
pair.setSecond(trainer.gammas);
return pair;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CBM method predictByMarginals.
/**
* sort marginals, and keep top few
* @param vector
* @param top
* @return
*/
public MultiLabel predictByMarginals(Vector vector, int top) {
double[] probs = predictClassProbs(vector);
int[] sortedIndices = ArgSort.argSortDescending(probs);
MultiLabel prediction = new MultiLabel();
for (int i = 0; i < top; i++) {
prediction.addLabel(sortedIndices[i]);
}
return prediction;
}
Aggregations