use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class MultiLabelSynthesizer method flipOneNonUniform.
/**
* y0: w=(0,1)
* y1: w=(1,1)
* y2: w=(1,0)
* y3: w=(1,-1)
* @param numData
* @return
*/
public static MultiLabelClfDataSet flipOneNonUniform(int numData) {
int numClass = 4;
int numFeature = 2;
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
// generate weights
Vector[] weights = new Vector[numClass];
for (int k = 0; k < numClass; k++) {
Vector vector = new DenseVector(numFeature);
weights[k] = vector;
}
weights[0].set(0, 0);
weights[0].set(1, 1);
weights[1].set(0, 1);
weights[1].set(1, 1);
weights[2].set(0, 1);
weights[2].set(1, 0);
weights[3].set(0, 1);
weights[3].set(1, -1);
// generate features
for (int i = 0; i < numData; i++) {
for (int j = 0; j < numFeature; j++) {
dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
}
}
// assign labels
for (int i = 0; i < numData; i++) {
for (int k = 0; k < numClass; k++) {
double dot = weights[k].dot(dataSet.getRow(i));
if (dot >= 0) {
dataSet.addLabel(i, k);
}
}
}
int[] indices = { 0, 1, 2, 3 };
double[] probs = { 0.4, 0.2, 0.2, 0.2 };
IntegerDistribution distribution = new EnumeratedIntegerDistribution(indices, probs);
// flip
for (int i = 0; i < numData; i++) {
int toChange = distribution.sample();
MultiLabel label = dataSet.getMultiLabels()[i];
if (label.matchClass(toChange)) {
label.removeLabel(toChange);
} else {
label.addLabel(toChange);
}
}
return dataSet;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class Meka2Trec method main.
/**
* this is only support multi-label classification dataset.
* @param args
*/
public static void main(String[] args) throws IOException {
if (args.length != 1) {
throw new IllegalArgumentException("Please specify a properties file.");
}
Config config = new Config(args[0]);
System.out.println(config);
List<String> trecs = config.getStrings("trec");
List<String> mekas = config.getStrings("meka");
int numLabels = config.getInt("numLabels");
int numFeatures = config.getInt("numFeatures");
String dataMode = config.getString("dataMode");
for (int i = 0; i < mekas.size(); i++) {
System.out.println("processing on: " + trecs.get(i));
MultiLabelClfDataSet dataSet = MekaFormat.loadMLClfDataset(mekas.get(i), numFeatures, numLabels, dataMode);
TRECFormat.save(dataSet, trecs.get(i));
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class AdaBoostMHInspector method analyzePrediction.
/**
* can be binary scaling or across-class scaling
* @param boosting
* @param scaling
* @param dataSet
* @param dataPointIndex
* @param classes
* @param limit
* @return
*/
public static MultiLabelPredictionAnalysis analyzePrediction(AdaBoostMH boosting, MultiLabelClassifier.ClassProbEstimator scaling, MultiLabelClfDataSet dataSet, int dataPointIndex, List<Integer> classes, int limit) {
MultiLabelPredictionAnalysis predictionAnalysis = new MultiLabelPredictionAnalysis();
LabelTranslator labelTranslator = dataSet.getLabelTranslator();
IdTranslator idTranslator = dataSet.getIdTranslator();
predictionAnalysis.setInternalId(dataPointIndex);
predictionAnalysis.setId(idTranslator.toExtId(dataPointIndex));
predictionAnalysis.setInternalLabels(dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered());
List<String> labels = dataSet.getMultiLabels()[dataPointIndex].getMatchedLabelsOrdered().stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setLabels(labels);
double probForTrueLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForTrueLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), dataSet.getMultiLabels()[dataPointIndex]);
}
predictionAnalysis.setProbForTrueLabels(probForTrueLabels);
MultiLabel predictedLabels = boosting.predict(dataSet.getRow(dataPointIndex));
List<Integer> internalPrediction = predictedLabels.getMatchedLabelsOrdered();
predictionAnalysis.setInternalPrediction(internalPrediction);
List<String> prediction = internalPrediction.stream().map(labelTranslator::toExtLabel).collect(Collectors.toList());
predictionAnalysis.setPrediction(prediction);
double probForPredictedLabels = Double.NaN;
if (scaling instanceof MultiLabelClassifier.AssignmentProbEstimator) {
probForPredictedLabels = ((MultiLabelClassifier.AssignmentProbEstimator) scaling).predictAssignmentProb(dataSet.getRow(dataPointIndex), predictedLabels);
}
predictionAnalysis.setProbForPredictedLabels(probForPredictedLabels);
List<ClassScoreCalculation> classScoreCalculations = new ArrayList<>();
for (int k : classes) {
ClassScoreCalculation classScoreCalculation = decisionProcess(boosting, scaling, labelTranslator, dataSet.getRow(dataPointIndex), k, limit);
classScoreCalculations.add(classScoreCalculation);
}
predictionAnalysis.setClassScoreCalculations(classScoreCalculations);
List<MultiLabelPredictionAnalysis.ClassRankInfo> ranking = classes.stream().map(label -> {
MultiLabelPredictionAnalysis.ClassRankInfo rankInfo = new MultiLabelPredictionAnalysis.ClassRankInfo();
rankInfo.setClassIndex(label);
rankInfo.setClassName(labelTranslator.toExtLabel(label));
rankInfo.setProb(scaling.predictClassProb(dataSet.getRow(dataPointIndex), label));
return rankInfo;
}).collect(Collectors.toList());
return predictionAnalysis;
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class SparkCBMOptimizer method updateBinaryClassifiers.
private void updateBinaryClassifiers() {
if (logger.isDebugEnabled()) {
logger.debug("start updateBinaryClassifiers");
}
Classifier.ProbabilityEstimator[][] localBinaryClassifiers = cbm.binaryClassifiers;
double[][] localGammasT = gammasT;
Broadcast<MultiLabelClfDataSet> localDataSetBroadcast = dataSetBroadCast;
Broadcast<double[][][]> localTargetsBroadcast = targetDisBroadCast;
double localVariance = priorVarianceBinary;
List<BinaryTask> binaryTaskList = new ArrayList<>();
for (int k = 0; k < cbm.numComponents; k++) {
for (int l = 0; l < cbm.numLabels; l++) {
LogisticRegression logisticRegression = (LogisticRegression) localBinaryClassifiers[k][l];
double[] weights = localGammasT[k];
binaryTaskList.add(new BinaryTask(k, l, logisticRegression, weights));
}
}
JavaRDD<BinaryTask> binaryTaskRDD = sparkContext.parallelize(binaryTaskList, binaryTaskList.size());
List<BinaryTaskResult> results = binaryTaskRDD.map(binaryTask -> {
int labelIndex = binaryTask.classIndex;
return updateBinaryLogisticRegression(binaryTask.componentIndex, binaryTask.classIndex, binaryTask.logisticRegression, localDataSetBroadcast.value(), binaryTask.weights, localTargetsBroadcast.value()[labelIndex], localVariance);
}).collect();
for (BinaryTaskResult result : results) {
cbm.binaryClassifiers[result.componentIndex][result.classIndex] = result.binaryClassifier;
}
// IntStream.range(0, cbm.numComponents).forEach(this::updateBinaryClassifiers);
if (logger.isDebugEnabled()) {
logger.debug("finish updateBinaryClassifiers");
}
}
use of edu.neu.ccs.pyramid.dataset.MultiLabelClfDataSet in project pyramid by cheng-li.
the class CBMTest method test3.
private static void test3() throws Exception {
MultiLabelClfDataSet dataSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/train.trec"), DataSetType.ML_CLF_SPARSE, true);
MultiLabelClfDataSet testSet = TRECFormat.loadMultiLabelClfDataSet(new File(DATASETS, "ohsumed/3/test.trec"), DataSetType.ML_CLF_SPARSE, true);
int numComponents = 4;
CBM cbm = CBM.getBuilder().setNumClasses(dataSet.getNumClasses()).setNumFeatures(dataSet.getNumFeatures()).setNumComponents(numComponents).setBinaryClassifierType("lr").setMultiClassClassifierType("boost").build();
cbm.setPredictMode("dynamic");
CBMOptimizer optimizer = new CBMOptimizer(cbm, dataSet);
optimizer.setPriorVarianceBinary(10);
optimizer.setPriorVarianceMultiClass(10);
CBMInitializer.initialize(cbm, dataSet, optimizer);
cbm.setNumSample(100);
System.out.println("num cluster: " + cbm.numComponents);
System.out.println("after initialization");
System.out.println("train acc = " + Accuracy.accuracy(cbm, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(cbm, testSet));
for (int i = 1; i <= 30; i++) {
optimizer.iterate();
System.out.print("iter : " + i + "\t");
System.out.print("objective: " + optimizer.getTerminator().getLastValue() + "\t");
System.out.print("trainAcc : " + Accuracy.accuracy(cbm, dataSet) + "\t");
System.out.print("trainOver: " + Overlap.overlap(cbm, dataSet) + "\t");
System.out.print("testAcc : " + Accuracy.accuracy(cbm, testSet) + "\t");
System.out.println("testOver : " + Overlap.overlap(cbm, testSet) + "\t");
}
System.out.println("history = " + optimizer.getTerminator().getHistory());
System.out.println(cbm);
}
Aggregations