Search in sources :

Example 1 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class VarSelectMapper method loadModel.

/**
 * Load first model in model path as a {@link MLRegression} instance.
 */
private synchronized void loadModel() throws IOException {
    LOG.debug("Before loading model with memory {} in thread {}.", MemoryUtils.getRuntimeMemoryStats(), Thread.currentThread().getName());
    long start = System.currentTimeMillis();
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(SourceType.LOCAL);
    // load model from local d-cache model file
    model = (MLRegression) ModelSpecLoaderUtils.loadModel(modelConfig, new Path("model0." + modelConfig.getAlgorithm().toLowerCase()), fs);
    LOG.debug("After load model class {} with time {}ms and memory {} in thread {}.", model.getClass().getName(), (System.currentTimeMillis() - start), MemoryUtils.getRuntimeMemoryStats(), Thread.currentThread().getName());
}
Also used : Path(org.apache.hadoop.fs.Path) FileSystem(org.apache.hadoop.fs.FileSystem) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 2 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class IndependentNNModel method loadFromStream.

/**
 * Load model instance from input stream which is saved in NNOutput for specified binary format.
 *
 * @param input
 *            the input stream, flat input stream or gzip input stream both OK
 * @param isRemoveNameSpace
 *            is remove name space or not
 * @return the nn model instance
 * @throws IOException
 *             any IOException in de-serialization.
 */
