use of ai.djl.pytorch.engine.PtSymbolBlock in project djl by deepjavalibrary.
the class IValueTest method testIValueModel.
@Test
public void testIValueModel() throws IOException, ModelException {
Criteria<NDList, NDList> criteria = Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls("https://resources.djl.ai/test-models/ivalue_jit.zip").optProgress(new ProgressBar()).build();
try (ZooModel<NDList, NDList> model = criteria.loadModel()) {
PtSymbolBlock block = (PtSymbolBlock) model.getBlock();
IValue tokens = IValue.listFrom(1, 2, 3);
IValue cls = IValue.from(0);
IValue sep = IValue.from(4);
IValue ret = block.forward(tokens, cls, sep);
long[] actual = ret.toLongArray();
Assert.assertEquals(actual, new long[] { 0, 1, 2, 3, 4 });
tokens.close();
cls.close();
sep.close();
ret.close();
}
}
use of ai.djl.pytorch.engine.PtSymbolBlock in project djl by deepjavalibrary.
the class TorchScriptTest method testInputOutput.
@Test
public void testInputOutput() throws IOException, ModelException {
Criteria<NDList, NDList> criteria = Criteria.builder().setTypes(NDList.class, NDList.class).optModelUrls("djl://ai.djl.pytorch/resnet/0.0.1/traced_resnet18").optProgress(new ProgressBar()).build();
try (ZooModel<NDList, NDList> model = criteria.loadModel()) {
PtNDManager manager = (PtNDManager) model.getNDManager();
Path modelFile = model.getModelPath().resolve("traced_resnet18.pt");
try (InputStream is = Files.newInputStream(modelFile)) {
PtSymbolBlock block = JniUtils.loadModule(manager, is, true, false);
ByteArrayOutputStream os = new ByteArrayOutputStream();
JniUtils.writeModule(block, os, true);
ByteArrayInputStream bis = new ByteArrayInputStream(os.toByteArray());
JniUtils.loadModule(manager, bis, true, true);
bis.close();
os.close();
}
}
}
use of ai.djl.pytorch.engine.PtSymbolBlock in project djl by deepjavalibrary.
the class JniUtils method loadModule.
public static PtSymbolBlock loadModule(PtNDManager manager, Path path, boolean mapLocation, String[] extraFileKeys, String[] extraFileValues) {
Device device = manager.getDevice();
long handle = PyTorchLibrary.LIB.moduleLoad(path.toString(), new int[] { PtDeviceType.toDeviceType(device), device.getDeviceId() }, mapLocation, extraFileKeys, extraFileValues);
return new PtSymbolBlock(manager, handle);
}
Aggregations