use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class ModelDataEncodeProcessor method updateModel.
private void updateModel(String encodeRefModel) throws IOException {
ModelConfig encodeModel = loadSubModelConfig(encodeRefModel);
encodeModel.setModelSetName(encodeRefModel);
int featureCnt = 1;
for (int i = 0; i < this.treeModel.getTrees().size(); i++) {
featureCnt = featureCnt * this.treeModel.getTrees().get(i).size();
}
List<String> categoricalVars = new ArrayList<String>();
for (int i = 0; i < featureCnt; i++) {
categoricalVars.add("tree_vars_" + i);
}
String catVarFileName = encodeModel.getDataSet().getCategoricalColumnNameFile();
if (StringUtils.isBlank(catVarFileName)) {
catVarFileName = "categorical.column.names";
encodeModel.getDataSet().setCategoricalColumnNameFile(catVarFileName);
}
FileUtils.writeLines(new File(encodeRefModel + File.separator + catVarFileName), categoricalVars);
saveModelConfig(encodeRefModel, encodeModel);
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class ValidationConductorTest method testPartershipModel.
// @Test
public void testPartershipModel() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("/Users/zhanhu/temp/partnership_varselect/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("/Users/zhanhu/temp/partnership_varselect/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
List<Integer> columnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
columnIdList.add(columnConfig.getColumnNum());
}
}
TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
List<String> recordsList = IOUtils.readLines(new FileInputStream("/Users/zhanhu/temp/partnership_varselect/part-m-00479"));
for (String record : recordsList) {
addNormalizedRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
}
Set<Integer> workingList = new HashSet<Integer>();
for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
workingList.clear();
workingList.add(columnId);
ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
double error = conductor.runValidate();
System.out.println("The error is - " + error + ", for columnId - " + columnId);
}
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class PerformanceEvaluatorTest method reviewTest.
@Test(expectedExceptions = FileNotFoundException.class)
public void reviewTest() throws IOException {
ModelConfig model = ModelConfig.createInitModelConfig("test", ALGORITHM.NN, ".", false);
EvalConfig eval = new EvalConfig();
eval.setName("test");
eval.setDataSet(new RawSourceData());
model.getBasic().setRunMode(RunMode.LOCAL);
PerformanceEvaluator actor = new PerformanceEvaluator(model, eval);
actor.review();
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class ModelInspectorTest method testValidateNormalize.
@Test
public void testValidateNormalize() throws Exception {
ModelConfig config = CommonUtils.loadModelConfig();
ValidateResult result = instance.probe(config, ModelStep.NORMALIZE);
Assert.assertTrue(result.getStatus());
}
use of ml.shifu.shifu.container.obj.ModelConfig in project shifu by ShifuML.
the class ModelInspectorTest method testValidateEval.
@Test
public void testValidateEval() throws Exception {
ModelConfig config = CommonUtils.loadModelConfig();
ValidateResult result = instance.probe(config, ModelStep.EVAL);
Assert.assertTrue(result.getStatus());
config.setEvals(null);
result = instance.probe(config, ModelStep.EVAL);
Assert.assertTrue(result.getStatus());
}
Aggregations