Search in sources :

Example 1 with Function

use of deepboof.Function in project BoofCV by lessthanoptimal.

the class CheckBaseImageClassifier method createDummyNetwork.

private void createDummyNetwork(BaseImageClassifier alg, int width, int height) {
    for (int i = 0; i < numCategories; i++) {
        alg.getCategories().add("Category " + i);
    }
    FunctionLinear_F32 function = new FunctionLinear_F32(numCategories);
    function.initialize(3, height, width);
    List<Tensor_F32> parameters = new ArrayList<>();
    parameters.add(TensorFactory_F32.random(rand, false, function.getParameterShapes().get(0)));
    parameters.add(TensorFactory_F32.random(rand, false, function.getParameterShapes().get(1)));
    function.setParameters(parameters);
    Node<Tensor_F32, Function<Tensor_F32>> node = new Node<>();
    node.function = function;
    List<Node<Tensor_F32, Function<Tensor_F32>>> sequence = new ArrayList<>();
    sequence.add(node);
    alg.network = new FunctionSequence<>(sequence, Tensor_F32.class);
    alg.tensorOutput = new Tensor_F32(WI(1, alg.network.getOutputShape()));
}
Also used : Function(deepboof.Function) FunctionLinear_F32(deepboof.impl.forward.standard.FunctionLinear_F32) Node(deepboof.graph.Node) ArrayList(java.util.ArrayList) Tensor_F32(deepboof.tensors.Tensor_F32)

Example 2 with Function

use of deepboof.Function in project BoofCV by lessthanoptimal.

the class ImageClassifierResNet method loadModel.

@Override
public void loadModel(File path) throws IOException {
    String name = String.format("resnet/resnet-%d.t7", resnetID);
    List<TorchObject> list = new ParseBinaryTorch7().parse(new File(path, name));
    TorchGeneric torchSequence = (TorchGeneric) list.get(0);
    // mean = torchListToArray( (TorchList)torchNorm.get("mean"));
    // stdev = torchListToArray( (TorchList)torchNorm.get("std"));
    SequenceAndParameters<Tensor_F32, Function<Tensor_F32>> seqparam = ConvertTorchToBoofForward.convert(torchSequence);
    System.out.println("Bread here");
}
Also used : TorchGeneric(deepboof.io.torch7.struct.TorchGeneric) Function(deepboof.Function) TorchObject(deepboof.io.torch7.struct.TorchObject) ParseBinaryTorch7(deepboof.io.torch7.ParseBinaryTorch7) Tensor_F32(deepboof.tensors.Tensor_F32) File(java.io.File)

Example 3 with Function

use of deepboof.Function in project BoofCV by lessthanoptimal.

the class ImageClassifierNiNImageNet method loadModel.

@Override
public void loadModel(File directory) throws IOException {
    List<TorchObject> list = new ParseBinaryTorch7().parse(new File(directory, "nin_bn_final.t7"));
    TorchGeneric torchSequence = ((TorchGeneric) list.get(0)).get("model");
    TorchGeneric torchNorm = torchSequence.get("transform");
    mean = torchListToArray((TorchList) torchNorm.get("mean"));
    stdev = torchListToArray((TorchList) torchNorm.get("std"));
    SequenceAndParameters<Tensor_F32, Function<Tensor_F32>> seqparam = ConvertTorchToBoofForward.convert(torchSequence);
    network = seqparam.createForward(3, imageCrop, imageCrop);
    tensorOutput = new Tensor_F32(WI(1, network.getOutputShape()));
    TorchList torchCategories = (TorchList) new ParseAsciiTorch7().parse(new File(directory, "synset.t7")).get(0);
    categories.clear();
    for (int i = 0; i < torchCategories.list.size(); i++) {
        categories.add(((TorchString) torchCategories.list.get(i)).message);
    }
}
Also used : Function(deepboof.Function) ParseBinaryTorch7(deepboof.io.torch7.ParseBinaryTorch7) Tensor_F32(deepboof.tensors.Tensor_F32) File(java.io.File) ParseAsciiTorch7(deepboof.io.torch7.ParseAsciiTorch7)

Example 4 with Function

use of deepboof.Function in project BoofCV by lessthanoptimal.

the class ImageClassifierVggCifar10 method loadModel.

/**
 * Expects there to be two files in the provided directory:<br>
 * YuvStatistics.txt<br>
 * model.net<br>
 *
 * @param directory Directory containing model files
 * @throws IOException Throw if anything goes wrong while reading data
 */
@Override
public void loadModel(File directory) throws IOException {
    stats = DeepModelIO.load(new File(directory, "YuvStatistics.txt"));
    SequenceAndParameters<Tensor_F32, Function<Tensor_F32>> sequence = new ParseBinaryTorch7().parseIntoBoof(new File(directory, "model.net"));
    network = sequence.createForward(3, inputSize, inputSize);
    tensorOutput = new Tensor_F32(WI(1, network.getOutputShape()));
    BorderType type = BorderType.valueOf(stats.border);
    localNorm = new ImageLocalNormalization<>(GrayF32.class, type);
    kernel = DataManipulationOps.create1D_F32(stats.kernel);
}
Also used : Function(deepboof.Function) GrayF32(boofcv.struct.image.GrayF32) ParseBinaryTorch7(deepboof.io.torch7.ParseBinaryTorch7) Tensor_F32(deepboof.tensors.Tensor_F32) File(java.io.File) BorderType(boofcv.core.image.border.BorderType)

Aggregations

Function (deepboof.Function)4 Tensor_F32 (deepboof.tensors.Tensor_F32)4 ParseBinaryTorch7 (deepboof.io.torch7.ParseBinaryTorch7)3 File (java.io.File)3 BorderType (boofcv.core.image.border.BorderType)1 GrayF32 (boofcv.struct.image.GrayF32)1 Node (deepboof.graph.Node)1 FunctionLinear_F32 (deepboof.impl.forward.standard.FunctionLinear_F32)1 ParseAsciiTorch7 (deepboof.io.torch7.ParseAsciiTorch7)1 TorchGeneric (deepboof.io.torch7.struct.TorchGeneric)1 TorchObject (deepboof.io.torch7.struct.TorchObject)1 ArrayList (java.util.ArrayList)1