use of com.tencent.angel.ml.math2.matrix.RowBasedMatrix in project angel by Tencent.
the class MatrixFormatImpl method initMatrix.
public Matrix initMatrix(MatrixFilesMeta matrixFilesMeta) {
Map<Integer, MatrixPartitionMeta> partMetas = matrixFilesMeta.getPartMetas();
Int2LongOpenHashMap rowIdToElemNumMap = new Int2LongOpenHashMap();
for (MatrixPartitionMeta partMeta : partMetas.values()) {
Map<Integer, RowPartitionMeta> rowMetas = partMeta.getRowMetas();
for (Map.Entry<Integer, RowPartitionMeta> rowMetaEntry : rowMetas.entrySet()) {
rowIdToElemNumMap.addTo(rowMetaEntry.getKey(), rowMetaEntry.getValue().getElementNum());
}
}
RowType rowType = RowType.valueOf(matrixFilesMeta.getRowType());
RowBasedMatrix matrix = rbMatrix(rowType, matrixFilesMeta.getRow(), matrixFilesMeta.getCol());
ObjectIterator<Int2LongMap.Entry> iter = rowIdToElemNumMap.int2LongEntrySet().fastIterator();
Int2LongMap.Entry entry;
while (iter.hasNext()) {
entry = iter.next();
matrix.setRow(entry.getIntKey(), initRow(rowType, matrixFilesMeta.getCol(), entry.getLongValue()));
}
return matrix;
}
use of com.tencent.angel.ml.math2.matrix.RowBasedMatrix in project angel by Tencent.
the class DotMatrixExecutor method apply.
public static Matrix apply(Matrix mat1, boolean trans1, Matrix mat2, boolean trans2, Boolean parallel) {
if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
if (parallel) {
return applyParallel((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
} else {
return apply((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
}
} else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof BlasFloatMatrix) {
if (parallel) {
return applyParallel((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
} else {
return apply((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
}
} else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBIntDoubleMatrix) {
return apply((BlasDoubleMatrix) mat1, trans1, (RBIntDoubleMatrix) mat2, trans2);
} else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBLongDoubleMatrix) {
return apply((BlasDoubleMatrix) mat1, trans1, (RBLongDoubleMatrix) mat2, trans2);
} else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBIntFloatMatrix) {
return apply((BlasFloatMatrix) mat1, trans1, (RBIntFloatMatrix) mat2, trans2);
} else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBLongFloatMatrix) {
return apply((BlasFloatMatrix) mat1, trans1, (RBLongFloatMatrix) mat2, trans2);
} else if (mat1 instanceof RBIntDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
return apply((RBIntDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
} else if (mat1 instanceof RBIntFloatMatrix && mat2 instanceof BlasFloatMatrix) {
return apply((RBIntFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
} else if (mat1 instanceof RowBasedMatrix && mat2 instanceof RowBasedMatrix) {
if (!trans1 && trans2) {
int outputRow = mat1.getNumRows();
int outputCol = mat2.getNumRows();
RowType type1 = mat1.getRow(0).getStorage().getType();
RowType type2 = mat2.getRow(0).getStorage().getType();
if (type1.isDouble() && type2.isDouble()) {
BlasDoubleMatrix res = MFactory.denseDoubleMatrix(outputRow, outputCol);
for (int i = 0; i < outputCol; i++) {
Vector row = mat2.getRow(i);
Vector col = mat1.dot(row);
res.setCol(i, col);
}
return res;
} else if (type1.isFloat() && type2.isFloat()) {
BlasFloatMatrix res = MFactory.denseFloatMatrix(outputRow, outputCol);
for (int i = 0; i < outputCol; i++) {
Vector row = mat2.getRow(i);
Vector col = mat1.dot(row);
res.setCol(i, col);
}
return res;
} else {
throw new AngelException("the operation is not supported!");
}
} else {
throw new AngelException("the operation is not supported!");
}
} else {
throw new AngelException("the operation is not supported!");
}
}
Aggregations