use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFElasticNet method calEmpiricalCountForLabelPair.
private double calEmpiricalCountForLabelPair(int parameterIndex) {
double empiricalCount = 0.0;
int start = parameterIndex - numWeightsForFeatures;
int l1 = parameterToL1[start];
int l2 = parameterToL2[start];
int featureCase = start % 4;
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
MultiLabel label = dataSet.getMultiLabels()[i];
switch(featureCase) {
// both l1, l2 equal 0;
case 0:
if (!label.matchClass(l1) && !label.matchClass(l2))
empiricalCount += 1.0;
break;
// l1 = 1; l2 = 0;
case 1:
if (label.matchClass(l1) && !label.matchClass(l2))
empiricalCount += 1.0;
break;
// l1 = 0; l2 = 1;
case 2:
if (!label.matchClass(l1) && label.matchClass(l2))
empiricalCount += 1.0;
break;
// l1 = 1; l2 = 1;
case 3:
if (label.matchClass(l1) && label.matchClass(l2))
empiricalCount += 1.0;
break;
default:
throw new RuntimeException("feature case :" + featureCase + " failed.");
}
}
return empiricalCount;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRF method computeLabelPartScore.
/**
* the part of score which depends only on labels
* for each label pair, exactly one feature function returns 1
* @return
*/
double computeLabelPartScore(int labelComIndex) {
MultiLabel label = supportCombinations.get(labelComIndex);
double score = 0;
int pos = this.weights.getNumWeightsForFeatures();
boolean[] matches = new boolean[numClasses];
for (int match : label.getMatchedLabels()) {
matches[match] = true;
}
for (int l1 = 0; l1 < numClasses; l1++) {
for (int l2 = l1 + 1; l2 < numClasses; l2++) {
if (!matches[l1] && !matches[l2]) {
score += this.weights.getWeightForIndex(pos);
} else if (matches[l1] && !matches[l2]) {
score += this.weights.getWeightForIndex(pos + 1);
} else if (!matches[l1] && matches[l2]) {
score += this.weights.getWeightForIndex(pos + 2);
} else {
score += this.weights.getWeightForIndex(pos + 3);
}
pos += 4;
}
}
return score;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class KLLoss method initTargMarginals.
private void initTargMarginals(int dataPoint) {
double[] joint = targetDistribution[dataPoint];
for (int c = 0; c < joint.length; c++) {
MultiLabel multiLabel = supportedCombinations.get(c);
double prob = joint[c];
for (int l : multiLabel.getMatchedLabels()) {
targetMarginals[dataPoint][l] += prob;
}
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFTest method test2.
public static void test2() throws Exception {
System.out.println(config);
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_DENSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_DENSE, true);
double gaussianVariance = config.getDouble("gaussianVariance");
// loading or save model infos.
String output = config.getString("output");
String modelName = config.getString("modelName");
CMLCRF cmlcrf;
MultiLabel[] predTrain;
MultiLabel[] predTest;
if (config.getBoolean("train.warmStart")) {
cmlcrf = CMLCRF.deserialize(new File(output, modelName));
System.out.println("loading model:");
System.out.println(cmlcrf);
} else {
cmlcrf = new CMLCRF(trainSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, trainSet, gaussianVariance);
if (config.getBoolean("isLBFGS")) {
LBFGS optimizer = new LBFGS(crfLoss);
optimizer.getTerminator().setAbsoluteEpsilon(0.1);
for (int i = 0; i < config.getInt("numRounds"); i++) {
optimizer.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
}
} else {
GradientDescent optimizer = new GradientDescent(crfLoss);
for (int i = 0; i < config.getInt("numRounds"); i++) {
optimizer.iterate();
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("iter: " + String.format("%04d", i));
System.out.print("\tTrain acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
}
}
}
System.out.println();
System.out.println();
System.out.println("--------------------------------Results-----------------------------\n");
predTrain = cmlcrf.predict(trainSet);
predTest = cmlcrf.predict(testSet);
System.out.print("Train acc: " + String.format("%.4f", Accuracy.accuracy(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTrain overlap " + String.format("%.4f", Overlap.overlap(trainSet.getMultiLabels(), predTrain)));
System.out.print("\tTest acc: " + String.format("%.4f", Accuracy.accuracy(testSet.getMultiLabels(), predTest)));
System.out.println("\tTest overlap " + String.format("%.4f", Overlap.overlap(testSet.getMultiLabels(), predTest)));
if (config.getBoolean("saveModel")) {
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabel in project pyramid by cheng-li.
the class CMLCRFTest method test7.
private static void test7() throws Exception {
System.out.println(config);
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.trainData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(config.getString("input.testData"), DataSetType.ML_CLF_SEQ_SPARSE, true);
// loading or save model infos.
String output = config.getString("output");
String modelName = config.getString("modelName");
CMLCRF cmlcrf = null;
if (config.getString("train.warmStart").equals("true")) {
cmlcrf = CMLCRF.deserialize(new File(output, modelName));
System.out.println("loading model:");
System.out.println(cmlcrf);
} else if (config.getString("train.warmStart").equals("auto")) {
cmlcrf = CMLCRF.deserialize(new File(output, modelName));
System.out.println("retrain model:");
CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
} else if (config.getString("train.warmStart").equals("false")) {
cmlcrf = new CMLCRF(trainSet);
cmlcrf.setConsiderPair(config.getBoolean("considerLabelPair"));
CMLCRFElasticNet cmlcrfElasticNet = new CMLCRFElasticNet(cmlcrf, trainSet, config.getDouble("l1Ratio"), config.getDouble("regularization"));
train(cmlcrfElasticNet, cmlcrf, trainSet, testSet, config);
}
System.out.println();
System.out.println();
System.out.println("--------------------------------Results-----------------------------\n");
MLMeasures measures = new MLMeasures(cmlcrf, trainSet);
System.out.println("========== Train ==========\n");
System.out.println(measures);
System.out.println("========== Test ==========\n");
long startTimePred = System.nanoTime();
MultiLabel[] preds = cmlcrf.predict(testSet);
long stopTimePred = System.nanoTime();
long predTime = stopTimePred - startTimePred;
System.out.println("\nprediction time: " + TimeUnit.NANOSECONDS.toSeconds(predTime) + " sec.");
System.out.println(new MLMeasures(cmlcrf, testSet));
System.out.println("\n\n");
InstanceF1Predictor pluginF1 = new InstanceF1Predictor(cmlcrf);
System.out.println("Plugin F1");
System.out.println(new MLMeasures(pluginF1, testSet));
if (config.getBoolean("saveModel")) {
(new File(output)).mkdirs();
File serializeModel = new File(output, modelName);
cmlcrf.serialize(serializeModel);
}
}
Aggregations