use of com.tencent.angel.tools.ModelLineConvert in project angel by Tencent.
the class ModelMergeAndConverterTask method run.
@Override
public void run(TaskContext taskContext) throws AngelException {
try {
// Get input path, output path
String modelLoadDir = conf.get(AngelConf.ANGEL_LOAD_MODEL_PATH);
if (modelLoadDir == null) {
throw new InvalidParameterException("convert source path " + AngelConf.ANGEL_LOAD_MODEL_PATH + " must be set");
}
String convertedModelSaveDir = conf.get(AngelConf.ANGEL_SAVE_MODEL_PATH);
if (convertedModelSaveDir == null) {
throw new InvalidParameterException("converted model save path " + AngelConf.ANGEL_LOAD_MODEL_PATH + " must be set");
}
// Init serde
String modelSerdeClass = conf.get("angel.modelconverts.serde.class", TextModelLineConvert.class.getName());
Class<? extends ModelLineConvert> funcClass = (Class<? extends ModelLineConvert>) Class.forName(modelSerdeClass);
Constructor<? extends ModelLineConvert> constructor = funcClass.getConstructor();
constructor.setAccessible(true);
ModelLineConvert serde = constructor.newInstance();
// Parse need convert model names, if not set, we will convert all models in input directory
String needConvertModelNames = conf.get("angel.modelconverts.model.names");
String[] modelNames = null;
if (needConvertModelNames == null) {
LOG.info("we will convert all models save in " + modelLoadDir);
Path modelLoadPath = new Path(modelLoadDir);
FileSystem fs = modelLoadPath.getFileSystem(conf);
FileStatus[] fileStatus = fs.listStatus(modelLoadPath);
if (fileStatus == null || fileStatus.length == 0) {
throw new IOException("can not find any models in " + modelLoadDir);
}
List<String> modelNameList = new ArrayList<>();
for (int i = 0; i < fileStatus.length; i++) {
if (fileStatus[i].isDirectory()) {
modelNameList.add(fileStatus[i].getPath().getName());
}
}
if (modelNameList.isEmpty()) {
throw new IOException("can not find any models in " + modelLoadDir);
}
modelNames = modelNameList.toArray(new String[0]);
} else {
modelNames = needConvertModelNames.split(",");
if (modelNames.length == 0) {
throw new IOException("can not find any models in " + modelLoadDir);
}
}
for (int i = 0; i < modelNames.length; i++) {
LOG.info("===================start to convert model " + modelNames[i]);
ModelMergeAndConvert.convert(conf, modelLoadDir + Path.SEPARATOR + modelNames[i], convertedModelSaveDir + Path.SEPARATOR + modelNames[i], serde);
LOG.info("===================end to convert model " + modelNames[i]);
}
} catch (Throwable x) {
LOG.fatal("convert model falied, ", x);
throw new AngelException(x);
}
}
use of com.tencent.angel.tools.ModelLineConvert in project angel by Tencent.
the class ModelConverterTask method run.
@Override
public void run(TaskContext taskContext) throws AngelException {
try {
// Get input path, output path
String modelLoadDir = conf.get(AngelConf.ANGEL_LOAD_MODEL_PATH);
if (modelLoadDir == null) {
throw new InvalidParameterException("convert source path " + AngelConf.ANGEL_LOAD_MODEL_PATH + " must be set");
}
String convertedModelSaveDir = conf.get(AngelConf.ANGEL_SAVE_MODEL_PATH);
if (convertedModelSaveDir == null) {
throw new InvalidParameterException("converted model save path " + AngelConf.ANGEL_LOAD_MODEL_PATH + " must be set");
}
// Init serde
String modelSerdeClass = conf.get("angel.modelconverts.serde.class", TextModelLineConvert.class.getName());
Class<? extends ModelLineConvert> funcClass = (Class<? extends ModelLineConvert>) Class.forName(modelSerdeClass);
Constructor<? extends ModelLineConvert> constructor = funcClass.getConstructor();
constructor.setAccessible(true);
ModelLineConvert serde = constructor.newInstance();
// Parse need convert model names, if not set, we will convert all models in input directory
String needConvertModelNames = conf.get("angel.modelconverts.model.names");
String[] modelNames = null;
if (needConvertModelNames == null) {
LOG.info("we will convert all models save in " + modelLoadDir);
Path modelLoadPath = new Path(modelLoadDir);
FileSystem fs = modelLoadPath.getFileSystem(conf);
FileStatus[] fileStatus = fs.listStatus(modelLoadPath);
if (fileStatus == null || fileStatus.length == 0) {
throw new IOException("can not find any models in " + modelLoadDir);
}
List<String> modelNameList = new ArrayList<>();
for (int i = 0; i < fileStatus.length; i++) {
if (fileStatus[i].isDirectory()) {
modelNameList.add(fileStatus[i].getPath().getName());
}
}
if (modelNameList.isEmpty()) {
throw new IOException("can not find any models in " + modelLoadDir);
}
modelNames = modelNameList.toArray(new String[0]);
} else {
modelNames = needConvertModelNames.split(",");
if (modelNames.length == 0) {
throw new IOException("can not find any models in " + modelLoadDir);
}
}
for (int i = 0; i < modelNames.length; i++) {
LOG.info("===================start to convert model " + modelNames[i]);
ModelConverter.convert(conf, modelLoadDir + Path.SEPARATOR + modelNames[i], convertedModelSaveDir + Path.SEPARATOR + modelNames[i], serde);
LOG.info("===================end to convert model " + modelNames[i]);
}
} catch (Throwable x) {
LOG.fatal("convert model falied, ", x);
throw new AngelException(x);
}
}
Aggregations