Search in sources :

Example 11 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class NNModelSpecTest method testModelTraverse.

// @Test
public void testModelTraverse() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    for (int layer = flatNetwork.getLayerIndex().length - 1; layer > 0; layer--) {
        int layerOutputCnt = flatNetwork.getLayerFeedCounts()[layer - 1];
        int layerInputCnt = flatNetwork.getLayerCounts()[layer];
        System.out.println("Weight index for layer " + (flatNetwork.getLayerIndex().length - layer));
        int extendedLayerInputCnt = extendedFlatNetwork.getLayerCounts()[layer];
        int indexPos = flatNetwork.getWeightIndex()[layer - 1];
        int extendedIndexPos = extendedFlatNetwork.getWeightIndex()[layer - 1];
        for (int i = 0; i < layerOutputCnt; i++) {
            for (int j = 0; j < layerInputCnt; j++) {
                int weightIndex = indexPos + (i * layerInputCnt) + j;
                int extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + j;
                if (j == layerInputCnt - 1) {
                    // move bias to end
                    extendedWeightIndex = extendedIndexPos + (i * extendedLayerInputCnt) + (extendedLayerInputCnt - 1);
                }
                System.out.println(weightIndex + " --> " + extendedWeightIndex);
            }
        }
    }
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) BasicNetwork(org.encog.neural.networks.BasicNetwork) BasicML(org.encog.ml.BasicML) File(java.io.File)

Example 12 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class NNModelSpecTest method testModelFitIn.

@Test
public void testModelFitIn() {
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    NNMaster master = new NNMaster();
    Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
    Assert.assertEquals(fixedWeightIndexSet.size(), 931);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
    Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) NNMaster(ml.shifu.shifu.core.dtrain.nn.NNMaster) BasicNetwork(org.encog.neural.networks.BasicNetwork) BasicML(org.encog.ml.BasicML) File(java.io.File) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) Test(org.testng.annotations.Test)

Example 13 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class ShifuCLI method analysisModelFi.

public static int analysisModelFi(String modelPath) {
    File modelFile = new File(modelPath);
    if (!modelFile.exists() || !(modelPath.toUpperCase().endsWith("." + CommonConstants.GBT_ALG_NAME) || modelPath.toUpperCase().endsWith("." + CommonConstants.RF_ALG_NAME))) {
        log.error("The model {} doesn't exist or it isn't GBT/RF model.", modelPath);
        return 1;
    }
    FileInputStream inputStream = null;
    String fiFileName = modelFile.getName() + ".fi";
    try {
        inputStream = new FileInputStream(modelFile);
        BasicML basicML = TreeModel.loadFromStream(inputStream);
        Map<Integer, MutablePair<String, Double>> featureImportances = CommonUtils.computeTreeModelFeatureImportance(Arrays.asList(new BasicML[] { basicML }));
        CommonUtils.writeFeatureImportance(fiFileName, featureImportances);
    } catch (IOException e) {
        log.error("Fail to analysis model FI for {}", modelPath);
        return 1;
    } finally {
        IOUtils.closeQuietly(inputStream);
    }
    return 0;
}
Also used : MutablePair(org.apache.commons.lang3.tuple.MutablePair) BasicML(org.encog.ml.BasicML) IOException(java.io.IOException) JarFile(java.util.jar.JarFile) File(java.io.File) FileInputStream(java.io.FileInputStream)

Example 14 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class PostTrainMapper method setup.

