use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class FixedBucketSamplerTest method testFixedBucketSampler.
@Test
public void testFixedBucketSampler() throws IOException, TranslateException {
FixedBucketSampler fixedBucketSampler = new FixedBucketSampler(10, 10, false);
TatoebaEnglishFrenchDataset dataset = TatoebaEnglishFrenchDataset.builder().setSampling(fixedBucketSampler).optDataBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.zeros(new Shape(1)), 10).build()).optLabelBatchifier(PaddingStackBatchifier.builder().optIncludeValidLengths(true).addPad(0, 0, (m) -> m.ones(new Shape(1)), 10).build()).optLimit(200).build();
dataset.prepare();
Iterator<List<Long>> iterator = fixedBucketSampler.sample(dataset);
long count = 0;
Set<Long> indicesSet = new HashSet<>();
while (iterator.hasNext()) {
List<Long> indices = iterator.next();
indicesSet.addAll(indices);
count += indices.size();
}
Assert.assertEquals(count, dataset.size());
Assert.assertEquals(indicesSet.size(), dataset.size());
fixedBucketSampler = new FixedBucketSampler(10, 5, true);
iterator = fixedBucketSampler.sample(dataset);
count = 0;
indicesSet.clear();
while (iterator.hasNext()) {
List<Long> indices = iterator.next();
indicesSet.addAll(indices);
count = count + indices.size();
}
Assert.assertEquals(count, dataset.size());
Assert.assertEquals(indicesSet.size(), dataset.size());
}
use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class TrtTest method testSerializedEngine.
@Test
public void testSerializedEngine() throws ModelException, IOException, TranslateException {
Engine engine;
try {
engine = Engine.getEngine("TensorRT");
} catch (Exception ignore) {
throw new SkipException("Your os configuration doesn't support TensorRT.");
}
Device device = engine.defaultDevice();
if (!device.isGpu()) {
throw new SkipException("TensorRT only support GPU.");
}
String sm = CudaUtils.getComputeCapability(device.getDeviceId());
Criteria<float[], float[]> criteria = Criteria.builder().setTypes(float[].class, float[].class).optModelPath(Paths.get("src/test/resources/identity_" + sm + ".trt")).optTranslator(new MyTranslator()).optEngine("TensorRT").build();
try (ZooModel<float[], float[]> model = criteria.loadModel();
Predictor<float[], float[]> predictor = model.newPredictor()) {
float[] data = new float[] { 1, 2, 3, 4 };
float[] ret = predictor.predict(data);
Assert.assertEquals(ret, data);
}
}
use of ai.djl.translate.TranslateException in project djl-serving by deepjavalibrary.
the class WorkerThread method run.
/**
* {@inheritDoc}
*/
@Override
public void run() {
Thread thread = Thread.currentThread();
thread.setName(workerName);
currentThread.set(thread);
this.state = WorkerState.WORKER_STARTED;
List<Input> req = null;
String errorMessage = "Worker shutting down";
try {
while (isRunning() && !aggregator.isFinished()) {
req = aggregator.getRequest();
if (req != null && !req.isEmpty()) {
try {
List<Output> reply = predictor.batchPredict(req);
aggregator.sendResponse(reply);
} catch (TranslateException e) {
logger.warn("Failed to predict", e);
aggregator.sendError(e);
}
}
req = null;
}
} catch (InterruptedException e) {
logger.debug("Shutting down the thread .. Scaling down.");
} catch (Throwable t) {
logger.error("Server error", t);
errorMessage = t.getMessage();
} finally {
logger.debug("Shutting down worker thread .. {}", currentThread.get().getName());
currentThread.set(null);
shutdown(WorkerState.WORKER_STOPPED);
if (req != null) {
Exception e = new WlmException(errorMessage);
aggregator.sendError(e);
}
}
}
use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class BaseModelLoader method loadModel.
/**
* {@inheritDoc}
*/
@Override
@SuppressWarnings("unchecked")
public <I, O> ZooModel<I, O> loadModel(Criteria<I, O> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
Artifact artifact = mrl.match(criteria.getFilters());
if (artifact == null) {
throw new ModelNotFoundException("No matching filter found");
}
Progress progress = criteria.getProgress();
Map<String, Object> arguments = artifact.getArguments(criteria.getArguments());
Map<String, String> options = artifact.getOptions(criteria.getOptions());
try {
TranslatorFactory factory = getTranslatorFactory(criteria, arguments);
Class<I> input = criteria.getInputClass();
Class<O> output = criteria.getOutputClass();
if (factory == null || !factory.isSupported(input, output)) {
factory = defaultFactory;
if (!factory.isSupported(input, output)) {
throw new ModelNotFoundException(getFactoryLookupErrorMessage(factory));
}
}
mrl.prepare(artifact, progress);
if (progress != null) {
progress.reset("Loading", 2);
progress.update(1);
}
Path modelPath = mrl.getRepository().getResourceDirectory(artifact);
Path modelDir = Files.isRegularFile(modelPath) ? modelPath.getParent() : modelPath;
if (modelDir == null) {
throw new AssertionError("Directory should not be null.");
}
loadServingProperties(modelDir, arguments, options);
Application application = criteria.getApplication();
if (application != Application.UNDEFINED) {
arguments.put("application", application.getPath());
}
String engine = criteria.getEngine();
if (engine == null) {
// get engine from serving.properties
engine = (String) arguments.get("engine");
}
// Otherwise if none of them is specified or model zoo is null, go to default engine.
if (engine == null) {
ModelZoo modelZoo = ModelZoo.getModelZoo(mrl.getGroupId());
if (modelZoo != null) {
String defaultEngine = Engine.getDefaultEngineName();
for (String supportedEngine : modelZoo.getSupportedEngines()) {
if (supportedEngine.equals(defaultEngine)) {
engine = supportedEngine;
break;
} else if (Engine.hasEngine(supportedEngine)) {
engine = supportedEngine;
}
}
if (engine == null) {
throw new ModelNotFoundException("No supported engine available for model zoo: " + modelZoo.getGroupId());
}
}
}
if (engine != null && !Engine.hasEngine(engine)) {
throw new ModelNotFoundException(engine + " is not supported");
}
String modelName = criteria.getModelName();
if (modelName == null) {
modelName = artifact.getName();
}
Model model = createModel(modelDir, modelName, criteria.getDevice(), criteria.getBlock(), arguments, engine);
model.load(modelPath, null, options);
Translator<I, O> translator = (Translator<I, O>) factory.newInstance(input, output, model, arguments);
return new ZooModel<>(model, translator);
} catch (TranslateException e) {
throw new ModelNotFoundException("No matching translator found", e);
} finally {
if (progress != null) {
progress.end();
}
}
}
use of ai.djl.translate.TranslateException in project djl by deepjavalibrary.
the class Predictor method batchPredict.
/**
* Predicts a batch for inference.
*
* @param inputs a list of inputs
* @return a list of output objects defined by the user
* @throws TranslateException if an error occurs during prediction
*/
@SuppressWarnings({ "PMD.AvoidRethrowingException", "PMD.IdenticalCatchBranches" })
public List<O> batchPredict(List<I> inputs) throws TranslateException {
long begin = System.nanoTime();
try (PredictorContext context = new PredictorContext()) {
if (!prepared) {
translator.prepare(context);
prepared = true;
}
Batchifier batchifier = translator.getBatchifier();
if (batchifier == null) {
List<O> ret = new ArrayList<>(inputs.size());
for (I input : inputs) {
timestamp = System.nanoTime();
begin = timestamp;
NDList ndList = translator.processInput(context, input);
preprocessEnd(ndList);
NDList result = predictInternal(context, ndList);
predictEnd(result);
ret.add(translator.processOutput(context, result));
postProcessEnd(begin);
}
return ret;
}
timestamp = System.nanoTime();
NDList inputBatch = processInputs(context, inputs);
preprocessEnd(inputBatch);
NDList result = predictInternal(context, inputBatch);
predictEnd(result);
List<O> ret = processOutputs(context, result);
postProcessEnd(begin);
return ret;
} catch (TranslateException e) {
throw e;
} catch (Exception e) {
throw new TranslateException(e);
}
}
Aggregations