use of ml.shifu.shifu.combo.CsvFile in project shifu by ShifuML.
the class PmmlSpecValidationTest method doValidation.
@SuppressWarnings("unchecked")
private boolean doValidation(String pmmlPath, String DataPath, String delimiter, String scoreName) throws Exception {
PMML pmml = PMMLUtils.loadPMML(pmmlPath);
NeuralNetworkEvaluator evaluator = new NeuralNetworkEvaluator(pmml);
List<TargetField> targetFields = evaluator.getTargetFields();
CsvFile evalData = new CsvFile(DataPath, delimiter, true);
Iterator<Map<String, String>> iterator = evalData.iterator();
int mismatchCnt = 0;
while (iterator.hasNext()) {
Map<String, String> rawInput = iterator.next();
Map<FieldName, FieldValue> maps = convertRawIntoInput(evaluator, rawInput);
double pmmlScore = 0.0;
switch(evaluator.getModel().getMiningFunction()) {
case REGRESSION:
if (targetFields.size() == 1) {
Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
pmmlScore = regressionTerm.get(new FieldName(AbstractSpecifCreator.FINAL_RESULT));
} else {
Map<FieldName, Double> regressionTerm = (Map<FieldName, Double>) evaluator.evaluate(maps);
List<FieldName> outputFieldList = new ArrayList<FieldName>(regressionTerm.keySet());
Collections.sort(outputFieldList, new Comparator<FieldName>() {
@Override
public int compare(FieldName a, FieldName b) {
return a.getValue().compareTo(b.getValue());
}
});
for (int i = 0; i < outputFieldList.size(); i++) {
FieldName fieldName = outputFieldList.get(i);
if (fieldName.getValue().startsWith(AbstractSpecifCreator.FINAL_RESULT)) {
pmmlScore = regressionTerm.get(fieldName);
}
}
}
break;
case CLASSIFICATION:
Map<FieldName, Classification<Double>> classificationTerm = (Map<FieldName, Classification<Double>>) evaluator.evaluate(maps);
for (Classification<Double> cMap : classificationTerm.values()) for (Map.Entry<String, Value<Double>> entry : cMap.getValues().entrySet()) System.out.println(entry.getValue().getValue() * 1000);
break;
default:
break;
}
double expectScore = Double.parseDouble(rawInput.get(scoreName));
if (Math.abs(expectScore - pmmlScore) > EPS) {
System.out.println(rawInput.get("trans_id") + "|" + expectScore + "|" + pmmlScore);
mismatchCnt++;
}
}
return mismatchCnt == 0;
}
use of ml.shifu.shifu.combo.CsvFile in project shifu by ShifuML.
the class IndependentTreeModelTest method testGBTTreeEncode.
// @Test
public void testGBTTreeEncode() throws IOException {
IndependentTreeModel treeModel = IndependentTreeModel.loadFromStream(IndependentTreeModelTest.class.getResourceAsStream("/example/encode/model0.gbt"));
CsvFile csvFile = new CsvFile("src/test/resources/example/encode/sample.data.10", "\u0007");
for (Map<String, String> rawData : csvFile) {
Map<String, Object> input = new HashMap<String, Object>();
input.putAll(rawData);
List<String> instanceCodes = treeModel.encode(5, input);
System.out.println(instanceCodes);
}
}
use of ml.shifu.shifu.combo.CsvFile in project shifu by ShifuML.
the class PMMLScoreGenTest method genPMMLAndCompareScore.
private void genPMMLAndCompareScore(String modelName, String evalDataSet, String evalSetName, String delimiter) throws Exception {
Map<String, Object> params = new HashMap<String, Object>();
params.put(ExportModelProcessor.IS_CONCISE, true);
ShifuCLI.exportModel(ExportModelProcessor.ONE_BAGGING_PMML_MODEL, params);
int totalRecordCnt = 0;
int matchRecordCnt = 0;
CsvFile evalScoreFile = new CsvFile("evals" + File.separator + evalSetName + File.separator + "EvalScore", "|", true);
Iterator<Map<String, String>> scoreIterator = evalScoreFile.iterator();
// skip first line
scoreIterator.next();
CsvFile evalData = new CsvFile(evalDataSet, delimiter, true);
PMML pmml = PMMLUtils.loadPMML("pmmls" + File.separator + modelName + ".pmml");
MiningModelEvaluator evaluator = new MiningModelEvaluator(pmml);
Iterator<Map<String, String>> iterator = evalData.iterator();
while (iterator.hasNext() && scoreIterator.hasNext()) {
Map<String, String> rawInput = iterator.next();
double pmmlScore = score(evaluator, rawInput, "FinalResult");
Map<String, String> scoreInput = scoreIterator.next();
double evalScore = Double.parseDouble(scoreInput.get("mean"));
totalRecordCnt++;
if (Math.abs(evalScore - pmmlScore) < EPS) {
matchRecordCnt++;
}
}
Assert.assertTrue(matchRecordCnt == totalRecordCnt);
}
Aggregations