use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class RegTreeInspector method featureImportance.
/**
* pair contains feature name and reduction
* @param tree
* @return
*/
// public static Map<Integer, Pair<String,Double>> featureImportance(RegressionTree tree){
// List<Feature> featureList = tree.getFeatureList().getAll();
// Map<Integer, Pair<String,Double>> map = new HashMap<>();
// List<Node> nodes = tree.traverse();
// nodes.stream().filter(node -> !node.isLeaf())
// .forEach(node -> {
// int featureIndex = node.getFeatureIndex();
// String featureName = featureList.get(node.getFeatureIndex()).getName();
// double reduction = node.getReduction();
// Pair<String,Double> oldPair = map.getOrDefault(featureIndex, new Pair<>(featureName,0.0));
// Pair<String, Double> newPair = new Pair<>(featureName,oldPair.getSecond()+reduction);
// map.put(featureIndex, newPair);
// });
// return map;
// }
public static Map<Feature, Double> featureImportance(RegressionTree tree) {
FeatureList featureList = tree.getFeatureList();
Map<Feature, Double> map = new HashMap<>();
List<Node> nodes = tree.traverse();
nodes.stream().filter(node -> !node.isLeaf()).forEach(node -> {
int featureIndex = node.getFeatureIndex();
Feature feature = featureList.get(featureIndex);
double reduction = node.getReduction();
double oldValue = map.getOrDefault(feature, 0.0);
double newValue = reduction + oldValue;
map.put(feature, newValue);
});
return map;
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class DataSetUtil method sampleData.
public static Pair<DataSet, double[][]> sampleData(DataSet dataSet, double[][] targetDistribution, List<Integer> indices) {
DataSet sample;
int numClasses = targetDistribution[0].length;
double[][] sampledTargets = new double[indices.size()][numClasses];
sample = DataSetBuilder.getBuilder().dense(dataSet.isDense()).missingValue(dataSet.hasMissingValue()).numDataPoints(indices.size()).numFeatures(dataSet.getNumFeatures()).build();
for (int i = 0; i < indices.size(); i++) {
int indexInOld = indices.get(i);
Vector oldVector = dataSet.getRow(indexInOld);
double[] targets = targetDistribution[indexInOld];
//copy label
sampledTargets[i] = Arrays.copyOf(targets, targets.length);
//copy row feature values, optimized for sparse vector
for (Vector.Element element : oldVector.nonZeroes()) {
sample.setFeatureValue(i, element.index(), element.get());
}
}
sample.setFeatureList(dataSet.getFeatureList());
//ignore idTranslator as we may have duplicate extIds
return new Pair<>(sample, sampledTargets);
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class ArffFormat method writeMatrixFile.
private static void writeMatrixFile(MultiLabelClfDataSet dataSet, File arffFile) {
File matrixFile = new File(arffFile, ARFF_MATRIX_FILE_NAME);
int numDataPoints = dataSet.getNumDataPoints();
int numFeatures = dataSet.getNumFeatures();
MultiLabel[] multiLabels = dataSet.getMultiLabels();
try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
bw.write("@RELATION MATRIX" + "\n");
for (int i = 0; i < numFeatures; i++) {
bw.write("@ATTRIBUTE " + i + " NUMERIC" + "\n");
}
for (int i = 0; i < dataSet.getNumClasses(); i++) {
bw.write("@ATTRIBUTE class " + i + " {0,1}" + "\n");
}
bw.write("@DATA" + "\n");
for (int i = 0; i < numDataPoints; i++) {
MultiLabel multiLabel = multiLabels[i];
List<Integer> labels = multiLabel.getMatchedLabels().stream().sorted().collect(Collectors.toList());
bw.write("{");
Vector vector = dataSet.getRow(i);
// only write non-zeros
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : vector.nonZeroes()) {
Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
pairs.add(pair);
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
for (Pair<Integer, Double> pair : sorted) {
bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
}
for (int l = 0; l < labels.size() - 1; l++) {
int label = labels.get(l) + numFeatures;
bw.write(label + " 1,");
}
int label = labels.get(labels.size() - 1) + numFeatures;
bw.write(label + " 1}" + "\n");
}
} catch (IOException e) {
e.printStackTrace();
}
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class ArffFormat method writeMatrixFile.
private static void writeMatrixFile(RegDataSet dataSet, File arffFile) {
File matrixFile = new File(arffFile, ARFF_MATRIX_FILE_NAME);
int numDataPoints = dataSet.getNumDataPoints();
int numFeatures = dataSet.getNumFeatures();
double[] labels = dataSet.getLabels();
try (BufferedWriter bw = new BufferedWriter(new FileWriter(matrixFile))) {
bw.write("@RELATION MATRIX" + "\n");
for (int i = 0; i < numFeatures; i++) {
bw.write("@ATTRIBUTE " + i + " NUMERIC" + "\n");
}
bw.write("@ATTRIBUTE class NUMERIC" + "\n");
bw.write("@DATA" + "\n");
for (int i = 0; i < numDataPoints; i++) {
double label = labels[i];
bw.write("{");
Vector vector = dataSet.getRow(i);
// only write non-zeros
List<Pair<Integer, Double>> pairs = new ArrayList<>();
for (Vector.Element element : vector.nonZeroes()) {
Pair<Integer, Double> pair = new Pair<>(element.index(), element.get());
pairs.add(pair);
}
Comparator<Pair<Integer, Double>> comparator = Comparator.comparing(Pair::getFirst);
List<Pair<Integer, Double>> sorted = pairs.stream().sorted(comparator).collect(Collectors.toList());
for (Pair<Integer, Double> pair : sorted) {
bw.write(pair.getFirst() + ":" + pair.getSecond() + " ");
}
bw.write(numFeatures + " " + label + "}" + "\n");
}
} catch (IOException e) {
e.printStackTrace();
}
}
use of edu.neu.ccs.pyramid.util.Pair in project pyramid by cheng-li.
the class TrustRegionNewtonOptimizer method trcg.
/**
*
* @param delta input
* @param g input
* @return s, r
*/
private Pair<Vector, Vector> trcg(double delta, Vector g) {
int numColumns = loss.getNumColumns();
double one = 1;
Vector d = new DenseVector(numColumns);
Vector Hd = new DenseVector(numColumns);
double rTr, rnewTrnew, cgtol;
Vector s = new DenseVector(numColumns);
Vector r = new DenseVector(numColumns);
Pair<Vector, Vector> result = new Pair<>();
for (int i = 0; i < numColumns; i++) {
s.set(i, 0);
r.set(i, -g.get(i));
d.set(i, r.get(i));
}
cgtol = 0.1 * g.norm(2);
rTr = r.dot(r);
while (true) {
if (r.norm(2) <= cgtol) {
break;
}
loss.Hv(d, Hd);
double alpha = rTr / d.dot(Hd);
daxpy(alpha, d, s);
if (s.norm(2) > delta) {
alpha = -alpha;
daxpy(alpha, d, s);
double std = s.dot(d);
double sts = s.dot(s);
double dtd = d.dot(d);
double dsq = delta * delta;
double rad = Math.sqrt(std * std + dtd * (dsq - sts));
if (std >= 0)
alpha = (dsq - sts) / (std + rad);
else
alpha = (rad - std) / dtd;
daxpy(alpha, d, s);
alpha = -alpha;
daxpy(alpha, Hd, r);
break;
}
alpha = -alpha;
daxpy(alpha, Hd, r);
rnewTrnew = r.dot(r);
double beta = rnewTrnew / rTr;
scale(beta, d);
daxpy(one, r, d);
rTr = rnewTrnew;
}
result.setFirst(s);
result.setSecond(r);
return result;
}
Aggregations