Search in sources :

Example 16 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class TrainModelProcessor method inputOutputModelCheckSuccess.

@SuppressWarnings("unchecked")
private boolean inputOutputModelCheckSuccess(FileSystem fileSystem, Path modelPath, Map<String, Object> modelParams) throws IOException {
    BasicML basicML = ModelSpecLoaderUtils.loadModel(this.modelConfig, modelPath, fileSystem);
    BasicFloatNetwork model = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
    int[] outputCandidateCounts = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), getColumnConfigList());
    int inputs = outputCandidateCounts[0] == 0 ? outputCandidateCounts[2] : outputCandidateCounts[0];
    boolean isInputOutConsistent = model.getInputCount() <= inputs && model.getOutputCount() == outputCandidateCounts[1];
    if (!isInputOutConsistent) {
        return false;
    }
    // same hidden layer ?
    boolean isHasSameHiddenLayer = (model.getLayerCount() - 2) <= (Integer) modelParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    if (!isHasSameHiddenLayer) {
        return false;
    }
    // same hidden nodes ?
    boolean isHasSameHiddenNodes = true;
    // same activations ?
    boolean isHasSameHiddenActivation = true;
    List<Integer> hiddenNodeList = (List<Integer>) modelParams.get(CommonConstants.NUM_HIDDEN_NODES);
    List<String> actFuncList = (List<String>) modelParams.get(CommonConstants.ACTIVATION_FUNC);
    for (int i = 1; i < model.getLayerCount() - 1; i++) {
        if (model.getLayerNeuronCount(i) > hiddenNodeList.get(i - 1)) {
            isHasSameHiddenNodes = false;
            break;
        }
        ActivationFunction activation = model.getActivation(i);
        String actFunc = actFuncList.get(i - 1);
        if (actFunc.equalsIgnoreCase(NNConstants.NN_LINEAR)) {
            isHasSameHiddenActivation = ActivationLinear.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIGMOID)) {
            isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_TANH)) {
            isHasSameHiddenActivation = ActivationTANH.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_LOG)) {
            isHasSameHiddenActivation = ActivationLOG.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIN)) {
            isHasSameHiddenActivation = ActivationSIN.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_RELU)) {
            isHasSameHiddenActivation = ActivationReLU.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_LEAKY_RELU)) {
            isHasSameHiddenActivation = ActivationLeakyReLU.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_SWISH)) {
            isHasSameHiddenActivation = ActivationSwish.class == activation.getClass();
        } else if (actFunc.equalsIgnoreCase(NNConstants.NN_PTANH)) {
            isHasSameHiddenActivation = ActivationPTANH.class == activation.getClass();
        } else {
            isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
        }
        if (!isHasSameHiddenActivation) {
            break;
        }
    }
    if (!isHasSameHiddenNodes || !isHasSameHiddenActivation) {
        return false;
    }
    return true;
}
Also used : BasicML(org.encog.ml.BasicML) RequiredFieldList(org.apache.pig.LoadPushDown.RequiredFieldList) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 17 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class ModelSpecLoaderUtils method loadSubModelSpec.

/**
 * Load sub-model with FileStatus
 *
 * @param modelConfig
 *            - {@link ModelConfig}, need this, since the model file may exist in HDFS
 * @param columnConfigList
 *            - List of {@link ColumnConfig}
 * @param fileStatus
 *            - {@link EvalConfig}, maybe null
 * @param sourceType
 *            - {@link SourceType}, HDFS or Local?
 * @param gbtConvertToProb
 *            - convert to probability or not for gbt model
 * @param gbtScoreConvertStrategy
 *            - gbt score conversion strategy
 * @return {@link ModelSpec} for sub-model
 */
