Search in sources :

Example 1 with RowBasedMatrix

use of com.tencent.angel.ml.math2.matrix.RowBasedMatrix in project angel by Tencent.

the class MatrixFormatImpl method initMatrix.

public Matrix initMatrix(MatrixFilesMeta matrixFilesMeta) {
    Map<Integer, MatrixPartitionMeta> partMetas = matrixFilesMeta.getPartMetas();
    Int2LongOpenHashMap rowIdToElemNumMap = new Int2LongOpenHashMap();
    for (MatrixPartitionMeta partMeta : partMetas.values()) {
        Map<Integer, RowPartitionMeta> rowMetas = partMeta.getRowMetas();
        for (Map.Entry<Integer, RowPartitionMeta> rowMetaEntry : rowMetas.entrySet()) {
            rowIdToElemNumMap.addTo(rowMetaEntry.getKey(), rowMetaEntry.getValue().getElementNum());
        }
    }
    RowType rowType = RowType.valueOf(matrixFilesMeta.getRowType());
    RowBasedMatrix matrix = rbMatrix(rowType, matrixFilesMeta.getRow(), matrixFilesMeta.getCol());
    ObjectIterator<Int2LongMap.Entry> iter = rowIdToElemNumMap.int2LongEntrySet().fastIterator();
    Int2LongMap.Entry entry;
    while (iter.hasNext()) {
        entry = iter.next();
        matrix.setRow(entry.getIntKey(), initRow(rowType, matrixFilesMeta.getCol(), entry.getLongValue()));
    }
    return matrix;
}
Also used : RowType(com.tencent.angel.ml.matrix.RowType) RowBasedMatrix(com.tencent.angel.ml.math2.matrix.RowBasedMatrix) Int2LongOpenHashMap(it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap) Int2LongMap(it.unimi.dsi.fastutil.ints.Int2LongMap) Int2LongMap(it.unimi.dsi.fastutil.ints.Int2LongMap) Map(java.util.Map) Int2LongOpenHashMap(it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap)

Example 2 with RowBasedMatrix

use of com.tencent.angel.ml.math2.matrix.RowBasedMatrix in project angel by Tencent.

the class DotMatrixExecutor method apply.

