use of org.encog.ml.BasicML in project shifu by ShifuML.
the class TrainModelProcessor method inputOutputModelCheckSuccess.
@SuppressWarnings("unchecked")
private boolean inputOutputModelCheckSuccess(FileSystem fileSystem, Path modelPath, Map<String, Object> modelParams) throws IOException {
BasicML basicML = ModelSpecLoaderUtils.loadModel(this.modelConfig, modelPath, fileSystem);
BasicFloatNetwork model = (BasicFloatNetwork) ModelSpecLoaderUtils.getBasicNetwork(basicML);
int[] outputCandidateCounts = DTrainUtils.getInputOutputCandidateCounts(modelConfig.getNormalizeType(), getColumnConfigList());
int inputs = outputCandidateCounts[0] == 0 ? outputCandidateCounts[2] : outputCandidateCounts[0];
boolean isInputOutConsistent = model.getInputCount() <= inputs && model.getOutputCount() == outputCandidateCounts[1];
if (!isInputOutConsistent) {
return false;
}
// same hidden layer ?
boolean isHasSameHiddenLayer = (model.getLayerCount() - 2) <= (Integer) modelParams.get(CommonConstants.NUM_HIDDEN_LAYERS);
if (!isHasSameHiddenLayer) {
return false;
}
// same hidden nodes ?
boolean isHasSameHiddenNodes = true;
// same activations ?
boolean isHasSameHiddenActivation = true;
List<Integer> hiddenNodeList = (List<Integer>) modelParams.get(CommonConstants.NUM_HIDDEN_NODES);
List<String> actFuncList = (List<String>) modelParams.get(CommonConstants.ACTIVATION_FUNC);
for (int i = 1; i < model.getLayerCount() - 1; i++) {
if (model.getLayerNeuronCount(i) > hiddenNodeList.get(i - 1)) {
isHasSameHiddenNodes = false;
break;
}
ActivationFunction activation = model.getActivation(i);
String actFunc = actFuncList.get(i - 1);
if (actFunc.equalsIgnoreCase(NNConstants.NN_LINEAR)) {
isHasSameHiddenActivation = ActivationLinear.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIGMOID)) {
isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_TANH)) {
isHasSameHiddenActivation = ActivationTANH.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_LOG)) {
isHasSameHiddenActivation = ActivationLOG.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_SIN)) {
isHasSameHiddenActivation = ActivationSIN.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_RELU)) {
isHasSameHiddenActivation = ActivationReLU.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_LEAKY_RELU)) {
isHasSameHiddenActivation = ActivationLeakyReLU.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_SWISH)) {
isHasSameHiddenActivation = ActivationSwish.class == activation.getClass();
} else if (actFunc.equalsIgnoreCase(NNConstants.NN_PTANH)) {
isHasSameHiddenActivation = ActivationPTANH.class == activation.getClass();
} else {
isHasSameHiddenActivation = ActivationSigmoid.class == activation.getClass();
}
if (!isHasSameHiddenActivation) {
break;
}
}
if (!isHasSameHiddenNodes || !isHasSameHiddenActivation) {
return false;
}
return true;
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class ModelSpecLoaderUtils method loadSubModelSpec.
/**
* Load sub-model with FileStatus
*
* @param modelConfig
* - {@link ModelConfig}, need this, since the model file may exist in HDFS
* @param columnConfigList
* - List of {@link ColumnConfig}
* @param fileStatus
* - {@link EvalConfig}, maybe null
* @param sourceType
* - {@link SourceType}, HDFS or Local?
* @param gbtConvertToProb
* - convert to probability or not for gbt model
* @param gbtScoreConvertStrategy
* - gbt score conversion strategy
* @return {@link ModelSpec} for sub-model
*/
private static ModelSpec loadSubModelSpec(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, FileStatus fileStatus, RawSourceData.SourceType sourceType, Boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
String subModelName = fileStatus.getPath().getName();
List<FileStatus> modelFileStats = new ArrayList<FileStatus>();
FileStatus[] subConfigs = new FileStatus[2];
ALGORITHM algorithm = getModelsAlgAndSpecFiles(fileStatus, sourceType, modelFileStats, subConfigs);
ModelSpec modelSpec = null;
if (CollectionUtils.isNotEmpty(modelFileStats)) {
Collections.sort(modelFileStats, new Comparator<FileStatus>() {
@Override
public int compare(FileStatus fa, FileStatus fb) {
return fa.getPath().getName().compareTo(fb.getPath().getName());
}
});
List<BasicML> models = new ArrayList<BasicML>();
for (FileStatus f : modelFileStats) {
models.add(loadModel(modelConfig, f.getPath(), fs, gbtConvertToProb, gbtScoreConvertStrategy));
}
ModelConfig subModelConfig = modelConfig;
if (subConfigs[0] != null) {
subModelConfig = CommonUtils.loadModelConfig(subConfigs[0].getPath().toString(), sourceType);
}
List<ColumnConfig> subColumnConfigList = columnConfigList;
if (subConfigs[1] != null) {
subColumnConfigList = CommonUtils.loadColumnConfigList(subConfigs[1].getPath().toString(), sourceType);
}
modelSpec = new ModelSpec(subModelName, subModelConfig, subColumnConfigList, algorithm, models);
}
return modelSpec;
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class ModelSpecLoaderUtils method loadBasicModels.
/**
* Load basic models by configuration
*
* @param modelConfig
* model config
* @param evalConfig
* eval config
* @param sourceType
* source type
* @param gbtConvertToProb
* convert gbt score to prob or not
* @param gbtScoreConvertStrategy
* specify how to convert gbt raw score
* @return list of models
* @throws IOException
* if any IO exception in reading model file.
* @throws IllegalArgumentException
* if {@code modelConfig} is, if invalid model algorithm .
* @throws IllegalStateException
* if not HDFS or LOCAL source type or algorithm not supported.
*/
public static List<BasicML> loadBasicModels(ModelConfig modelConfig, EvalConfig evalConfig, SourceType sourceType, boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
List<BasicML> models = new ArrayList<BasicML>();
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
// check if eval generic model, if so bypass the Shifu model loader procedure
if (// generic or TensorFlow algorithm
Constants.GENERIC.equalsIgnoreCase(modelConfig.getAlgorithm()) || Constants.TENSORFLOW.equalsIgnoreCase(modelConfig.getAlgorithm())) {
List<FileStatus> genericModelConfigs = findGenericModels(modelConfig, evalConfig, sourceType);
if (genericModelConfigs.isEmpty()) {
throw new RuntimeException("Load generic model failed.");
}
loadGenericModels(modelConfig, genericModelConfigs, sourceType, models);
log.debug("return generic model {}", models.size());
return models;
}
List<FileStatus> modelFileStats = locateBasicModels(modelConfig, evalConfig, sourceType);
if (CollectionUtils.isNotEmpty(modelFileStats)) {
for (FileStatus fst : modelFileStats) {
models.add(loadModel(modelConfig, fst.getPath(), fs, gbtConvertToProb, gbtScoreConvertStrategy));
}
}
return models;
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class ModelSpecLoaderUtils method loadModel.
/**
* Loading model according to existing model path.
*
* @param modelConfig
* model config
* @param modelPath
* the path to store model
* @param fs
* file system used to store model
* @param gbtConvertToProb
* convert gbt score to prob or not
* @param gbtScoreConvertStrategy
* specify how to convert gbt raw score
* @return model object or null if no modelPath file,
* @throws IOException
* if loading file for any IOException
*/
public static BasicML loadModel(ModelConfig modelConfig, Path modelPath, FileSystem fs, boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
if (!fs.exists(modelPath)) {
// no such existing model, return null.
return null;
}
// we have to register PersistBasicFloatNetwork for loading such models
PersistorRegistry.getInstance().add(new PersistBasicFloatNetwork());
FSDataInputStream stream = null;
BufferedReader br = null;
try {
stream = fs.open(modelPath);
if (modelPath.getName().endsWith(LogisticRegressionContants.LR_ALG_NAME.toLowerCase())) {
// LR model
br = new BufferedReader(new InputStreamReader(stream));
try {
return LR.loadFromString(br.readLine());
} catch (Exception e) {
// local LR model?
// close and reopen
IOUtils.closeQuietly(br);
stream = fs.open(modelPath);
return BasicML.class.cast(EncogDirectoryPersistence.loadObject(stream));
}
} else if (// RF or GBT
modelPath.getName().endsWith(CommonConstants.RF_ALG_NAME.toLowerCase()) || modelPath.getName().endsWith(CommonConstants.GBT_ALG_NAME.toLowerCase())) {
return TreeModel.loadFromStream(stream, gbtConvertToProb, gbtScoreConvertStrategy);
} else {
GzipStreamPair pair = GzipStreamPair.isGZipFormat(stream);
if (pair.isGzip()) {
return BasicML.class.cast(NNModel.loadFromStream(pair.getInput()));
} else {
return BasicML.class.cast(EncogDirectoryPersistence.loadObject(pair.getInput()));
}
}
} catch (Exception e) {
String msg = "the expecting model file is: " + modelPath;
throw new ShifuException(ShifuErrorCode.ERROR_FAIL_TO_LOAD_MODEL_FILE, e, msg);
} finally {
IOUtils.closeQuietly(br);
IOUtils.closeQuietly(stream);
}
}
use of org.encog.ml.BasicML in project shifu by ShifuML.
the class NNModelSpecTest method testModelStructureCompare.
@Test
public void testModelStructureCompare() {
BasicML basicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model0.nn")));
BasicNetwork basicNetwork = (BasicNetwork) basicML;
FlatNetwork flatNetwork = basicNetwork.getFlat();
BasicML extendedBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model1.nn")));
BasicNetwork extendedBasicNetwork = (BasicNetwork) extendedBasicML;
FlatNetwork extendedFlatNetwork = extendedBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, extendedFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, flatNetwork), 0);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, flatNetwork), 1);
BasicML diffBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model2.nn")));
BasicNetwork diffBasicNetwork = (BasicNetwork) diffBasicML;
FlatNetwork diffFlatNetwork = diffBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, diffFlatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, flatNetwork), -1);
Assert.assertEquals(new NNStructureComparator().compare(extendedFlatNetwork, diffFlatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(diffFlatNetwork, extendedFlatNetwork), -1);
BasicML deepBasicML = BasicML.class.cast(EncogDirectoryPersistence.loadObject(new File("src/test/resources/model/model3.nn")));
BasicNetwork deppBasicNetwork = (BasicNetwork) deepBasicML;
FlatNetwork deepFlatNetwork = deppBasicNetwork.getFlat();
Assert.assertEquals(new NNStructureComparator().compare(deepFlatNetwork, flatNetwork), 1);
Assert.assertEquals(new NNStructureComparator().compare(flatNetwork, deepFlatNetwork), -1);
}
Aggregations