@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
protected void setup(Context context) throws IOException, InterruptedException {
    loadConfigFiles(context);
    loadTagWeightNum();
    this.dataPurifier = new DataPurifier(this.modelConfig, false);
    this.outputKey = new IntWritable();
    this.outputValue = new Text();
    this.tags = new HashSet<String>(modelConfig.getFlattenTags());
    SourceType sourceType = this.modelConfig.getDataSet().getSource();
    List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelConfig, null, sourceType);
    this.headers = CommonUtils.getFinalHeaders(modelConfig);
    this.modelRunner = new ModelRunner(modelConfig, columnConfigList, this.headers, modelConfig.getDataSetDelimiter(), models);
    this.mos = new MultipleOutputs<NullWritable, Text>((TaskInputOutputContext) context);
    this.initFeatureStats();
}
Also used : DataPurifier(ml.shifu.shifu.core.DataPurifier) SourceType(ml.shifu.shifu.container.obj.RawSourceData.SourceType) TaskInputOutputContext(org.apache.hadoop.mapreduce.TaskInputOutputContext) Text(org.apache.hadoop.io.Text) BasicML(org.encog.ml.BasicML) NullWritable(org.apache.hadoop.io.NullWritable) IntWritable(org.apache.hadoop.io.IntWritable) ModelRunner(ml.shifu.shifu.core.ModelRunner)

Example 15 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class ExportModelProcessor method run.

/*
     * (non-Javadoc)
     * 
     * @see ml.shifu.shifu.core.processor.Processor#run()
     */