public static IndependentNNModel loadFromStream(InputStream input, boolean isRemoveNameSpace) throws IOException {
    DataInputStream dis = null;
    // check if gzip or not
    try {
        byte[] header = new byte[2];
        BufferedInputStream bis = new BufferedInputStream(input);
        bis.mark(2);
        int result = bis.read(header);
        bis.reset();
        int ss = (header[0] & 0xff) | ((header[1] & 0xff) << 8);
        if (result != -1 && ss == GZIPInputStream.GZIP_MAGIC) {
            dis = new DataInputStream(new GZIPInputStream(bis));
        } else {
            dis = new DataInputStream(bis);
        }
    } catch (java.io.IOException e) {
        dis = new DataInputStream(input);
    }
    int version = dis.readInt();
    IndependentNNModel.setVersion(version);
    String normStr = ml.shifu.shifu.core.dtrain.StringUtils.readString(dis);
    NormType normType = NormType.valueOf(normStr.toUpperCase());
    // for all features
    Map<Integer, String> numNameMap = new HashMap<Integer, String>();
    Map<Integer, List<String>> cateColumnNameNames = new HashMap<Integer, List<String>>();
    // for categorical features
    Map<Integer, Map<String, Double>> cateWoeMap = new HashMap<Integer, Map<String, Double>>();
    Map<Integer, Map<String, Double>> cateWgtWoeMap = new HashMap<Integer, Map<String, Double>>();
    Map<Integer, Map<String, Double>> binPosRateMap = new HashMap<Integer, Map<String, Double>>();
    // for numerical features
    Map<Integer, List<Double>> numerBinBoundaries = new HashMap<Integer, List<Double>>();
    Map<Integer, List<Double>> numerWoes = new HashMap<Integer, List<Double>>();
    Map<Integer, List<Double>> numerWgtWoes = new HashMap<Integer, List<Double>>();
    // for all features
    Map<Integer, Double> numerMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> numerStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> woeMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> woeStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> wgtWoeMeanMap = new HashMap<Integer, Double>();
    Map<Integer, Double> wgtWoeStddevMap = new HashMap<Integer, Double>();
    Map<Integer, Double> cutoffMap = new HashMap<Integer, Double>();
    Map<Integer, ColumnType> columnTypeMap = new HashMap<Integer, ColumnType>();
    Map<Integer, Map<String, Integer>> cateIndexMapping = new HashMap<Integer, Map<String, Integer>>();
    int columnSize = dis.readInt();
    for (int i = 0; i < columnSize; i++) {
        NNColumnStats cs = new NNColumnStats();
        cs.readFields(dis);
        List<Double> binWoes = cs.getBinCountWoes();
        List<Double> binWgtWoes = cs.getBinWeightWoes();
        List<Double> binPosRates = cs.getBinPosRates();
        int columnNum = cs.getColumnNum();
        columnTypeMap.put(columnNum, cs.getColumnType());
        if (isRemoveNameSpace) {
            // remove name-space in column name to make it be called by simple name
            numNameMap.put(columnNum, StringUtils.getSimpleColumnName(cs.getColumnName()));
        } else {
            numNameMap.put(columnNum, cs.getColumnName());
        }
        // for categorical features
        Map<String, Double> woeMap = new HashMap<String, Double>();
        Map<String, Double> woeWgtMap = new HashMap<String, Double>();
        Map<String, Double> posRateMap = new HashMap<String, Double>();
        Map<String, Integer> cateIndexMap = new HashMap<String, Integer>();
        if (cs.isCategorical() || cs.isHybrid()) {
            List<String> binCategories = cs.getBinCategories();
            cateColumnNameNames.put(columnNum, binCategories);
            for (int j = 0; j < binCategories.size(); j++) {
                String currCate = binCategories.get(j);
                if (currCate.contains(Constants.CATEGORICAL_GROUP_VAL_DELIMITER)) {
                    // merged category should be flatten, use own split function to avoid depending on guava jar in
                    // prediction
                    String[] splits = StringUtils.split(currCate, Constants.CATEGORICAL_GROUP_VAL_DELIMITER);
                    for (String str : splits) {
                        woeMap.put(str, binWoes.get(j));
                        woeWgtMap.put(str, binWgtWoes.get(j));
                        posRateMap.put(str, binPosRates.get(j));
                        cateIndexMap.put(str, j);
                    }
                } else {
                    woeMap.put(currCate, binWoes.get(j));
                    woeWgtMap.put(currCate, binWgtWoes.get(j));
                    posRateMap.put(currCate, binPosRates.get(j));
                    cateIndexMap.put(currCate, j);
                }
            }
            // append last missing bin
            woeMap.put(Constants.EMPTY_CATEGORY, binWoes.get(binCategories.size()));
            woeWgtMap.put(Constants.EMPTY_CATEGORY, binWgtWoes.get(binCategories.size()));
            posRateMap.put(Constants.EMPTY_CATEGORY, binPosRates.get(binCategories.size()));
        }
        if (cs.isNumerical() || cs.isHybrid()) {
            numerBinBoundaries.put(columnNum, cs.getBinBoundaries());
            numerWoes.put(columnNum, binWoes);
            numerWgtWoes.put(columnNum, binWgtWoes);
        }
        cateWoeMap.put(columnNum, woeMap);
        cateWgtWoeMap.put(columnNum, woeWgtMap);
        binPosRateMap.put(columnNum, posRateMap);
        cateIndexMapping.put(columnNum, cateIndexMap);
        numerMeanMap.put(columnNum, cs.getMean());
        numerStddevMap.put(columnNum, cs.getStddev());
        woeMeanMap.put(columnNum, cs.getWoeMean());
        woeStddevMap.put(columnNum, cs.getWoeStddev());
        wgtWoeMeanMap.put(columnNum, cs.getWoeWgtMean());
        wgtWoeStddevMap.put(columnNum, cs.getWoeWgtStddev());
        cutoffMap.put(columnNum, cs.getCutoff());
    }
    Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
    int columnMapSize = dis.readInt();
    for (int i = 0; i < columnMapSize; i++) {
        columnMap.put(dis.readInt(), dis.readInt());
    }
    int size = dis.readInt();
    List<BasicFloatNetwork> networks = new ArrayList<BasicFloatNetwork>();
    for (int i = 0; i < size; i++) {
        networks.add(new PersistBasicFloatNetwork().readNetwork(dis));
    }
    return new IndependentNNModel(networks, normType, numNameMap, cateColumnNameNames, columnMap, cateWoeMap, cateWgtWoeMap, binPosRateMap, numerBinBoundaries, numerWgtWoes, numerWoes, cutoffMap, numerMeanMap, numerStddevMap, woeMeanMap, woeStddevMap, wgtWoeMeanMap, wgtWoeStddevMap, columnTypeMap, cateIndexMapping);
}
Also used : ColumnType(ml.shifu.shifu.container.obj.ColumnType) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) GZIPInputStream(java.util.zip.GZIPInputStream) BufferedInputStream(java.io.BufferedInputStream) ArrayList(java.util.ArrayList) List(java.util.List) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork) NormType(ml.shifu.shifu.container.obj.ModelNormalizeConf.NormType) IOException(java.io.IOException) DataInputStream(java.io.DataInputStream) HashMap(java.util.HashMap) Map(java.util.Map) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 3 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class ModelSpecLoaderUtils method locateBasicModels.

