Search in sources :

Example 1 with TrainingDataSet

use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.

the class ValidationConductorTest method testPartershipModel.

// @Test
public void testPartershipModel() throws IOException {
    ModelConfig modelConfig = CommonUtils.loadModelConfig("/Users/zhanhu/temp/partnership_varselect/ModelConfig.json", RawSourceData.SourceType.LOCAL);
    List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("/Users/zhanhu/temp/partnership_varselect/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
    List<Integer> columnIdList = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig columnConfig : columnConfigList) {
        if (CommonUtils.isGoodCandidate(columnConfig, hasCandidates)) {
            columnIdList.add(columnConfig.getColumnNum());
        }
    }
    TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
    List<String> recordsList = IOUtils.readLines(new FileInputStream("/Users/zhanhu/temp/partnership_varselect/part-m-00479"));
    for (String record : recordsList) {
        addNormalizedRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
    }
    Set<Integer> workingList = new HashSet<Integer>();
    for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
        workingList.clear();
        workingList.add(columnId);
        ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
        double error = conductor.runValidate();
        System.out.println("The error is - " + error + ", for columnId - " + columnId);
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) FileInputStream(java.io.FileInputStream) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) HashSet(java.util.HashSet)

Example 2 with TrainingDataSet

use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.

the class WrapperWorkerConductorTest method genTrainingDataSet.

public TrainingDataSet genTrainingDataSet(ModelConfig modelConfig, List<ColumnConfig> columnConfigList) throws IOException {
    List<Integer> columnIdList = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig columnConfig : columnConfigList) {
        if (columnConfig.isCandidate(hasCandidates)) {
            columnIdList.add(columnConfig.getColumnNum());
        }
    }
    TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
    List<String> recordsList = IOUtils.readLines(new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00"));
    for (String record : recordsList) {
        addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
    }
    return trainingDataSet;
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) FileInputStream(java.io.FileInputStream)

Example 3 with TrainingDataSet

use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.

the class ValidationConductorTest method testRunValidate.

@Test
public void testRunValidate() throws IOException {
    ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
    List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
    List<Integer> columnIdList = new ArrayList<Integer>();
    boolean hasCandidates = CommonUtils.hasCandidateColumns(columnConfigList);
    for (ColumnConfig columnConfig : columnConfigList) {
        if (columnConfig.isCandidate(hasCandidates)) {
            columnIdList.add(columnConfig.getColumnNum());
        }
    }
    TrainingDataSet trainingDataSet = new TrainingDataSet(columnIdList);
    List<String> recordsList = IOUtils.readLines(new FileInputStream("src/test/resources/example/cancer-judgement/DataStore/DataSet1/part-00"));
    for (String record : recordsList) {
        addRecordIntoTrainDataSet(modelConfig, columnConfigList, trainingDataSet, record);
    }
    Set<Integer> workingList = new HashSet<Integer>();
    for (Integer columnId : trainingDataSet.getDataColumnIdList()) {
        workingList.clear();
        workingList.add(columnId);
        ValidationConductor conductor = new ValidationConductor(modelConfig, columnConfigList, workingList, trainingDataSet);
        double error = conductor.runValidate();
        System.out.println("The error is - " + error + ", for columnId - " + columnId);
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) FileInputStream(java.io.FileInputStream) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) HashSet(java.util.HashSet) Test(org.testng.annotations.Test)

Example 4 with TrainingDataSet

use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.

the class WrapperWorkerConductorTest method testWrapperConductor.

@Test
public void testWrapperConductor() throws IOException {
    ModelConfig modelConfig = CommonUtils.loadModelConfig("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ModelConfig.json", RawSourceData.SourceType.LOCAL);
    List<ColumnConfig> columnConfigList = CommonUtils.loadColumnConfigList("src/test/resources/example/cancer-judgement/ModelStore/ModelSet1/ColumnConfig.json", RawSourceData.SourceType.LOCAL);
    WrapperWorkerConductor wrapper = new WrapperWorkerConductor(modelConfig, columnConfigList);
    TrainingDataSet trainingDataSet = genTrainingDataSet(modelConfig, columnConfigList);
    wrapper.retainData(trainingDataSet);
    List<Integer> columnIdList = new ArrayList<Integer>();
    for (int i = 2; i < 30; i++) {
        columnIdList.add(i);
    }
    List<CandidateSeed> seedList = new ArrayList<CandidateSeed>();
    for (int i = 0; i < 10; i++) {
        seedList.add(new CandidateSeed(0, columnIdList.subList(i + 1, i + 7)));
    }
    wrapper.consumeMasterResult(new VarSelMasterResult(seedList));
    VarSelWorkerResult workerResult = wrapper.generateVarSelResult();
    Assert.assertNotNull(workerResult);
    Assert.assertTrue(workerResult.getSeedPerfList().size() > 0);
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) ArrayList(java.util.ArrayList) VarSelWorkerResult(ml.shifu.shifu.core.dvarsel.VarSelWorkerResult) CandidateSeed(ml.shifu.shifu.core.dvarsel.CandidateSeed) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) VarSelMasterResult(ml.shifu.shifu.core.dvarsel.VarSelMasterResult) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) Test(org.testng.annotations.Test)