@Override
public int run() throws Exception {
    setUp(ModelStep.EXPORT);
    int status = 0;
    File pmmls = new File("pmmls");
    FileUtils.forceMkdir(pmmls);
    if (StringUtils.isBlank(type)) {
        type = PMML;
    }
    String modelsPath = pathFinder.getModelsPath(SourceType.LOCAL);
    if (type.equalsIgnoreCase(ONE_BAGGING_MODEL)) {
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm()) && !CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging model is only supported in NN/GBT/RF algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            if (models.size() < 1) {
                log.warn("No model is found in {}.", modelsPath);
            } else {
                log.info("Convert nn models into one binary bagging model.");
                Configuration conf = new Configuration();
                Path output = new Path(pathFinder.getBaggingModelPath(SourceType.LOCAL), "model.b" + modelConfig.getAlgorithm());
                if ("nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
                    BinaryNNSerializer.save(modelConfig, columnConfigList, models, FileSystem.getLocal(conf), output);
                } else if (CommonUtils.isTreeModel(modelConfig.getAlgorithm())) {
                    List<List<TreeNode>> baggingTrees = new ArrayList<List<TreeNode>>();
                    for (int i = 0; i < models.size(); i++) {
                        TreeModel tm = (TreeModel) models.get(i);
                        // TreeModel only has one TreeNode instance although it is list inside
                        baggingTrees.add(tm.getIndependentTreeModel().getTrees().get(0));
                    }
                    int[] inputOutputIndex = DTrainUtils.getNumericAndCategoricalInputAndOutputCounts(this.columnConfigList);
                    // numerical + categorical = # of all input
                    int inputCount = inputOutputIndex[0] + inputOutputIndex[1];
                    BinaryDTSerializer.save(modelConfig, columnConfigList, baggingTrees, modelConfig.getParams().get("Loss").toString(), inputCount, FileSystem.getLocal(conf), output);
                }
                log.info("Please find one unified bagging model in local {}.", output);
            }
        }
    } else if (type.equalsIgnoreCase(PMML)) {
        // typical pmml generation
        List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
        PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), false);
        for (int index = 0; index < models.size(); index++) {
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + Integer.toString(index) + ".pmml";
            log.info("\t Start to generate " + path);
            PMML pmml = translator.build(Arrays.asList(new BasicML[] { models.get(index) }));
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(ONE_BAGGING_PMML_MODEL)) {
        // one unified bagging pmml generation
        log.info("Convert models into one bagging pmml model {} format", type);
        if (!"nn".equalsIgnoreCase(modelConfig.getAlgorithm())) {
            log.warn("Currently one bagging pmml model is only supported in NN algorithm.");
        } else {
            List<BasicML> models = ModelSpecLoaderUtils.loadBasicModels(modelsPath, ALGORITHM.valueOf(modelConfig.getAlgorithm().toUpperCase()));
            PMMLTranslator translator = PMMLConstructorFactory.produce(modelConfig, columnConfigList, isConcise(), true);
            String path = "pmmls" + File.separator + modelConfig.getModelSetName() + ".pmml";
            log.info("\t Start to generate one unified model to: " + path);
            PMML pmml = translator.build(models);
            PMMLUtils.savePMML(pmml, path);
        }
    } else if (type.equalsIgnoreCase(COLUMN_STATS)) {
        saveColumnStatus();
    } else if (type.equalsIgnoreCase(WOE_MAPPING)) {
        List<ColumnConfig> exportCatColumns = new ArrayList<ColumnConfig>();
        List<String> catVariables = getRequestVars();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (CollectionUtils.isEmpty(catVariables) || isRequestColumn(catVariables, columnConfig)) {
                exportCatColumns.add(columnConfig);
            }
        }
        if (CollectionUtils.isNotEmpty(exportCatColumns)) {
            List<String> woeMappings = new ArrayList<String>();
            for (ColumnConfig columnConfig : exportCatColumns) {
                String woeMapText = rebinAndExportWoeMapping(columnConfig);
                woeMappings.add(woeMapText);
            }
            FileUtils.write(new File("woemapping.txt"), StringUtils.join(woeMappings, ",\n"));
        }
    } else if (type.equalsIgnoreCase(WOE)) {
        List<String> woeInfos = new ArrayList<String>();
        for (ColumnConfig columnConfig : this.columnConfigList) {
            if (columnConfig.getBinLength() > 1 && ((columnConfig.isCategorical() && CollectionUtils.isNotEmpty(columnConfig.getBinCategory())) || (columnConfig.isNumerical() && CollectionUtils.isNotEmpty(columnConfig.getBinBoundary()) && columnConfig.getBinBoundary().size() > 1))) {
                List<String> varWoeInfos = generateWoeInfos(columnConfig);
                if (CollectionUtils.isNotEmpty(varWoeInfos)) {
                    woeInfos.addAll(varWoeInfos);
                    woeInfos.add("");
                }
            }
            FileUtils.writeLines(new File("varwoe_info.txt"), woeInfos);
        }
    } else if (type.equalsIgnoreCase(CORRELATION)) {
        // export correlation into mapping list
        if (!ShifuFileUtils.isFileExists(pathFinder.getLocalCorrelationCsvPath(), SourceType.LOCAL)) {
            log.warn("The correlation file doesn't exist. Please make sure you have ran `shifu stats -c`.");
            return 2;
        }
        return exportVariableCorr();
    } else {
        log.error("Unsupported output format - {}", type);
        status = -1;
    }
    clearUp(ModelStep.EXPORT);
    log.info("Done.");
    return status;
}
Also used : Path(org.apache.hadoop.fs.Path) Configuration(org.apache.hadoop.conf.Configuration) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) BasicML(org.encog.ml.BasicML) PMMLTranslator(ml.shifu.shifu.core.pmml.PMMLTranslator) TreeModel(ml.shifu.shifu.core.TreeModel) TreeNode(ml.shifu.shifu.core.dtrain.dt.TreeNode) PMML(org.dmg.pmml.PMML) File(java.io.File)

Aggregations

BasicML (org.encog.ml.BasicML)23 File (java.io.File)6 BasicNetwork (org.encog.neural.networks.BasicNetwork)5 IOException (java.io.IOException)4 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)4 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)4 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)4 FileSystem (org.apache.hadoop.fs.FileSystem)4 FlatNetwork (org.encog.neural.flat.FlatNetwork)4 ArrayList (java.util.ArrayList)3 NSColumn (ml.shifu.shifu.column.NSColumn)3 ModelRunner (ml.shifu.shifu.core.ModelRunner)3 ModelSpec (ml.shifu.shifu.core.model.ModelSpec)3 MutablePair (org.apache.commons.lang3.tuple.MutablePair)3 Configuration (org.apache.hadoop.conf.Configuration)3 FileStatus (org.apache.hadoop.fs.FileStatus)3 Path (org.apache.hadoop.fs.Path)3 JarFile (java.util.jar.JarFile)2 CaseScoreResult (ml.shifu.shifu.container.CaseScoreResult)2 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)2