private static ModelSpec loadSubModelSpec(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, FileStatus fileStatus, RawSourceData.SourceType sourceType, Boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
    FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    String subModelName = fileStatus.getPath().getName();
    List<FileStatus> modelFileStats = new ArrayList<FileStatus>();
    FileStatus[] subConfigs = new FileStatus[2];
    ALGORITHM algorithm = getModelsAlgAndSpecFiles(fileStatus, sourceType, modelFileStats, subConfigs);
    ModelSpec modelSpec = null;
    if (CollectionUtils.isNotEmpty(modelFileStats)) {
        Collections.sort(modelFileStats, new Comparator<FileStatus>() {

            @Override
            public int compare(FileStatus fa, FileStatus fb) {
                return fa.getPath().getName().compareTo(fb.getPath().getName());
            }
        });
        List<BasicML> models = new ArrayList<BasicML>();
        for (FileStatus f : modelFileStats) {
            models.add(loadModel(modelConfig, f.getPath(), fs, gbtConvertToProb, gbtScoreConvertStrategy));
        }
        ModelConfig subModelConfig = modelConfig;
        if (subConfigs[0] != null) {
            subModelConfig = CommonUtils.loadModelConfig(subConfigs[0].getPath().toString(), sourceType);
        }
        List<ColumnConfig> subColumnConfigList = columnConfigList;
        if (subConfigs[1] != null) {
            subColumnConfigList = CommonUtils.loadColumnConfigList(subConfigs[1].getPath().toString(), sourceType);
        }
        modelSpec = new ModelSpec(subModelName, subModelConfig, subColumnConfigList, algorithm, models);
    }
    return modelSpec;
}
Also used : FileStatus(org.apache.hadoop.fs.FileStatus) ALGORITHM(ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM) BasicML(org.encog.ml.BasicML) FileSystem(org.apache.hadoop.fs.FileSystem) ModelSpec(ml.shifu.shifu.core.model.ModelSpec)

Example 18 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class ModelSpecLoaderUtils method loadBasicModels.

/**
 * Load basic models by configuration
 *
 * @param modelConfig
 *            model config
 * @param evalConfig
 *            eval config
 * @param sourceType
 *            source type
 * @param gbtConvertToProb
 *            convert gbt score to prob or not
 * @param gbtScoreConvertStrategy
 *            specify how to convert gbt raw score
 * @return list of models
 * @throws IOException
 *             if any IO exception in reading model file.
 * @throws IllegalArgumentException
 *             if {@code modelConfig} is, if invalid model algorithm .
 * @throws IllegalStateException
 *             if not HDFS or LOCAL source type or algorithm not supported.
 */
public static List<BasicML> loadBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType, boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
    List<BasicML> models = new ArrayList<BasicML>();
    FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    // check if eval generic model, if so bypass the Shifu model loader procedure
    if (// generic or TensorFlow algorithm
    Constants.GENERIC.equalsIgnoreCase(modelConfig.getAlgorithm()) || Constants.TENSORFLOW.equalsIgnoreCase(modelConfig.getAlgorithm())) {
        List<FileStatus> genericModelConfigs = findGenericModels(modelConfig, evalConfig, sourceType);
        if (genericModelConfigs.isEmpty()) {
            throw new RuntimeException("Load generic model failed.");
        }
        loadGenericModels(modelConfig, genericModelConfigs, sourceType, models);
        log.debug("return generic model {}", models.size());
        return models;
    }
    List<FileStatus> modelFileStats = locateBasicModels(modelConfig, evalConfig, sourceType);
    if (CollectionUtils.isNotEmpty(modelFileStats)) {
        for (FileStatus fst : modelFileStats) {
            models.add(loadModel(modelConfig, fst.getPath(), fs, gbtConvertToProb, gbtScoreConvertStrategy));
        }
    }
    return models;
}
Also used : FileStatus(org.apache.hadoop.fs.FileStatus) FileSystem(org.apache.hadoop.fs.FileSystem) BasicML(org.encog.ml.BasicML)

Example 19 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class ModelSpecLoaderUtils method loadModel.

/**
 * Loading model according to existing model path.
 *
 * @param modelConfig
 *            model config
 * @param modelPath
 *            the path to store model
 * @param fs
 *            file system used to store model
 * @param gbtConvertToProb
 *            convert gbt score to prob or not
 * @param gbtScoreConvertStrategy
 *            specify how to convert gbt raw score
 * @return model object or null if no modelPath file,
 * @throws IOException
 *             if loading file for any IOException
 */
