use of deepboof.tensors.Tensor_F32 in project BoofCV by lessthanoptimal.
the class CheckBaseImageClassifier method createDummyNetwork.
private void createDummyNetwork(BaseImageClassifier alg, int width, int height) {
for (int i = 0; i < numCategories; i++) {
alg.getCategories().add("Category " + i);
}
FunctionLinear_F32 function = new FunctionLinear_F32(numCategories);
function.initialize(3, height, width);
List<Tensor_F32> parameters = new ArrayList<>();
parameters.add(TensorFactory_F32.random(rand, false, function.getParameterShapes().get(0)));
parameters.add(TensorFactory_F32.random(rand, false, function.getParameterShapes().get(1)));
function.setParameters(parameters);
Node<Tensor_F32, Function<Tensor_F32>> node = new Node<>();
node.function = function;
List<Node<Tensor_F32, Function<Tensor_F32>>> sequence = new ArrayList<>();
sequence.add(node);
alg.network = new FunctionSequence<>(sequence, Tensor_F32.class);
alg.tensorOutput = new Tensor_F32(WI(1, alg.network.getOutputShape()));
}
use of deepboof.tensors.Tensor_F32 in project BoofCV by lessthanoptimal.
the class TestDataManipulationOps method imageToTensor_fail.
@Test
public void imageToTensor_fail() {
Planar<GrayF32> image = new Planar<>(GrayF32.class, 30, 25, 2);
try {
DataManipulationOps.imageToTensor(image, new Tensor_F32(2, 25, 30), 0);
fail("expected exception");
} catch (RuntimeException ignore) {
}
try {
DataManipulationOps.imageToTensor(image, new Tensor_F32(0, 3, 25, 30), 0);
fail("expected exception");
} catch (RuntimeException ignore) {
}
try {
DataManipulationOps.imageToTensor(image, new Tensor_F32(1, 2, 26, 30), 0);
fail("expected exception");
} catch (RuntimeException ignore) {
}
try {
DataManipulationOps.imageToTensor(image, new Tensor_F32(1, 2, 25, 31), 0);
fail("expected exception");
} catch (RuntimeException ignore) {
}
}
use of deepboof.tensors.Tensor_F32 in project BoofCV by lessthanoptimal.
the class ImageClassifierResNet method loadModel.
@Override
public void loadModel(File path) throws IOException {
String name = String.format("resnet/resnet-%d.t7", resnetID);
List<TorchObject> list = new ParseBinaryTorch7().parse(new File(path, name));
TorchGeneric torchSequence = (TorchGeneric) list.get(0);
// mean = torchListToArray( (TorchList)torchNorm.get("mean"));
// stdev = torchListToArray( (TorchList)torchNorm.get("std"));
SequenceAndParameters<Tensor_F32, Function<Tensor_F32>> seqparam = ConvertTorchToBoofForward.convert(torchSequence);
System.out.println("Bread here");
}
use of deepboof.tensors.Tensor_F32 in project BoofCV by lessthanoptimal.
the class ImageClassifierNiNImageNet method loadModel.
@Override
public void loadModel(File directory) throws IOException {
List<TorchObject> list = new ParseBinaryTorch7().parse(new File(directory, "nin_bn_final.t7"));
TorchGeneric torchSequence = ((TorchGeneric) list.get(0)).get("model");
TorchGeneric torchNorm = torchSequence.get("transform");
mean = torchListToArray((TorchList) torchNorm.get("mean"));
stdev = torchListToArray((TorchList) torchNorm.get("std"));
SequenceAndParameters<Tensor_F32, Function<Tensor_F32>> seqparam = ConvertTorchToBoofForward.convert(torchSequence);
network = seqparam.createForward(3, imageCrop, imageCrop);
tensorOutput = new Tensor_F32(WI(1, network.getOutputShape()));
TorchList torchCategories = (TorchList) new ParseAsciiTorch7().parse(new File(directory, "synset.t7")).get(0);
categories.clear();
for (int i = 0; i < torchCategories.list.size(); i++) {
categories.add(((TorchString) torchCategories.list.get(i)).message);
}
}
use of deepboof.tensors.Tensor_F32 in project BoofCV by lessthanoptimal.
the class ImageClassifierVggCifar10 method loadModel.
/**
* Expects there to be two files in the provided directory:<br>
* YuvStatistics.txt<br>
* model.net<br>
*
* @param directory Directory containing model files
* @throws IOException Throw if anything goes wrong while reading data
*/
@Override
public void loadModel(File directory) throws IOException {
stats = DeepModelIO.load(new File(directory, "YuvStatistics.txt"));
SequenceAndParameters<Tensor_F32, Function<Tensor_F32>> sequence = new ParseBinaryTorch7().parseIntoBoof(new File(directory, "model.net"));
network = sequence.createForward(3, inputSize, inputSize);
tensorOutput = new Tensor_F32(WI(1, network.getOutputShape()));
BorderType type = BorderType.valueOf(stats.border);
localNorm = new ImageLocalNormalization<>(GrayF32.class, type);
kernel = DataManipulationOps.create1D_F32(stats.kernel);
}
Aggregations