use of edu.neu.ccs.pyramid.feature.FeatureList in project pyramid by cheng-li.
the class DataSetUtil method concatenateByColumn.
public static MultiLabelClfDataSet concatenateByColumn(MultiLabelClfDataSet dataSet1, MultiLabelClfDataSet dataSet2) {
int numDataPoints = dataSet1.getNumDataPoints();
int numFeatures1 = dataSet1.getNumFeatures();
int numFeatures2 = dataSet2.getNumFeatures();
int numFeatures = numFeatures1 + numFeatures2;
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numDataPoints).numFeatures(numFeatures).numClasses(dataSet1.getNumClasses()).density(dataSet1.density()).missingValue(dataSet1.hasMissingValue()).build();
int featureIndex = 0;
for (int j = 0; j < numFeatures1; j++) {
Vector vector = dataSet1.getColumn(j);
for (Vector.Element element : vector.nonZeroes()) {
int i = element.index();
double value = element.get();
dataSet.setFeatureValue(i, featureIndex, value);
}
featureIndex += 1;
}
for (int j = 0; j < numFeatures2; j++) {
Vector vector = dataSet2.getColumn(j);
for (Vector.Element element : vector.nonZeroes()) {
int i = element.index();
double value = element.get();
dataSet.setFeatureValue(i, featureIndex, value);
}
featureIndex += 1;
}
MultiLabel[] labels = dataSet1.getMultiLabels();
for (int i = 0; i < numDataPoints; i++) {
dataSet.setLabels(i, labels[i]);
}
FeatureList featureList = new FeatureList();
for (Feature feature : dataSet1.getFeatureList().getAll()) {
featureList.add(feature);
}
for (Feature feature : dataSet2.getFeatureList().getAll()) {
featureList.add(feature);
}
dataSet.setFeatureList(featureList);
dataSet.setLabelTranslator(dataSet1.getLabelTranslator());
dataSet.setIdTranslator(dataSet1.getIdTranslator());
return dataSet;
}
use of edu.neu.ccs.pyramid.feature.FeatureList in project pyramid by cheng-li.
the class DataSetUtil method sampleFeatures.
/**
* only keep the selected featureList
* @param dataSet
* @return
*/
public static MultiLabelClfDataSet sampleFeatures(MultiLabelClfDataSet dataSet, List<Integer> columnsToKeep) {
MultiLabelClfDataSet trimmed;
boolean missingValue = dataSet.hasMissingValue();
int numClasses = dataSet.getNumClasses();
// keep density
if (dataSet.isDense()) {
trimmed = new DenseMLClfDataSet(dataSet.getNumDataPoints(), columnsToKeep.size(), missingValue, numClasses);
} else {
trimmed = new SparseMLClfDataSet(dataSet.getNumDataPoints(), columnsToKeep.size(), missingValue, numClasses);
}
for (int j = 0; j < trimmed.getNumFeatures(); j++) {
int oldColumnIndex = columnsToKeep.get(j);
Vector vector = dataSet.getColumn(oldColumnIndex);
for (Vector.Element element : vector.nonZeroes()) {
int dataPointIndex = element.index();
double value = element.get();
trimmed.setFeatureValue(dataPointIndex, j, value);
}
}
//copy labels
MultiLabel[] multiLabels = dataSet.getMultiLabels();
for (int i = 0; i < trimmed.getNumDataPoints(); i++) {
trimmed.addLabels(i, multiLabels[i].getMatchedLabels());
}
//just copy settings
trimmed.setLabelTranslator(dataSet.getLabelTranslator());
trimmed.setIdTranslator(dataSet.getIdTranslator());
List<Feature> oldFeatures = dataSet.getFeatureList().getAll();
List<Feature> newFeatures = columnsToKeep.stream().map(oldFeatures::get).collect(Collectors.toList());
for (int i = 0; i < newFeatures.size(); i++) {
newFeatures.get(i).setIndex(i);
}
trimmed.setFeatureList(new FeatureList(newFeatures));
return trimmed;
}
use of edu.neu.ccs.pyramid.feature.FeatureList in project pyramid by cheng-li.
the class MLLogisticRegressionInspector method topFeatures.
public static TopFeatures topFeatures(MLLogisticRegression logisticRegression, int classIndex, int limit) {
FeatureList featureList = logisticRegression.getFeatureList();
Vector weights = logisticRegression.getWeights().getWeightsWithoutBiasForClass(classIndex);
Comparator<FeatureUtility> comparator = Comparator.comparing(FeatureUtility::getUtility);
List<Feature> list = IntStream.range(0, weights.size()).mapToObj(i -> new FeatureUtility(featureList.get(i)).setUtility(weights.get(i))).filter(featureUtility -> featureUtility.getUtility() > 0).sorted(comparator.reversed()).map(FeatureUtility::getFeature).limit(limit).collect(Collectors.toList());
TopFeatures topFeatures = new TopFeatures();
topFeatures.setTopFeatures(list);
topFeatures.setClassIndex(classIndex);
LabelTranslator labelTranslator = logisticRegression.getLabelTranslator();
topFeatures.setClassName(labelTranslator.toExtLabel(classIndex));
return topFeatures;
}
use of edu.neu.ccs.pyramid.feature.FeatureList in project pyramid by cheng-li.
the class MekaFormat method loadMLClfDatasetPre.
private static MultiLabelClfDataSet loadMLClfDatasetPre(File file, int numClasses, int numFeatures, int numData, Map<String, String> labelMap, Map<String, String> featureMap) throws IOException {
MultiLabelClfDataSet dataSet = MLClfDataSetBuilder.getBuilder().numDataPoints(numData).numClasses(numClasses).numFeatures(numFeatures).build();
// set features
List<Feature> featureList = new LinkedList<>();
for (int m = 0; m < numFeatures; m++) {
String featureIndex = Integer.toString(m + numClasses);
String featureName = featureMap.get(featureIndex);
Feature feature = new Feature();
feature.setIndex(m);
feature.setName(featureName);
featureList.add(feature);
}
dataSet.setFeatureList(new FeatureList(featureList));
// set Label
Map<Integer, String> labelIndexMap = new HashMap<>();
for (Map.Entry<String, String> entry : labelMap.entrySet()) {
String labelString = entry.getKey();
String labelName = entry.getValue();
labelIndexMap.put(Integer.parseInt(labelString), labelName);
}
LabelTranslator labelTranslator = new LabelTranslator(labelIndexMap);
dataSet.setLabelTranslator(labelTranslator);
// create feature matrix
BufferedReader br = new BufferedReader(new FileReader(file));
String line;
int dataCount = 0;
while ((line = br.readLine()) != null) {
if ((line.startsWith("{")) && (line.endsWith("}"))) {
line = line.substring(1, line.length() - 1);
String[] indexValues = line.split(",");
for (String indexValue : indexValues) {
String[] indexValuePair = indexValue.split(" ");
String index = indexValuePair[0];
String value = indexValuePair[1];
if (labelMap.containsKey(index)) {
double valueDouble = Double.parseDouble(value);
if (valueDouble == 1.0) {
dataSet.addLabel(dataCount, Integer.parseInt(index));
}
} else if (featureMap.containsKey(index)) {
double valueDouble = Double.parseDouble(value);
int indexInt = Integer.parseInt(index);
dataSet.setFeatureValue(dataCount, indexInt - numClasses, valueDouble);
} else {
throw new RuntimeException("Index not found in the line: " + line);
}
}
dataCount++;
}
}
br.close();
return dataSet;
}
Aggregations