Example 5 with TrainingDataSet

use of ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet in project shifu by ShifuML.

the class VarSelWorker method init.

@Override
public void init(WorkerContext<VarSelMasterResult, VarSelWorkerResult> workerContext) {
    Properties props = workerContext.getProps();
    try {
        RawSourceData.SourceType sourceType = RawSourceData.SourceType.valueOf(props.getProperty(CommonConstants.MODELSET_SOURCE_TYPE, RawSourceData.SourceType.HDFS.toString()));
        this.modelConfig = CommonUtils.loadModelConfig(props.getProperty(CommonConstants.SHIFU_MODEL_CONFIG), sourceType);
        this.columnConfigList = CommonUtils.loadColumnConfigList(props.getProperty(CommonConstants.SHIFU_COLUMN_CONFIG), sourceType);
        String conductorClsName = props.getProperty(Constants.VAR_SEL_WORKER_CONDUCTOR);
        this.workerConductor = (AbstractWorkerConductor) Class.forName(conductorClsName).getDeclaredConstructor(ModelConfig.class, List.class).newInstance(this.modelConfig, this.columnConfigList);
    } catch (IOException e) {
        throw new RuntimeException("Fail to load ModelConfig or List<ColumnConfig>", e);
    } catch (ClassNotFoundException e) {
        throw new RuntimeException("Invalid Master Conductor class", e);
    } catch (InstantiationException e) {
        throw new RuntimeException("Fail to create instance", e);
    } catch (IllegalAccessException e) {
        throw new RuntimeException("Illegal access when creating instance", e);
    } catch (NoSuchMethodException e) {
        throw new RuntimeException("Fail to call method when creating instance", e);
    } catch (InvocationTargetException e) {
        throw new RuntimeException("Fail to invoke when creating instance", e);
    }
    List<Integer> normalizedColumnIdList = this.getNormalizedColumnIdList();
    this.inputNodeCount = normalizedColumnIdList.size();
    this.outputNodeCount = this.getTargetColumnCount();
    trainingDataSet = new TrainingDataSet(normalizedColumnIdList);
    try {
        dataPurifier = new DataPurifier(modelConfig, false);
    } catch (IOException e) {
        throw new RuntimeException("Fail to create DataPurifier", e);
    }
    this.targetColumnId = CommonUtils.getTargetColumnNum(this.columnConfigList);
    if (StringUtils.isNotBlank(modelConfig.getWeightColumnName())) {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.getColumnName().equalsIgnoreCase(modelConfig.getWeightColumnName().trim())) {
                this.weightColumnId = columnConfig.getColumnNum();
                break;
            }
        }
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) IOException(java.io.IOException) Properties(java.util.Properties) RawSourceData(ml.shifu.shifu.container.obj.RawSourceData) InvocationTargetException(java.lang.reflect.InvocationTargetException) ModelConfig(ml.shifu.shifu.container.obj.ModelConfig) DataPurifier(ml.shifu.shifu.core.DataPurifier) TrainingDataSet(ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet) ArrayList(java.util.ArrayList) List(java.util.List)

Aggregations

ArrayList (java.util.ArrayList)5 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)5 TrainingDataSet (ml.shifu.shifu.core.dvarsel.dataset.TrainingDataSet)5 ModelConfig (ml.shifu.shifu.container.obj.ModelConfig)4 FileInputStream (java.io.FileInputStream)3 HashSet (java.util.HashSet)2 Test (org.testng.annotations.Test)2 IOException (java.io.IOException)1 InvocationTargetException (java.lang.reflect.InvocationTargetException)1 List (java.util.List)1 Properties (java.util.Properties)1 RawSourceData (ml.shifu.shifu.container.obj.RawSourceData)1 DataPurifier (ml.shifu.shifu.core.DataPurifier)1 CandidateSeed (ml.shifu.shifu.core.dvarsel.CandidateSeed)1 VarSelMasterResult (ml.shifu.shifu.core.dvarsel.VarSelMasterResult)1 VarSelWorkerResult (ml.shifu.shifu.core.dvarsel.VarSelWorkerResult)1