Search in sources :

Example 6 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class ModelSpecLoaderUtils method loadSubModels.

/**
 * Load sub-models under current model space
 *
 * @param modelConfig
 *            - {@link ModelConfig}, need this, since the model file may exist in HDFS
 * @param columnConfigList
 *            - List of {@link ColumnConfig}
 * @param evalConfig
 *            - {@link EvalConfig}, maybe null
 * @param sourceType
 *            - {@link SourceType}, HDFS or Local?
 * @param gbtConvertToProb
 *            - convert to probability or not for gbt model
 * @param gbtScoreConvertStrategy
 *            - gbt score conversion strategy
 * @return list of {@link ModelSpec} for sub models
 */
@SuppressWarnings("deprecation")
public static List<ModelSpec> loadSubModels(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, EvalConfig evalConfig, RawSourceData.SourceType sourceType, Boolean gbtConvertToProb, String gbtScoreConvertStrategy) {
    List<ModelSpec> modelSpecs = new ArrayList<ModelSpec>();
    FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
    // we have to register PersistBasicFloatNetwork for loading such models
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    PathFinder pathFinder = new PathFinder(modelConfig);
    String modelsPath = null;
    if (evalConfig == null || StringUtils.isEmpty(evalConfig.getModelsPath())) {
        modelsPath = pathFinder.getModelsPath(sourceType);
    } else {
        modelsPath = evalConfig.getModelsPath();
    }
    try {
        FileStatus[] fsArr = fs.listStatus(new Path(modelsPath));
        for (FileStatus fileStatus : fsArr) {
            if (fileStatus.isDir()) {
                ModelSpec modelSpec = loadSubModelSpec(modelConfig, columnConfigList, fileStatus, sourceType, gbtConvertToProb, gbtScoreConvertStrategy);
                if (modelSpec != null) {
                    modelSpecs.add(modelSpec);
                }
            }
        }
    } catch (IOException e) {
        log.error("Error occurred when loading sub-models.", e);
    }
    return modelSpecs;
}
Also used : Path(org.apache.hadoop.fs.Path) FileStatus(org.apache.hadoop.fs.FileStatus) FileSystem(org.apache.hadoop.fs.FileSystem) PathFinder(ml.shifu.shifu.fs.PathFinder) ModelSpec(ml.shifu.shifu.core.model.ModelSpec) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 7 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class ModelSpecLoaderUtils method loadModel.

/**
 * Loading model according to existing model path.
 *
 * @param modelConfig
 *            model config
 * @param modelPath
 *            the path to store model
 * @param fs
 *            file system used to store model
 * @param gbtConvertToProb
 *            convert gbt score to prob or not
 * @param gbtScoreConvertStrategy
 *            specify how to convert gbt raw score
 * @return model object or null if no modelPath file,
 * @throws IOException
 *             if loading file for any IOException
 */
public static BasicML loadModel(ModelConfig modelConfig, Path modelPath, FileSystem fs, boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
    if (!fs.exists(modelPath)) {
        // no such existing model, return null.
        return null;
    }
    // we have to register PersistBasicFloatNetwork for loading such models
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    FSDataInputStream stream = null;
    BufferedReader br = null;
    try {
        stream = fs.open(modelPath);
        if (modelPath.getName().endsWith(LogisticRegressionContants.LR_ALG_NAME.toLowerCase())) {
            // LR model
            br = new BufferedReader(new InputStreamReader(stream));
            try {
                return LR.loadFromString(br.readLine());
            } catch (Exception e) {
                // local LR model?
                // close and reopen
                IOUtils.closeQuietly(br);
                stream = fs.open(modelPath);
                return BasicML.class.cast(EncogDirectoryPersistence.loadObject(stream));
            }
        } else if (// RF or GBT
        modelPath.getName().endsWith(CommonConstants.RF_ALG_NAME.toLowerCase()) || modelPath.getName().endsWith(CommonConstants.GBT_ALG_NAME.toLowerCase())) {
            return TreeModel.loadFromStream(stream, gbtConvertToProb, gbtScoreConvertStrategy);
        } else {
            GzipStreamPair pair = GzipStreamPair.isGZipFormat(stream);
            if (pair.isGzip()) {
                return BasicML.class.cast(NNModel.loadFromStream(pair.getInput()));
            } else {
                return BasicML.class.cast(EncogDirectoryPersistence.loadObject(pair.getInput()));
            }
        }
    } catch (Exception e) {
        String msg = "the expecting model file is: " + modelPath;
        throw new ShifuException(ShifuErrorCode.ERROR_FAIL_TO_LOAD_MODEL_FILE, e, msg);
    } finally {
        IOUtils.closeQuietly(br);
        IOUtils.closeQuietly(stream);
    }
}
Also used : FSDataInputStream(org.apache.hadoop.fs.FSDataInputStream) BasicML(org.encog.ml.BasicML) ShifuException(ml.shifu.shifu.exception.ShifuException) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) ShifuException(ml.shifu.shifu.exception.ShifuException)

Example 8 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class NNOutput method initNetwork.