public static Matrix apply(Matrix mat1, boolean trans1, Matrix mat2, boolean trans2, Boolean parallel) {
    if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
        if (parallel) {
            return applyParallel((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
        } else {
            return apply((BlasDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
        }
    } else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof BlasFloatMatrix) {
        if (parallel) {
            return applyParallel((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
        } else {
            return apply((BlasFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
        }
    } else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBIntDoubleMatrix) {
        return apply((BlasDoubleMatrix) mat1, trans1, (RBIntDoubleMatrix) mat2, trans2);
    } else if (mat1 instanceof BlasDoubleMatrix && mat2 instanceof RBLongDoubleMatrix) {
        return apply((BlasDoubleMatrix) mat1, trans1, (RBLongDoubleMatrix) mat2, trans2);
    } else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBIntFloatMatrix) {
        return apply((BlasFloatMatrix) mat1, trans1, (RBIntFloatMatrix) mat2, trans2);
    } else if (mat1 instanceof BlasFloatMatrix && mat2 instanceof RBLongFloatMatrix) {
        return apply((BlasFloatMatrix) mat1, trans1, (RBLongFloatMatrix) mat2, trans2);
    } else if (mat1 instanceof RBIntDoubleMatrix && mat2 instanceof BlasDoubleMatrix) {
        return apply((RBIntDoubleMatrix) mat1, trans1, (BlasDoubleMatrix) mat2, trans2);
    } else if (mat1 instanceof RBIntFloatMatrix && mat2 instanceof BlasFloatMatrix) {
        return apply((RBIntFloatMatrix) mat1, trans1, (BlasFloatMatrix) mat2, trans2);
    } else if (mat1 instanceof RowBasedMatrix && mat2 instanceof RowBasedMatrix) {
        if (!trans1 && trans2) {
            int outputRow = mat1.getNumRows();
            int outputCol = mat2.getNumRows();
            RowType type1 = mat1.getRow(0).getStorage().getType();
            RowType type2 = mat2.getRow(0).getStorage().getType();
            if (type1.isDouble() && type2.isDouble()) {
                BlasDoubleMatrix res = MFactory.denseDoubleMatrix(outputRow, outputCol);
                for (int i = 0; i < outputCol; i++) {
                    Vector row = mat2.getRow(i);
                    Vector col = mat1.dot(row);
                    res.setCol(i, col);
                }
                return res;
            } else if (type1.isFloat() && type2.isFloat()) {
                BlasFloatMatrix res = MFactory.denseFloatMatrix(outputRow, outputCol);
                for (int i = 0; i < outputCol; i++) {
                    Vector row = mat2.getRow(i);
                    Vector col = mat1.dot(row);
                    res.setCol(i, col);
                }
                return res;
            } else {
                throw new AngelException("the operation is not supported!");
            }
        } else {
            throw new AngelException("the operation is not supported!");
        }
    } else {
        throw new AngelException("the operation is not supported!");
    }
}
Also used : RBLongDoubleMatrix(com.tencent.angel.ml.math2.matrix.RBLongDoubleMatrix) AngelException(com.tencent.angel.exception.AngelException) RBLongFloatMatrix(com.tencent.angel.ml.math2.matrix.RBLongFloatMatrix) RBIntFloatMatrix(com.tencent.angel.ml.math2.matrix.RBIntFloatMatrix) RowBasedMatrix(com.tencent.angel.ml.math2.matrix.RowBasedMatrix) RBIntDoubleMatrix(com.tencent.angel.ml.math2.matrix.RBIntDoubleMatrix) RowType(com.tencent.angel.ml.matrix.RowType) BlasDoubleMatrix(com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix) BlasFloatMatrix(com.tencent.angel.ml.math2.matrix.BlasFloatMatrix) IntLongVector(com.tencent.angel.ml.math2.vector.IntLongVector) IntFloatVector(com.tencent.angel.ml.math2.vector.IntFloatVector) LongDoubleVector(com.tencent.angel.ml.math2.vector.LongDoubleVector) Vector(com.tencent.angel.ml.math2.vector.Vector) LongFloatVector(com.tencent.angel.ml.math2.vector.LongFloatVector) IntDoubleVector(com.tencent.angel.ml.math2.vector.IntDoubleVector) IntIntVector(com.tencent.angel.ml.math2.vector.IntIntVector) IntDummyVector(com.tencent.angel.ml.math2.vector.IntDummyVector)

Aggregations

RowBasedMatrix (com.tencent.angel.ml.math2.matrix.RowBasedMatrix)2 RowType (com.tencent.angel.ml.matrix.RowType)2 AngelException (com.tencent.angel.exception.AngelException)1 BlasDoubleMatrix (com.tencent.angel.ml.math2.matrix.BlasDoubleMatrix)1 BlasFloatMatrix (com.tencent.angel.ml.math2.matrix.BlasFloatMatrix)1 RBIntDoubleMatrix (com.tencent.angel.ml.math2.matrix.RBIntDoubleMatrix)1 RBIntFloatMatrix (com.tencent.angel.ml.math2.matrix.RBIntFloatMatrix)1 RBLongDoubleMatrix (com.tencent.angel.ml.math2.matrix.RBLongDoubleMatrix)1 RBLongFloatMatrix (com.tencent.angel.ml.math2.matrix.RBLongFloatMatrix)1 IntDoubleVector (com.tencent.angel.ml.math2.vector.IntDoubleVector)1 IntDummyVector (com.tencent.angel.ml.math2.vector.IntDummyVector)1 IntFloatVector (com.tencent.angel.ml.math2.vector.IntFloatVector)1 IntIntVector (com.tencent.angel.ml.math2.vector.IntIntVector)1 IntLongVector (com.tencent.angel.ml.math2.vector.IntLongVector)1 LongDoubleVector (com.tencent.angel.ml.math2.vector.LongDoubleVector)1 LongFloatVector (com.tencent.angel.ml.math2.vector.LongFloatVector)1 Vector (com.tencent.angel.ml.math2.vector.Vector)1 Int2LongMap (it.unimi.dsi.fastutil.ints.Int2LongMap)1 Int2LongOpenHashMap (it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap)1 Map (java.util.Map)1