/**
 * Find model spec files
 *
 * @param modelConfig
 *            model config
 * @param evalConfig
 *            eval configuration
 * @param sourceType
 *            {@link SourceType} LOCAL or HDFS?
 * @return The basic model file list
 * @throws IOException
 *             Exception when fail to load basic models
 */
public static List<FileStatus> locateBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType) throws IOException {
    // we have to register PersistBasicFloatNetwork for loading such models
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    List<FileStatus> listStatus = findModels(modelConfig, evalConfig, sourceType);
    if (CollectionUtils.isEmpty(listStatus)) {
        // throw new ShifuException(ShifuErrorCode.ERROR_MODEL_FILE_NOT_FOUND);
        // disable exception, since we there maybe sub-models
        listStatus = findGenericModels(modelConfig, evalConfig, sourceType);
        // if models not found, continue which makes eval works when training is in progress.
        if (CollectionUtils.isNotEmpty(listStatus)) {
            return listStatus;
        }
    }
    // to avoid the *unix and windows file list order
    Collections.sort(listStatus, new Comparator<FileStatus>() {

        @Override
        public int compare(FileStatus f1, FileStatus f2) {
            return f1.getPath().getName().compareToIgnoreCase(f2.getPath().getName());
        }
    });
    // added in shifu 0.2.5 to slice models not belonging to last training
    int baggingModelSize = modelConfig.getTrain().getBaggingNum();
    if (modelConfig.isClassification() && modelConfig.getTrain().isOneVsAll()) {
        baggingModelSize = modelConfig.getTags().size();
    }
    Integer kCrossValidation = modelConfig.getTrain().getNumKFold();
    if (kCrossValidation != null && kCrossValidation > 0) {
        // if kfold is enabled , bagging set it to bagging model size
        baggingModelSize = kCrossValidation;
    }
    GridSearch gs = new GridSearch(modelConfig.getTrain().getParams(), modelConfig.getTrain().getGridConfigFileContent());
    if (gs.hasHyperParam()) {
        // if it is grid search, set model size to all flatten params
        baggingModelSize = gs.getFlattenParams().size();
    }
    listStatus = listStatus.size() <= baggingModelSize ? listStatus : listStatus.subList(0, baggingModelSize);
    return listStatus;
}
Also used : FileStatus(org.apache.hadoop.fs.FileStatus) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) GridSearch(ml.shifu.shifu.core.dtrain.gs.GridSearch)

Example 4 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class ModelSpecLoaderUtils method loadBasicModels.

/**
 * Load neural network models from specified file path
 *
 * @param modelsPath
 *            - a file or directory that contains .nn files
 * @param alg
 *            the algorithm
 * @param isConvertToProb
 *            if convert to prob for gbt model
 * @param gbtScoreConvertStrategy
 *            specify how to convert gbt raw score
 * @return - a list of @BasicML
 * @throws IOException
 *             - throw exception when loading model files
 */
