use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class Merger method mergeClfData.
private static void mergeClfData(Config config) throws Exception {
String input1 = config.getString("input.data1");
String input2 = config.getString("input.data2");
String output = config.getString("output.data");
ClfDataSet dataSet1 = TRECFormat.loadClfDataSet(input1, DataSetType.CLF_DENSE, true);
ClfDataSet dataSet2 = TRECFormat.loadClfDataSet(input2, DataSetType.CLF_DENSE, true);
ClfDataSet merged = DataSetUtil.concatenateByRow(dataSet1, dataSet2);
TRECFormat.save(merged, output);
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class Trec2LibSvm method main.
public static void main(String[] args) throws Exception {
Config config = new Config(args[0]);
System.out.println(config);
List<String> trecs = config.getStrings("trec");
List<String> libSVMs = config.getStrings("libSVM");
for (int i = 0; i < trecs.size(); i++) {
ClfDataSet trecDataset = TRECFormat.loadClfDataSet(new File(trecs.get(i)), DataSetType.CLF_SPARSE, false);
System.out.println(i + " -- Translating on trecs: " + trecs.get(i));
LibSvmFormat.save(trecDataset, libSVMs.get(i));
}
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class Trec2Matlab method main.
public static void main(String[] args) throws Exception {
Config config = new Config(args[0]);
File trecFile = new File(config.getString("input.trecFile"));
ClfDataSet dataSet = TRECFormat.loadClfDataSet(trecFile, DataSetType.CLF_SPARSE, false);
File matlabFile = new File(config.getString("output.matlabFile"));
matlabFile.getParentFile().mkdirs();
try (BufferedWriter bw = new BufferedWriter(new FileWriter(matlabFile))) {
for (int i = 0; i < dataSet.getNumDataPoints(); i++) {
Vector vector = dataSet.getRow(i);
for (Vector.Element element : vector.nonZeroes()) {
int j = element.index();
double value = element.get();
bw.write("" + (i + 1));
bw.write("\t");
bw.write("" + (j + 1));
bw.write("\t");
bw.write("" + value);
bw.newLine();
}
}
}
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class RidgeLogisticOptimizerTest method test4.
private static void test4() throws Exception {
// ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"),
// DataSetType.CLF_SPARSE, true);
// ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"),
// DataSetType.CLF_SPARSE, true);
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
// ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"),
// DataSetType.CLF_SPARSE, true);
// ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"),
// DataSetType.CLF_SPARSE, true);
double variance = 1000;
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
double[] weights = new double[dataSet.getNumDataPoints()];
for (int i = 0; i < weights.length; i++) {
if (Math.random() < 0.1) {
weights[i] = 0;
} else {
weights[i] = 1;
}
}
RidgeLogisticOptimizer optimizer = new RidgeLogisticOptimizer(logisticRegression, dataSet, dataSet.getLabels(), weights, variance, true);
System.out.println("after initialization");
System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
StopWatch stopWatch = new StopWatch();
stopWatch.start();
for (int i = 0; i < 20; i++) {
((LBFGS) optimizer.getOptimizer()).iterate();
System.out.println("after iteration " + i);
System.out.println(stopWatch);
// System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
// System.out.println("test acc = "+Accuracy.accuracy(logisticRegression,testSet));
// System.out.println(logisticRegression);
}
}
use of edu.neu.ccs.pyramid.dataset.ClfDataSet in project pyramid by cheng-li.
the class RidgeLogisticOptimizerTest method test1.
private static void test1() throws Exception {
// ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/train.trec"),
// DataSetType.CLF_SPARSE, true);
// ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/imdb/3/test.trec"),
// DataSetType.CLF_SPARSE, true);
ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/train.trec"), DataSetType.CLF_SPARSE, true);
ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "20newsgroup/1/test.trec"), DataSetType.CLF_SPARSE, true);
// ClfDataSet dataSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/train.trec"),
// DataSetType.CLF_SPARSE, true);
// ClfDataSet testSet = TRECFormat.loadClfDataSet(new File(DATASETS, "/spam/trec_data/test.trec"),
// DataSetType.CLF_SPARSE, true);
double variance = 1000;
LogisticRegression logisticRegression = new LogisticRegression(dataSet.getNumClasses(), dataSet.getNumFeatures());
RidgeLogisticOptimizer optimizer = new RidgeLogisticOptimizer(logisticRegression, dataSet, variance, true);
optimizer.getOptimizer().getTerminator().setMaxIteration(10000).setMode(Terminator.Mode.STANDARD);
System.out.println("after initialization");
System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
optimizer.optimize();
System.out.println("after training");
System.out.println("train acc = " + Accuracy.accuracy(logisticRegression, dataSet));
System.out.println("test acc = " + Accuracy.accuracy(logisticRegression, testSet));
System.out.println(optimizer.getOptimizer().getTerminator().getHistory());
System.out.println(logisticRegression);
}
Aggregations