use of edu.neu.ccs.pyramid.multilabel_classification.crf.SamplingPredictor in project pyramid by cheng-li.
the class MultiLabelSynthesizer method crfSample.
public static MultiLabelClfDataSet crfSample() {
int numData = 10000;
int numClass = 4;
int numFeature = 2;
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numFeatures(numFeature).numClasses(numClass).numDataPoints(numData).build();
List<MultiLabel> support = Enumerator.enumerate(numClass);
CMLCRF cmlcrf = new CMLCRF(numClass, numFeature, support);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(0, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(0).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(1).set(1, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(0, 10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(2).set(1, 0);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(0, -10);
cmlcrf.getWeights().getWeightsWithoutBiasForClass(3).set(1, -10);
// generate features
for (int i = 0; i < numData; i++) {
for (int j = 0; j < numFeature; j++) {
dataSet.setFeatureValue(i, j, Sampling.doubleUniform(-1, 1));
}
}
SamplingPredictor samplingPredictor = new SamplingPredictor(cmlcrf);
// assign labels
for (int i = 0; i < numData; i++) {
MultiLabel label = samplingPredictor.predict(dataSet.getRow(i));
dataSet.setLabels(i, label);
}
return dataSet;
}
Aggregations