public static List<BasicML> loadBasicModels(final String modelsPath, final ALGORITHM alg, boolean isConvertToProb, String gbtScoreConvertStrategy) throws IOException {
    if (modelsPath == null || alg == null || ALGORITHM.DT.equals(alg)) {
        throw new IllegalArgumentException("The model path shouldn't be null");
    }
    // we have to register PersistBasicFloatNetwork for loading such models
    if (ALGORITHM.NN.equals(alg)) {
        PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    }
    File modelsPathDir = new File(modelsPath);
    File[] modelFiles = modelsPathDir.listFiles(new FilenameFilter() {

        @Override
        public boolean accept(File dir, String name) {
            return name.endsWith("." + alg.name().toLowerCase());
        }
    });
    if (modelFiles != null) {
        // sort file names
        Arrays.sort(modelFiles, new Comparator<File>() {

            @Override
            public int compare(File from, File to) {
                return from.getName().compareTo(to.getName());
            }
        });
        List<BasicML> models = new ArrayList<BasicML>(modelFiles.length);
        for (File nnf : modelFiles) {
            InputStream is = null;
            try {
                is = new FileInputStream(nnf);
                if (ALGORITHM.NN.equals(alg)) {
                    GzipStreamPair pair = GzipStreamPair.isGZipFormat(is);
                    if (pair.isGzip()) {
                        models.add(BasicML.class.cast(NNModel.loadFromStream(pair.getInput())));
                    } else {
                        models.add(BasicML.class.cast(EncogDirectoryPersistence.loadObject(pair.getInput())));
                    }
                } else if (ALGORITHM.LR.equals(alg)) {
                    models.add(LR.loadFromStream(is));
                } else if (ALGORITHM.GBT.equals(alg) || ALGORITHM.RF.equals(alg)) {
                    models.add(TreeModel.loadFromStream(is, isConvertToProb, gbtScoreConvertStrategy));
                }
            } finally {
                IOUtils.closeQuietly(is);
            }
        }
        return models;
    } else {
        throw new IOException(String.format("Failed to list files in %s", modelsPathDir.getAbsolutePath()));
    }
}
Also used : FSDataInputStream(org.apache.hadoop.fs.FSDataInputStream) BasicML(org.encog.ml.BasicML) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)

Example 5 with PersistBasicFloatNetwork

use of ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork in project shifu by ShifuML.

the class NNModelSpecTest method testModelFitIn.

@Test
public void testModelFitIn() {
    PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
    BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model5.nn")));
    BasicNetwork basicNetwork = (BasicNetwork) basicML;
    FlatNetwork flatNetwork = basicNetwork.getFlat();
    BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model6.nn")));
    BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
    FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
    NNMaster master = new NNMaster();
    Set<Integer> fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }));
    Assert.assertEquals(fixedWeightIndexSet.size(), 931);
    fixedWeightIndexSet = master.fitExistingModelIn(flatNetwork, extendedFlatNetwork, Arrays.asList(new Integer[] { 1, 2, 3 }), false);
    Assert.assertEquals(fixedWeightIndexSet.size(), 910);
}
Also used : FlatNetwork(org.encog.neural.flat.FlatNetwork) NNMaster(ml.shifu.shifu.core.dtrain.nn.NNMaster) BasicNetwork(org.encog.neural.networks.BasicNetwork) BasicML(org.encog.ml.BasicML) File(java.io.File) PersistBasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork) Test(org.testng.annotations.Test)

Aggregations

PersistBasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.PersistBasicFloatNetwork)9 BasicML (org.encog.ml.BasicML)4 ArrayList (java.util.ArrayList)2 List (java.util.List)2 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)2 FSDataInputStream (org.apache.hadoop.fs.FSDataInputStream)2 FileStatus (org.apache.hadoop.fs.FileStatus)2 FileSystem (org.apache.hadoop.fs.FileSystem)2 Path (org.apache.hadoop.fs.Path)2 BufferedInputStream (java.io.BufferedInputStream)1 DataInputStream (java.io.DataInputStream)1 DataOutputStream (java.io.DataOutputStream)1 File (java.io.File)1 IOException (java.io.IOException)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 GZIPInputStream (java.util.zip.GZIPInputStream)1 GZIPOutputStream (java.util.zip.GZIPOutputStream)1 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)1 ColumnType (ml.shifu.shifu.container.obj.ColumnType)1