use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class AugmentedLRLossTest method main.
public static void main(String[] args) throws Exception {
LoggerContext ctx = (LoggerContext) LogManager.getContext(false);
Configuration config = ctx.getConfiguration();
LoggerConfig loggerConfig = config.getLoggerConfig(LogManager.ROOT_LOGGER_NAME);
loggerConfig.setLevel(Level.DEBUG);
ctx.updateLoggers();
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/train"), DataSetType.ML_CLF_DENSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "scene/test"), DataSetType.ML_CLF_DENSE, true);
AugmentedLR augmentedLR = new AugmentedLR(dataSet.getNumFeatures(), 1);
double[][] gammas = new double[dataSet.getNumDataPoints()][1];
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
gammas[i][0] = 1;
}
AugmentedLRLoss loss = new AugmentedLRLoss(dataSet, 0, gammas, augmentedLR, 1, 1);
LBFGS lbfgs = new LBFGS(loss);
for (int i = 0; i < 100; i++) {
lbfgs.iterate();
System.out.println(loss.getValue());
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMInitializerTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet trainSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "/flags/train"), DataSetType.ML_CLF_SPARSE, true);
int numClusters = 2;
double softmaxVariance = 1000;
double logitVariance = 1000;
CBM cbm = CBM.getBuilder().setNumClasses(trainSet.getNumClasses()).setNumFeatures(trainSet.getNumFeatures()).setNumComponents(numClusters).setBinaryClassifierType("lr").setMultiClassClassifierType("lr").build();
CBMOptimizer optimizer = new CBMOptimizer(cbm, trainSet);
CBMInitializer.initialize(cbm, trainSet, optimizer);
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMInspectorTest method test1.
private static void test1() throws Exception {
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "meka_imdb/1/data_sets/test"), DataSetType.ML_CLF_SPARSE, true);
CBM CBM = (CBM) Serialization.deserialize(new File(TMP, "model"));
System.out.println(Accuracy.accuracy(CBM, testSet));
for (int i = 0; i < testSet.getNumDataPoints(); i++) {
MultiLabel trueLabel = testSet.getMultiLabels()[i];
MultiLabel pred = CBM.predict(testSet.getRow(i));
MultiLabel expectation = CBM.predictByMarginals(testSet.getRow(i));
if (pred.equals(trueLabel) && !pred.equals(expectation) && expectation.getMatchedLabels().size() > 0) {
System.out.println("==============================");
System.out.println("data point " + i);
System.out.println("prediction = " + pred);
System.out.println("expectation = " + expectation);
CBMInspector.covariance(CBM, testSet.getRow(i), testSet.getLabelTranslator());
}
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMTest method test2.
private static void test2() throws Exception {
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(1000).build();
BernoulliDistribution bernoulliDistribution = new BernoulliDistribution(0.5);
for (int n = 0; n < dataSet.getNumDataPoints(); n++) {
for (int m = 0; m < dataSet.getNumFeatures(); m++) {
int bit = bernoulliDistribution.sample();
int flip = bit;
if (Math.random() < 0.1) {
flip = 1 - bit;
}
dataSet.setFeatureValue(n, m, bit);
if (m == 0) {
if (flip == 0) {
dataSet.addLabel(n, 0);
} else {
dataSet.addLabel(n, 1);
}
} else {
if (flip == 0) {
dataSet.addLabel(n, 2);
} else {
dataSet.addLabel(n, 3);
}
}
}
}
MultiLabelClfDataSet testSet = MLClfDataSetBuilder.getBuilder().numFeatures(2).numClasses(4).numDataPoints(100).build();
for (int n = 0; n < testSet.getNumDataPoints(); n++) {
for (int m = 0; m < testSet.getNumFeatures(); m++) {
int bit = bernoulliDistribution.sample();
testSet.setFeatureValue(n, m, bit);
int flip = bit;
if (Math.random() < 0.1) {
flip = 1 - bit;
}
if (m == 0) {
if (flip == 0) {
testSet.addLabel(n, 0);
} else {
testSet.addLabel(n, 1);
}
} else {
if (flip == 0) {
testSet.addLabel(n, 2);
} else {
testSet.addLabel(n, 3);
}
}
}
}
int numComponents = 4;
CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setBinaryClassifierType("boost").setMultiClassClassifierType("boost").build();
cbm.setPredictMode("dynamic");
CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
optimizer.setPriorVarianceBinary(10);
optimizer.setPriorVarianceMultiClass(10);
CBMInitializer.initialize(cbm, dataSet, optimizer);
for (int i = 0; i < 3; i++) {
optimizer.iterate();
System.out.print("i: " + i + "\t");
System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
System.out.print("trainAcc: " + Accuracy.accuracy(cbm, dataSet) + "\t");
System.out.println("testAcc: " + Accuracy.accuracy(cbm, testSet));
}
System.out.println(cbm.toString());
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CMLCRFTest method test5.
private static void test5() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
CMLCRF cmlcrf = new CMLCRF(dataSet);
CRFLoss crfLoss = new CRFLoss(cmlcrf, dataSet, 1);
cmlcrf.setConsiderPair(false);
MultiLabel[] predTrain;
MultiLabel[] predTest;
LBFGS optimizer = new LBFGS(crfLoss);
for (int i = 0; i < 5; i++) {
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer.iterate();
System.out.println(crfLoss.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
CRFLoss crfLoss2 = new CRFLoss(cmlcrf, dataSet, 1);
cmlcrf.setConsiderPair(true);
LBFGS optimizer2 = new LBFGS(crfLoss2);
for (int i = 0; i < 50; i++) {
System.out.println("consider pairs");
// System.out.print("Obj: " + optimizer.getTerminator().getLastValue());
System.out.println("iter: " + i);
optimizer2.iterate();
System.out.println(crfLoss2.getValue());
predTrain = cmlcrf.predict(dataSet);
predTest = cmlcrf.predict(testSet);
System.out.print("\tTrain acc: " + Accuracy.accuracy(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTrain overlap " + Overlap.overlap(dataSet.getMultiLabels(), predTrain));
System.out.print("\tTest acc: " + Accuracy.accuracy(testSet.getMultiLabels(), predTest));
System.out.println("\tTest overlap " + Overlap.overlap(testSet.getMultiLabels(), predTest));
// System.out.println("crf = "+cmlcrf.getWeights());
// System.out.println(Arrays.toString(predTrain));
}
}
Aggregations