use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class Rename method evaluate.
@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = argument.evaluate(context);
TensorType renamedType = type(tensor.type());
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
String dimensionName = tensor.type().dimensions().get(i).name();
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
TensorAddress renamedAddress = rename(cell.getKey(), toIndexes);
builder.cell(renamedAddress, cell.getValue());
}
return builder.build();
}
use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class JsonReaderTestCase method testParsingOfIndexedTensorWithCells.
@Test
public void testParsingOfIndexedTensorWithCells() {
Tensor tensor = assertTensorField("{{x:0,y:0}:2.0,{x:1,y:0}:3.0}}", createPutWithTensor("{ " + " \"cells\": [ " + " { \"address\": { \"x\": \"0\", \"y\": \"0\" }, " + " \"value\": 2.0 }, " + " { \"address\": { \"x\": \"1\", \"y\": \"0\" }, " + " \"value\": 3.0 } " + " ]" + "}", "indexedtensorfield"), "indexedtensorfield");
// this matters for performance
assertTrue(tensor instanceof IndexedTensor);
}
use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class ExpandDims method lazyGetType.
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(2)) {
return null;
}
TensorFlowOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + "axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
if (axis.type().rank() != 0) {
throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + "axis argument must be a scalar.");
}
OrderedTensorType inputType = inputs.get(0).type().get();
int dimensionToInsert = (int) axis.asDouble();
if (dimensionToInsert < 0) {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
if (dimensionIndex == dimensionToInsert) {
String name = String.format("%s_%d", vespaName(), dimensionIndex);
expandDimensions.add(name);
typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
}
typeBuilder.add(dimension);
dimensionIndex++;
}
return typeBuilder.build();
}
Aggregations