use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer in project vespa by vespa-engine.
the class TensorFlowImporter method findDimensionNames.
/**
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
DimensionRenamer renamer = new DimensionRenamer();
for (TensorFlowModel.Signature signature : model.signatures().values()) {
for (String output : signature.outputs().values()) {
addDimensionNameConstraints(index.get(output), renamer);
}
}
renamer.solve();
for (TensorFlowModel.Signature signature : model.signatures().values()) {
for (String output : signature.outputs().values()) {
renameDimensions(index.get(output), renamer);
}
}
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer in project vespa by vespa-engine.
the class DimensionRenamerTest method testMnistRenaming.
@Test
public void testMnistRenaming() {
DimensionRenamer renamer = new DimensionRenamer();
renamer.addDimension("first_dimension_of_x");
renamer.addDimension("second_dimension_of_x");
renamer.addDimension("first_dimension_of_w");
renamer.addDimension("second_dimension_of_w");
renamer.addDimension("first_dimension_of_b");
// which dimension to join on matmul
renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
// other dimensions in matmul can't be equal
renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
// for efficiency, put dimension to join on innermost
renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
// bias
renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
renamer.solve();
String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer in project vespa by vespa-engine.
the class Join method addDimensionNameConstraints.
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
if (!allInputTypesPresent(2)) {
return;
}
OrderedTensorType a = largestInput().type().get();
OrderedTensorType b = smallestInput().type().get();
int sizeDifference = a.rank() - b.rank();
for (int i = 0; i < b.rank(); ++i) {
String bDim = b.dimensions().get(i).name();
String aDim = a.dimensions().get(i + sizeDifference).name();
renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
}
}
Aggregations