use of ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM in project shifu by ShifuML.
the class ComboModelProcessor method validate.
/**
* Validate the algorithms from user's input
*
* @param algorithms - algorithm list that user want to combo
* @return 0 - success
* other - fail
*/
private int validate(String algorithms) {
if (StringUtils.isBlank(algorithms)) {
LOG.error("The combo algorithms should not be empty");
return 1;
}
String[] algs = algorithms.split(ALG_DELIMITER);
if (algs.length < 3) {
LOG.error("At least, you should have 2 basic algorithms, and 1 assembling algorithm.");
return 2;
}
this.comboAlgs = new ArrayList<ModelTrainConf.ALGORITHM>();
for (String alg : algs) {
try {
ModelTrainConf.ALGORITHM algorithm = ModelTrainConf.ALGORITHM.valueOf(alg);
if (algorithm == null) {
LOG.error("Unsupported algorithm - {}", alg);
return 3;
}
this.comboAlgs.add(algorithm);
} catch (Throwable t) {
LOG.error("Unsupported algorithm - {}", alg);
return 3;
}
}
return 0;
}
use of ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM in project shifu by ShifuML.
the class ModelSpecLoaderUtils method getModelsAlgAndSpecFiles.
/**
* Get the model spec file stats and return the ALGORITHM for the model spc
*
* @param fileStatus
* directory to detect
* @param sourceType
* {@link SourceType}
* @param modelFileStats
* model spec file list to return
* @param subConfigs
* configurations for the sub model
* @return {@link ALGORITHM}
* @throws IOException
* Exception occurred when finding model spec files
*/
@SuppressWarnings("deprecation")
public static ALGORITHM getModelsAlgAndSpecFiles(FileStatus fileStatus, RawSourceData.SourceType sourceType, List<FileStatus> modelFileStats, FileStatus[] subConfigs) throws IOException {
assert modelFileStats != null;
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
ALGORITHM algorithm = null;
FileStatus[] fileStatsArr = fs.listStatus(fileStatus.getPath());
if (fileStatsArr != null) {
for (FileStatus fls : fileStatsArr) {
if (!fls.isDir()) {
String fileName = fls.getPath().getName();
if (algorithm == null) {
if (fileName.endsWith("." + ALGORITHM.NN.name().toLowerCase())) {
algorithm = ALGORITHM.NN;
} else if (fileName.endsWith("." + ALGORITHM.LR.name().toLowerCase())) {
algorithm = ALGORITHM.LR;
} else if (fileName.endsWith("." + ALGORITHM.GBT.name().toLowerCase())) {
algorithm = ALGORITHM.GBT;
}
}
if (algorithm != null && fileName.endsWith("." + algorithm.name().toLowerCase())) {
modelFileStats.add(fls);
}
if (fileName.equalsIgnoreCase(Constants.MODEL_CONFIG_JSON_FILE_NAME)) {
subConfigs[0] = fls;
} else if (fileName.equalsIgnoreCase(Constants.COLUMN_CONFIG_JSON_FILE_NAME)) {
subConfigs[1] = fls;
}
}
}
}
return algorithm;
}
use of ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM in project shifu by ShifuML.
the class ShifuCLI method createNewModel.
/*
* Create new model - create directory and ModelConfig for the model
*/
public static int createNewModel(String modelSetName, String modelType, String description) throws Exception {
ALGORITHM modelAlg = null;
if (modelType != null) {
for (ALGORITHM alg : ALGORITHM.values()) {
if (alg.name().equalsIgnoreCase(modelType.trim())) {
modelAlg = alg;
}
}
} else {
modelAlg = ALGORITHM.NN;
}
if (modelAlg == null) {
log.error("Unsupported algirithm - {}", modelType);
return 2;
}
CreateModelProcessor p = new CreateModelProcessor(modelSetName, modelAlg, description);
return p.run();
}
use of ml.shifu.shifu.container.obj.ModelTrainConf.ALGORITHM in project shifu by ShifuML.
the class ModelSpecLoaderUtils method loadSubModelSpec.
/**
* Load sub-model with FileStatus
*
* @param modelConfig
* - {@link ModelConfig}, need this, since the model file may exist in HDFS
* @param columnConfigList
* - List of {@link ColumnConfig}
* @param fileStatus
* - {@link EvalConfig}, maybe null
* @param sourceType
* - {@link SourceType}, HDFS or Local?
* @param gbtConvertToProb
* - convert to probability or not for gbt model
* @param gbtScoreConvertStrategy
* - gbt score conversion strategy
* @return {@link ModelSpec} for sub-model
*/
private static ModelSpec loadSubModelSpec(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, FileStatus fileStatus, RawSourceData.SourceType sourceType, Boolean gbtConvertToProb, String gbtScoreConvertStrategy) throws IOException {
FileSystem fs = ShifuFileUtils.getFileSystemBySourceType(sourceType);
String subModelName = fileStatus.getPath().getName();
List<FileStatus> modelFileStats = new ArrayList<FileStatus>();
FileStatus[] subConfigs = new FileStatus[2];
ALGORITHM algorithm = getModelsAlgAndSpecFiles(fileStatus, sourceType, modelFileStats, subConfigs);
ModelSpec modelSpec = null;
if (CollectionUtils.isNotEmpty(modelFileStats)) {
Collections.sort(modelFileStats, new Comparator<FileStatus>() {
@Override
public int compare(FileStatus fa, FileStatus fb) {
return fa.getPath().getName().compareTo(fb.getPath().getName());
}
});
List<BasicML> models = new ArrayList<BasicML>();
for (FileStatus f : modelFileStats) {
models.add(loadModel(modelConfig, f.getPath(), fs, gbtConvertToProb, gbtScoreConvertStrategy));
}
ModelConfig subModelConfig = modelConfig;
if (subConfigs[0] != null) {
subModelConfig = CommonUtils.loadModelConfig(subConfigs[0].getPath().toString(), sourceType);
}
List<ColumnConfig> subColumnConfigList = columnConfigList;
if (subConfigs[1] != null) {
subColumnConfigList = CommonUtils.loadColumnConfigList(subConfigs[1].getPath().toString(), sourceType);
}
modelSpec = new ModelSpec(subModelName, subModelConfig, subColumnConfigList, algorithm, models);
}
return modelSpec;
}
Aggregations