@SuppressWarnings("unchecked")
private void initNetwork(MasterContext<NNParams, NNParams> context) {
    int[] inputOutputIndex = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), this.columnConfigList);
    boolean isLinearTarget = CommonUtils.isLinearTarget(modelConfig, columnConfigList);
    @SuppressWarnings("unused") int inputNodeCount = inputOutputIndex[0] == 0 ? inputOutputIndex[2] : inputOutputIndex[0];
    // if is one vs all classification, outputNodeCount is set to 1, if classes=2, outputNodeCount is also 1
    int classes = modelConfig.getTags().size();
    int outputNodeCount = (isLinearTarget || modelConfig.isRegression()) ? inputOutputIndex[1] : (modelConfig.getTrain().isOneVsAll() ? inputOutputIndex[1] : (classes == 2 ? 1 : classes));
    int numLayers = (Integer) validParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
    List<String> actFunc = (List<String>) validParams.get(CommonConstants.ACTIVATION_FUNC);
    List<Integer> hiddenNodeList = (List<Integer>) validParams.get(CommonConstants.NUM_HIDDEN_NODES);
    boolean isAfterVarSelect = inputOutputIndex[0] != 0;
    // cache all feature list for sampling features
    List<Integer> allFeatures = NormalUtils.getAllFeatureList(columnConfigList, isAfterVarSelect);
    String subsetStr = context.getProps().getProperty(CommonConstants.SHIFU_NN_FEATURE_SUBSET);
    if (StringUtils.isBlank(subsetStr)) {
        this.subFeatures = new HashSet<Integer>(allFeatures);
    } else {
        String[] splits = subsetStr.split(",");
        this.subFeatures = new HashSet<Integer>();
        for (String split : splits) {
            this.subFeatures.add(Integer.parseInt(split));
        }
    }
    int featureInputsCnt = DTrainUtils.getFeatureInputsCnt(modelConfig, columnConfigList, this.subFeatures);
    String outputActivationFunc = (String) validParams.get(CommonConstants.OUTPUT_ACTIVATION_FUNC);
    this.network = DTrainUtils.generateNetwork(featureInputsCnt, outputNodeCount, numLayers, actFunc, hiddenNodeList, false, this.dropoutRate, this.wgtInit, CommonUtils.isLinearTarget(modelConfig, columnConfigList), outputActivationFunc);
    ((BasicFloatNetwork) this.network).setFeatureSet(this.subFeatures);
    // register here to save models
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
}
Also used : List(java.util.List) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 9 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class BinaryNNSerializer method save.

public static void save(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, List<BasicML> basicNetworks, FileSystem fs, Path output) throws IOException {
    DataOutputStream fos = null;
    try {
        fos = new DataOutputStream(new GZIPOutputStream(fs.create(output)));
        // version
        fos.writeInt(CommonConstants.NN_FORMAT_VERSION);
        // write normStr
        String normStr = modelConfig.getNormalize().getNormType().toString();
        ml.shifu.shifu.core.dtrain.StringUtils.writeString(fos, normStr);
        // compute columns needed
        Map<Integer, String> columnIndexNameMapping = getIndexNameMapping(columnConfigList);
        // write column stats to output
        List<NNColumnStats> csList = new ArrayList<NNColumnStats>();
        for (ColumnConfig cc : columnConfigList) {
            if (columnIndexNameMapping.containsKey(cc.getColumnNum())) {
                NNColumnStats cs = new NNColumnStats();
                cs.setCutoff(modelConfig.getNormalizeStdDevCutOff());
                cs.setColumnType(cc.getColumnType());
                cs.setMean(cc.getMean());
                cs.setStddev(cc.getStdDev());
                cs.setColumnNum(cc.getColumnNum());
                cs.setColumnName(cc.getColumnName());
                cs.setBinCategories(cc.getBinCategory());
                cs.setBinBoundaries(cc.getBinBoundary());
                cs.setBinPosRates(cc.getBinPosRate());
                cs.setBinCountWoes(cc.getBinCountWoe());
                cs.setBinWeightWoes(cc.getBinWeightedWoe());
                // TODO cache such computation
                double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, false);
                cs.setWoeMean(meanAndStdDev[0]);
                cs.setWoeStddev(meanAndStdDev[1]);
                double[] WgtMeanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(cc, true);
                cs.setWoeWgtMean(WgtMeanAndStdDev[0]);
                cs.setWoeWgtStddev(WgtMeanAndStdDev[1]);
                csList.add(cs);
            }
        }
        fos.writeInt(csList.size());
        for (NNColumnStats cs : csList) {
            cs.write(fos);
        }
        // write column index mapping
        Map<Integer, Integer> columnMapping = getColumnMapping(columnConfigList);
        fos.writeInt(columnMapping.size());
        for (Entry<Integer, Integer> entry : columnMapping.entrySet()) {
            fos.writeInt(entry.getKey());
            fos.writeInt(entry.getValue());
        }
        // persist network, set it as list
        fos.writeInt(basicNetworks.size());
        for (BasicML network : basicNetworks) {
            new PersistBasicFloatNetwork().saveNetwork(fos, (BasicFloatNetwork) network);
        }
    } finally {
        IOUtils.closeStream(fos);
    }
}
Also used : ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) DataOutputStream(java.io.DataOutputStream) ArrayList(java.util.ArrayList) BasicML(org.encog.ml.BasicML) GZIPOutputStream(java.util.zip.GZIPOutputStream) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Aggregations

PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)9 BasicML (org.encog.ml.BasicML)4 ArrayList (java.util.ArrayList)2 List (java.util.List)2 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)2 FSDataInputStream (org.apache.hadoop.fs.FSDataInputStream)2 FileStatus (org.apache.hadoop.fs.FileStatus)2 FileSystem (org.apache.hadoop.fs.FileSystem)2 Path (org.apache.hadoop.fs.Path)2 BufferedInputStream (java.io.BufferedInputStream)1 DataInputStream (java.io.DataInputStream)1 DataOutputStream (java.io.DataOutputStream)1 File (java.io.File)1 IOException (java.io.IOException)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 GZIPOutputStream (java.util.zip.GZIPOutputStream)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 ColumnType (ml.shifu.shifu.container.obj.ColumnType)1