use of edu.neu.ccs.pyramid.util.Translator in project pyramid by cheng-li.
the class DataSetUtil method toMultiClass.
public static Pair<ClfDataSet, Translator<MultiLabel>> toMultiClass(MultiLabelClfDataSet dataSet) {
int numDataPoints = dataSet.getNumDataPoints();
int numFeatures = dataSet.getNumFeatures();
List<MultiLabel> multiLabels = DataSetUtil.gatherMultiLabels(dataSet);
Translator<MultiLabel> translator = new Translator<>();
translator.addAll(multiLabels);
ClfDataSet clfDataSet = ClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).dense(dataSet.isDense()).missingValue(dataSet.hasMissingValue()).numClasses(translator.size()).build();
for (int i = 0; i < numDataPoints; i++) {
//only copy non-zero elements
Vector vector = dataSet.getRow(i);
for (Vector.Element element : vector.nonZeroes()) {
int featureIndex = element.index();
double value = element.get();
clfDataSet.setFeatureValue(i, featureIndex, value);
}
int label = translator.getIndex(dataSet.getMultiLabels()[i]);
clfDataSet.setLabel(i, label);
}
List<String> extLabels = multiLabels.stream().map(MultiLabel::toString).collect(Collectors.toList());
LabelTranslator labelTranslator = new LabelTranslator(extLabels);
clfDataSet.setLabelTranslator(labelTranslator);
clfDataSet.setFeatureList(dataSet.getFeatureList());
return new Pair<>(clfDataSet, translator);
}
Aggregations