public static BasicML loadModel(ModelConfig modelConfig, Path modelPath, FileSystem fs, boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
    if (!fs.exists(modelPath)) {
        // no such existing model, return null.
        return null;
    }
    // we have to register PersistBasicFloatNetwork for loading such models
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    FSDataInputStream stream = null;
    BufferedReader br = null;
    try {
        stream = fs.open(modelPath);
        if (modelPath.getName().endsWith(LogisticRegressionContants.LR_ALG_NAME.toLowerCase())) {
            // LR model
            br = new BufferedReader(new InputStreamReader(stream));
            try {
                return LR.loadFromString(br.readLine());
            } catch (Exception e) {
                // local LR model?
                // close and reopen
                IOUtils.closeQuietly(br);
                stream = fs.open(modelPath);
                return BasicML.class.cast(EncogDirectoryPersistence.loadObject(stream));
            }
        } else if (// RF or GBT
        modelPath.getName().endsWith(CommonConstants.RF_ALG_NAME.toLowerCase()) || modelPath.getName().endsWith(CommonConstants.GBT_ALG_NAME.toLowerCase())) {
            return TreeModel.loadFromStream(stream, gbtConvertToProb, gbtScoreConvertStrategy);
        } else {
            GzipStreamPair pair = GzipStreamPair.isGZipFormat(stream);
            if (pair.isGzip()) {
                return BasicML.class.cast(NNModel.loadFromStream(pair.getInput()));
            } else {
                return BasicML.class.cast(EncogDirectoryPersistence.loadObject(pair.getInput()));
            }
        }
    } catch (Exception e) {
        String msg = "the expecting model file is: " + modelPath;
        throw new ShifuException(ShifuErrorCode.ERROR_FAIL_TO_LOAD_MODEL_FILE, e, msg);
    } finally {
        IOUtils.closeQuietly(br);
        IOUtils.closeQuietly(stream);
    }
}
Also used : FSDataInputStream(org.apache.hadoop.fs.FSDataInputStream) BasicML(org.encog.ml.BasicML) ShifuException(ml.shifu.shifu.exception.ShifuException) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 20 with BasicML

use of org.encog.ml.BasicML in project shifu by ShifuML.

the class NNModelSpecTest method testModelStructureCompare.

@Test
public void testModelStructureCompare() {
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, extendedFlatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, flatNetwork), 0);
    Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, flatNetwork), 1);
    BasicML diffBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model2.nn")));
    BasicNetwork diffBasicNetwork = (BasicNetwork) diffBasicML;
    FlatNetwork diffFlatNetwork = diffBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, diffFlatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, flatNetwork), -1);
    Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, diffFlatNetwork), 1);
    Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, extendedFlatNetwork), -1);
    BasicML deepBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model3.nn")));
    BasicNetwork deppBasicNetwork = (BasicNetwork) deepBasicML;
    FlatNetwork deepFlatNetwork = deppBasicNetwork.getFlat();
    Assert.assertEquals(new NNStructureComparator().compare(deepFlatNetwork, flatNetwork), 1);
    Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, deepFlatNetwork), -1);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) BasicNetwork(org.encog.neural.networks.BasicNetwork) NNStructureComparator(ml.shifu.shifu.core.dtrain.nn.NNStructureComparator) BasicML(org.encog.ml.BasicML) File(java.io.File) Test(org.testng.annotations.Test)

Aggregations

BasicML (org.encog.ml.BasicML)23 File (java.io.File)6 BasicNetwork (org.encog.neural.networks.BasicNetwork)5 IOException (java.io.IOException)4 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)4 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)4 PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)4 FileSystem (org.apache.hadoop.fs.FileSystem)4 FlatNetwork (org.encog.neural.flat.FlatNetwork)4 ArrayList (java.util.ArrayList)3 NSColumn (ml.shifu.shifu.column.NSColumn)3 ModelRunner (ml.shifu.shifu.core.ModelRunner)3 ModelSpec (ml.shifu.shifu.core.model.ModelSpec)3 MutablePair (org.apache.commons.lang3.tuple.MutablePair)3 Configuration (org.apache.hadoop.conf.Configuration)3 FileStatus (org.apache.hadoop.fs.FileStatus)3 Path (org.apache.hadoop.fs.Path)3 JarFile (java.util.jar.JarFile)2 CaseScoreResult (ml.shifu.shifu.container.CaseScoreResult)2 SourceType (ml.shifu.shifu.container.obj.RawSourceData.SourceType)2