use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.
the class ValidationConductorTest method testPartershipModel.
// @Test
public void testPartershipModel() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("/Users/zhanhu/temp/partnership_varselect/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("/Users/zhanhu/temp/partnership_varselect/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
List<Integer> columnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
columnIdList.add(columnConfig.getColumnNum());
}
}
TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
List<String> recordsList = IOUtils.readLines(new FileInputStream("/Users/zhanhu/temp/partnership_varselect/part-m-00479"));
for (String record : recordsList) {
addNormalizedRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
}
Set<Integer> workingList = new HashSet<Integer>();
for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
workingList.clear();
workingList.add(columnId);
ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
double error = conductor.runValidate();
System.out.println("The error is - " + error + ", for columnId - " + columnId);
}
}
use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.
the class WrapperWorkerConductorTest method genTrainingDataSet.
public TrainingDataSet genTrainingDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) throws IOException {
List<Integer> columnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isCandidate(hasCandidates)) {
columnIdList.add(columnConfig.getColumnNum());
}
}
TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
List<String> recordsList = IOUtils.readLines(new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00"));
for (String record : recordsList) {
addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
}
return trainingDataSet;
}
use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.
the class ValidationConductorTest method testRunValidate.
@Test
public void testRunValidate() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
List<Integer> columnIdList = new ArrayList<Integer>();
boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.isCandidate(hasCandidates)) {
columnIdList.add(columnConfig.getColumnNum());
}
}
TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
List<String> recordsList = IOUtils.readLines(new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00"));
for (String record : recordsList) {
addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
}
Set<Integer> workingList = new HashSet<Integer>();
for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
workingList.clear();
workingList.add(columnId);
ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
double error = conductor.runValidate();
System.out.println("The error is - " + error + ", for columnId - " + columnId);
}
}
use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.
the class WrapperWorkerConductorTest method testWrapperConductor.
@Test
public void testWrapperConductor() throws IOException {
ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
WrapperWorkerConductor wrapper = new WrapperWorkerConductor(modelConfig, columnConfigList);
TrainingDataSet trainingDataSet = genTrainingDataSet(modelConfig, columnConfigList);
wrapper.retainData(trainingDataSet);
List<Integer> columnIdList = new ArrayList<Integer>();
for (int i = 2; i < 30; i++) {
columnIdList.add(i);
}
List<CandidateSeed> seedList = new ArrayList<CandidateSeed>();
for (int i = 0; i < 10; i++) {
seedList.add(new CandidateSeed(0, columnIdList.subList(i + 1, i + 7)));
}
wrapper.consumeMasterResult(new VarSelMasterResult(seedList));
VarSelWorkerResult workerResult = wrapper.generateVarSelResult();
Assert.assertNotNull(workerResult);
Assert.assertTrue(workerResult.getSeedPerfList().size() > 0);
}
use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.
the class VarSelWorker method init.
@Override
public void init(WorkerContext<VarSelMasterResult, VarSelWorkerResult> workerContext) {
Properties props = workerContext.getProps();
try {
RawSourceData.SourceType sourceType = RawSourceData.SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, RawSourceData.SourceType.HDFS.toString()));
this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
String conductorClsName = props.getProperty(Constants.VAR_SEL_WORKER_CONDUCTOR);
this.workerConductor = (AbstractWorkerConductor) Class.forName(conductorClsName).getDeclaredConstructor(ModelConfig.class, List.class).newInstance(this.modelConfig, this.columnConfigList);
} catch (IOException e) {
throw new RuntimeException("Fail to load ModelConfig or List<ColumnConfig>", e);
} catch (ClassNotFoundException e) {
throw new RuntimeException("Invalid Master Conductor class", e);
} catch (InstantiationException e) {
throw new RuntimeException("Fail to create instance", e);
} catch (IllegalAccessException e) {
throw new RuntimeException("Illegal access when creating instance", e);
} catch (NoSuchMethodException e) {
throw new RuntimeException("Fail to call method when creating instance", e);
} catch (InvocationTargetException e) {
throw new RuntimeException("Fail to invoke when creating instance", e);
}
List<Integer> normalizedColumnIdList = this.getNormalizedColumnIdList();
this.inputNodeCount = normalizedColumnIdList.size();
this.outputNodeCount = this.getTargetColumnCount();
trainingDataSet = new TrainingDataSet(normalizedColumnIdList);
try {
dataPurifier = new DataPurifier(modelConfig, false);
} catch (IOException e) {
throw new RuntimeException("Fail to create DataPurifier", e);
}
this.targetColumnId = CommonUtils.getTargetColumnNum(this.columnConfigList);
if (StringUtils.isNotBlank(modelConfig.getWeightColumnName())) {
for (ColumnConfig columnConfig : columnConfigList) {
if (columnConfig.getColumnName().equalsIgnoreCase(modelConfig.getWeightColumnName().trim())) {
this.weightColumnId = columnConfig.getColumnNum();
break;
}
}
}
}
Aggregations