use of de.bwaldvogel.liblinear.Problem in project dkpro-tc by dkpro.
the class LiblinearLoadModelConnector method runPrediction.
@Override
protected File runPrediction(File infile) throws Exception {
Problem predictionProblem = Problem.readFromFile(infile, 1.0);
File tmp = File.createTempFile("libLinearePrediction", ".txt");
BufferedWriter writer = null;
try {
writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(tmp), "utf-8"));
Feature[][] testInstances = predictionProblem.x;
for (int i = 0; i < testInstances.length; i++) {
Feature[] instance = testInstances[i];
Double prediction = Linear.predict(liblinearModel, instance);
writer.write(prediction.toString() + "\n");
}
} finally {
IOUtils.closeQuietly(writer);
}
tmp.deleteOnExit();
return tmp;
}
use of de.bwaldvogel.liblinear.Problem in project dkpro-tc by dkpro.
the class LiblinearTestTask method trainModel.
@Override
protected Object trainModel(TaskContext aContext) throws Exception {
File fileTrain = getTrainFile(aContext);
// default for bias is -1, documentation says to set it to 1 in order to
// get results closer
// to libsvm
// writer adds bias, so if we de-activate that here for some reason, we
// need to also
// deactivate it there
Problem train = Problem.readFromFile(fileTrain, 1.0);
SolverType solver = LiblinearUtils.getSolver(classificationArguments);
double C = LiblinearUtils.getParameterC(classificationArguments);
double eps = LiblinearUtils.getParameterEpsilon(classificationArguments);
Linear.setDebugOutput(null);
Parameter parameter = new Parameter(solver, C, eps);
Model model = Linear.train(train, parameter);
return model;
}
use of de.bwaldvogel.liblinear.Problem in project dkpro-tc by dkpro.
the class LiblinearTestTask method runPrediction.
@Override
protected void runPrediction(TaskContext aContext, Object trainedModel) throws Exception {
Model model = (Model) trainedModel;
File fileTest = getTestFile(aContext);
File predFolder = aContext.getFolder("", AccessMode.READWRITE);
File predictionsFile = new File(predFolder, Constants.FILENAME_PREDICTIONS);
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(predictionsFile), "utf-8"));
writer.append("#PREDICTION;GOLD" + "\n");
Problem test = Problem.readFromFile(fileTest, 1.0);
Feature[][] testInstances = test.x;
for (int i = 0; i < testInstances.length; i++) {
Feature[] instance = testInstances[i];
Double prediction = Linear.predict(model, instance);
writer.write(prediction + SEPARATOR_CHAR + new Double(test.y[i]));
writer.write("\n");
}
writer.close();
}
use of de.bwaldvogel.liblinear.Problem in project dkpro-tc by dkpro.
the class LiblinearSerializeModelConnector method trainModel.
@Override
protected void trainModel(TaskContext aContext, File fileTrain) throws Exception {
SolverType solver = LiblinearUtils.getSolver(classificationArguments);
double C = LiblinearUtils.getParameterC(classificationArguments);
double eps = LiblinearUtils.getParameterEpsilon(classificationArguments);
Linear.setDebugOutput(null);
Parameter parameter = new Parameter(solver, C, eps);
Problem train = Problem.readFromFile(fileTrain, 1.0);
Model model = Linear.train(train, parameter);
model.save(new File(outputFolder, MODEL_CLASSIFIER));
}
use of de.bwaldvogel.liblinear.Problem in project dkpro-tc by dkpro.
the class LiblinearDataWriterTest method dataWriterTest.
@Test
public void dataWriterTest() throws Exception {
List<Instance> fs = new ArrayList<Instance>();
List<Feature> features1 = new ArrayList<>();
features1.add(new Feature("feature1", 1.0, FeatureType.NUMERIC));
features1.add(new Feature("feature2", 0.0, FeatureType.NUMERIC));
List<Feature> features2 = new ArrayList<>();
features2.add(new Feature("feature2", 0.5, FeatureType.NUMERIC));
features2.add(new Feature("feature1", 0.5, FeatureType.NUMERIC));
Instance instance1 = new Instance(features1, "0");
Instance instance2 = new Instance(features2, "1");
fs.add(instance1);
fs.add(instance2);
File outputDirectory = folder.newFolder();
StringBuilder sb = new StringBuilder();
sb.append("feature1\n");
sb.append("feature2\n");
FileUtils.writeStringToFile(new File(outputDirectory, Constants.FILENAME_FEATURES), sb.toString(), "utf-8");
File outputFile = new File(outputDirectory, Constants.FILENAME_DATA_IN_CLASSIFIER_FORMAT);
LibsvmDataFormatWriter writer = new LibsvmDataFormatWriter();
writer.init(outputDirectory, false, Constants.LM_SINGLE_LABEL, false, new String[] { "0", "1" });
writer.writeClassifierFormat(fs);
Problem problem = Problem.readFromFile(outputFile, 1.0);
assertEquals(2, problem.l);
assertEquals(4, problem.n);
assertEquals(0.0, problem.y[0], 0.00001);
assertEquals(1.0, problem.y[1], 0.00001);
}
Aggregations