Search in sources :

Example 1 with DimensionRenamer

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer in project vespa by vespa-engine.

the class TensorFlowImporter method findDimensionNames.

/**
 * Find dimension names to avoid excessive renaming while evaluating the model.
 */
private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
    DimensionRenamer renamer = new DimensionRenamer();
    for (TensorFlowModel.Signature signature : model.signatures().values()) {
        for (String output : signature.outputs().values()) {
            addDimensionNameConstraints(index.get(output), renamer);
        }
    }
    renamer.solve();
    for (TensorFlowModel.Signature signature : model.signatures().values()) {
        for (String output : signature.outputs().values()) {
            renameDimensions(index.get(output), renamer);
        }
    }
}
Also used : DimensionRenamer(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer)

Example 2 with DimensionRenamer

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer in project vespa by vespa-engine.

the class DimensionRenamerTest method testMnistRenaming.

@Test
public void testMnistRenaming() {
    DimensionRenamer renamer = new DimensionRenamer();
    renamer.addDimension("first_dimension_of_x");
    renamer.addDimension("second_dimension_of_x");
    renamer.addDimension("first_dimension_of_w");
    renamer.addDimension("second_dimension_of_w");
    renamer.addDimension("first_dimension_of_b");
    // which dimension to join on matmul
    renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
    // other dimensions in matmul can't be equal
    renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
    // for efficiency, put dimension to join on innermost
    renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
    renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
    // bias
    renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
    renamer.solve();
    String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
    String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
    String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
    String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
    String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
    assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
    assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
    assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
    assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
    assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
}
Also used : DimensionRenamer(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer) Test(org.junit.Test)

Example 3 with DimensionRenamer

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer in project vespa by vespa-engine.

the class Join method addDimensionNameConstraints.

@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
    if (!allInputTypesPresent(2)) {
        return;
    }
    OrderedTensorType a = largestInput().type().get();
    OrderedTensorType b = smallestInput().type().get();
    int sizeDifference = a.rank() - b.rank();
    for (int i = 0; i < b.rank(); ++i) {
        String bDim = b.dimensions().get(i).name();
        String aDim = a.dimensions().get(i + sizeDifference).name();
        renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
    }
}
Also used : OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) DimensionRenamer(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer)

Aggregations

DimensionRenamer (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer)3 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)1 Test (org.junit.Test)1