use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class ValidationConductorTest method testRunValidate.
@Test
public void testRunValidate() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
List<Integer> columnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isCandidate(hasCandidates)) {
columnIdList.add(columnConfig.getColumnNum());
}
}
TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
List<String> recordsList = IOUtils.readLines(new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00"));
for (String record : recordsList) {
addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
}
Set<Integer> workingList = new HashSet<Integer>();
for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
workingList.clear();
workingList.add(columnId);
ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
double error = conductor.runValidate();
System.out.println("The error is - " + error + ", for columnId - " + columnId);
}
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class WrapperWorkerConductorTest method testWrapperConductor.
@Test
public void testWrapperConductor() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
WrapperWorkerConductor wrapper = new WrapperWorkerConductor(modelConfig, columnConfigList);
TrainingDataSet trainingDataSet = genTrainingDataSet(modelConfig, columnConfigList);
wrapper.retainData(trainingDataSet);
List<Integer> columnIdList = new ArrayList<Integer>();
for (int i = 2; i < 30; i++) {
columnIdList.add(i);
}
List<CandidateSeed> seedList = new ArrayList<CandidateSeed>();
for (int i = 0; i < 10; i++) {
seedList.add(new CandidateSeed(0, columnIdList.subList(i + 1, i + 7)));
}
wrapper.consumeMasterResult(new VarSelMasterResult(seedList));
VarSelWorkerResult workerResult = wrapper.generateVarSelResult();
Assert.assertNotNull(workerResult);
Assert.assertTrue(workerResult.getSeedPerfList().size() > 0);
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CommonUtilsTest method getTargetColumnNumTest.
@Test
public void getTargetColumnNumTest() {
List<ColumnConfig> list = new ArrayList<ColumnConfig>();
ColumnConfig config = new ColumnConfig();
config.setColumnFlag(null);
list.add(config);
config = new ColumnConfig();
config.setColumnFlag(ColumnFlag.Target);
config.setColumnNum(20);
list.add(config);
config = new ColumnConfig();
config.setColumnFlag(null);
list.add(config);
Assert.assertEquals(Integer.valueOf(20), CommonUtils.getTargetColumnNum(list));
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CommonUtilsTest method syncTest.
// @Test
public void syncTest() throws IOException {
ModelConfig config = ModelConfig.createInitModelConfig(".", ALGORITHM.NN, "test", false);
config.setModelSetName("testModel");
jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("ModelConfig.json"), config);
ColumnConfig col = new ColumnConfig();
col.setColumnName("ColumnA");
List<ColumnConfig> columnConfigList = new ArrayList<ColumnConfig>();
columnConfigList.add(col);
config.getDataSet().setSource(SourceType.LOCAL);
;
jsonMapper.writerWithDefaultPrettyPrinter().writeValue(new File("ColumnConfig.json"), columnConfigList);
File file = null;
file = new File("models");
if (!file.exists()) {
FileUtils.forceMkdir(file);
}
file = new File("models/model1.nn");
if (!file.exists()) {
if (file.createNewFile()) {
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), Constants.DEFAULT_CHARSET));
writer.write("test string");
writer.close();
} else {
LOG.warn("Create file {} failed", file.getAbsolutePath());
}
}
file = new File("EvalSets/test");
if (!file.exists()) {
FileUtils.forceMkdir(file);
}
file = new File("EvalSets/test/EvalConfig.json");
if (!file.exists()) {
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), Constants.DEFAULT_CHARSET));
writer.write("test string");
writer.close();
}
CommonUtils.copyConfFromLocalToHDFS(config, new PathFinder(config));
file = new File("ModelSets");
Assert.assertTrue(file.exists());
file = new File("ModelSets/testModel");
Assert.assertTrue(file.exists());
file = new File("ModelSets/testModel/ModelConfig.json");
Assert.assertTrue(file.exists());
file = new File("ModelSets/testModel/ColumnConfig.json");
Assert.assertTrue(file.exists());
file = new File("ModelSets/testModel/ReasonCodeMap.json");
Assert.assertTrue(file.exists());
file = new File("ModelSets/testModel/models/model1.nn");
Assert.assertTrue(file.exists());
file = new File("ModelSets/testModel/EvalSets/test/EvalConfig.json");
Assert.assertTrue(file.exists());
file = new File("ModelSets");
if (file.exists()) {
FileUtils.deleteDirectory(file);
}
file = new File("ColumnConfig.json");
FileUtils.deleteQuietly(file);
file = new File("ModelConfig.json");
FileUtils.deleteQuietly(file);
FileUtils.deleteDirectory(new File("models"));
FileUtils.deleteDirectory(new File("EvalSets"));
}
use of ml.shifu.shifu.container.obj.ColumnConfig in project shifu by ShifuML.
the class CommonUtilsTest method updateColumnConfigFlagsTest.
// @Test
public void updateColumnConfigFlagsTest() throws IOException {
ModelConfig config = ModelConfig.createInitModelConfig("test", ALGORITHM.NN, "test", false);
config.getDataSet().setMetaColumnNameFile("./conf/meta_column_conf.txt");
config.getVarSelect().setForceRemoveColumnNameFile("./conf/remove_column_list.txt");
List<ColumnConfig> list = new ArrayList<ColumnConfig>();
ColumnConfig e = new ColumnConfig();
e.setColumnName("a");
list.add(e);
e = new ColumnConfig();
e.setColumnName("c");
list.add(e);
e = new ColumnConfig();
e.setColumnName("d");
list.add(e);
ColumnConfigUpdater.updateColumnConfigFlags(config, list, ModelInspector.ModelStep.VARSELECT);
Assert.assertTrue(list.get(0).isMeta());
}
Aggregations