use of ai.djl.ndarray.types.Shape in project djl-demo by deepjavalibrary.
the class CSVDataset method encodeData.
/**
* Convert the URL string to NDArray encoded form
*
* @param manager NDManager for NDArray context
* @param url URL in string format
*/
private NDArray encodeData(NDManager manager, String url) {
NDArray encoded = manager.zeros(new Shape(alphabets.size(), FEATURE_LENGTH));
char[] arrayText = url.toCharArray();
for (int i = 0; i < url.length(); i++) {
if (i > FEATURE_LENGTH) {
break;
}
if (alphabetsIndex.containsKey(arrayText[i])) {
encoded.set(new NDIndex(alphabetsIndex.get(arrayText[i]), i), 1);
}
}
return encoded;
}
use of ai.djl.ndarray.types.Shape in project djl-demo by deepjavalibrary.
the class Training method main.
public static void main(String[] args) throws IOException, TranslateException {
// the location to save the model
Path modelDir = Paths.get("models");
// create ImageFolder dataset from directory
ImageFolder dataset = initDataset("ut-zap50k-images-square");
// Split the dataset set into training dataset and validate dataset
RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);
// set loss function, which seeks to minimize errors
// loss function evaluates model's predictions against the correct answer (during training)
// higher numbers are bad - means model performed poorly; indicates more errors; want to
// minimize errors (loss)
Loss loss = Loss.softmaxCrossEntropyLoss();
// setting training parameters (ie hyperparameters)
TrainingConfig config = setupTrainingConfig(loss);
try (// empty model instance to hold patterns
Model model = Models.getModel();
Trainer trainer = model.newTrainer(config)) {
// metrics collect and report key performance indicators, like accuracy
trainer.setMetrics(new Metrics());
Shape inputShape = new Shape(1, 3, Models.IMAGE_HEIGHT, Models.IMAGE_HEIGHT);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
// find the patterns in data
EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);
// set model properties
TrainingResult result = trainer.getTrainingResult();
model.setProperty("Epoch", String.valueOf(EPOCHS));
model.setProperty("Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
// save the model after done training for inference later
// model saved as shoeclassifier-0000.params
model.save(modelDir, Models.MODEL_NAME);
// save labels into model directory
Models.saveSynset(modelDir, dataset.getSynset());
}
}
use of ai.djl.ndarray.types.Shape in project djl-demo by deepjavalibrary.
the class FaceDetectionTranslator method processInput.
@Override
public NDList processInput(TranslatorContext ctx, Image input) {
width = input.getWidth();
height = input.getHeight();
NDArray array = input.toNDArray(ctx.getNDManager(), Image.Flag.COLOR);
// HWC -> CHW RGB -> BGR
array = array.transpose(2, 0, 1).flip(0);
// The network by default takes float32
if (!array.getDataType().equals(DataType.FLOAT32)) {
array = array.toType(DataType.FLOAT32, false);
}
NDArray mean = ctx.getNDManager().create(new float[] { 104f, 117f, 123f }, new Shape(3, 1, 1));
array = array.sub(mean);
NDList list = new NDList(array);
return list;
}
use of ai.djl.ndarray.types.Shape in project djl by deepjavalibrary.
the class CsvDataset method toNDList.
protected NDList toNDList(NDManager manager, CSVRecord record, List<Feature> selected) {
DynamicBuffer bb = new DynamicBuffer();
for (Feature feature : selected) {
String name = feature.getName();
String value = record.get(name);
feature.featurizer.featurize(bb, value);
}
FloatBuffer buf = bb.getBuffer();
return new NDList(manager.create(buf, new Shape(bb.getLength())));
}
use of ai.djl.ndarray.types.Shape in project djl by deepjavalibrary.
the class AirfoilRandomAccess method toNDList.
/**
* {@inheritDoc}
*/
@Override
protected NDList toNDList(NDManager manager, CSVRecord record, List<Feature> selected) {
int length = selected.size();
ByteBuffer bb = manager.allocateDirect(length * 4);
FloatBuffer buf = bb.asFloatBuffer();
int index = 0;
for (Feature feature : selected) {
String name = feature.getName();
float value = Float.parseFloat(record.get(name));
if (normalize) {
value = (value - mean.get(name)) / std.get(name);
}
buf.put(value);
++index;
}
buf.rewind();
return new NDList(manager.create(buf, new Shape(length)));
}
Aggregations