use of com.tencent.angel.ml.math2.MatrixExecutors in project angel by Tencent.
the class DotMatrixExecutor method applyParallel.
private static Matrix applyParallel(BlasDoubleMatrix mat1, boolean trans1, BlasDoubleMatrix mat2, boolean trans2) {
int m = mat1.getNumRows(), n = mat1.getNumCols();
int p = mat2.getNumRows(), q = mat2.getNumCols();
double[] resBlas;
BlasDoubleMatrix retMat;
BlasDoubleMatrix transMat1;
MatrixExecutors executors = MatrixExecutors.getInstance();
if (trans1) {
if (trans2) {
assert m == q;
resBlas = new double[n * p];
retMat = new BlasDoubleMatrix(mat1.getMatrixId(), mat1.getClock(), n, p, resBlas);
} else {
assert m == p;
resBlas = new double[n * q];
retMat = new BlasDoubleMatrix(mat1.getMatrixId(), mat1.getClock(), n, q, resBlas);
}
// Transform mat1, generate a new matrix
transMat1 = new BlasDoubleMatrix(mat1.getMatrixId(), mat1.getClock(), n, m, transform(mat1));
} else {
if (trans2) {
assert n == q;
resBlas = new double[m * p];
retMat = new BlasDoubleMatrix(mat1.getMatrixId(), mat1.getClock(), m, p, resBlas);
} else {
assert n == p;
resBlas = new double[m * q];
retMat = new BlasDoubleMatrix(mat1.getMatrixId(), mat1.getClock(), m, q, resBlas);
}
transMat1 = mat1;
}
// Split the row indices of mat1Trans
int subM = Math.max(1, transMat1.getNumRows() / executors.getParallel());
int[] leftRowOffIndices = splitRowIds(transMat1.getNumRows(), subM);
// Parallel execute use fork-join
DotForkJoinOp op = new DotForkJoinOp(transMat1, mat2, retMat, leftRowOffIndices, 0, leftRowOffIndices.length, subM, trans2);
executors.execute(op);
op.join();
return retMat;
}
use of com.tencent.angel.ml.math2.MatrixExecutors in project angel by Tencent.
the class DotMatrixExecutor method applyParallel.
private static Matrix applyParallel(BlasFloatMatrix mat1, boolean trans1, BlasFloatMatrix mat2, boolean trans2) {
int m = mat1.getNumRows(), n = mat1.getNumCols();
int p = mat2.getNumRows(), q = mat2.getNumCols();
float[] resBlas;
BlasFloatMatrix retMat;
BlasFloatMatrix transMat1;
MatrixExecutors executors = MatrixExecutors.getInstance();
if (trans1) {
if (trans2) {
assert m == q;
resBlas = new float[n * p];
retMat = new BlasFloatMatrix(mat1.getMatrixId(), mat1.getClock(), n, p, resBlas);
} else {
assert m == p;
resBlas = new float[n * q];
retMat = new BlasFloatMatrix(mat1.getMatrixId(), mat1.getClock(), n, q, resBlas);
}
// Transform mat1, generate a new matrix
transMat1 = new BlasFloatMatrix(mat1.getMatrixId(), mat1.getClock(), n, m, transform(mat1));
} else {
if (trans2) {
assert n == q;
resBlas = new float[m * p];
retMat = new BlasFloatMatrix(mat1.getMatrixId(), mat1.getClock(), m, p, resBlas);
} else {
assert n == p;
resBlas = new float[m * q];
retMat = new BlasFloatMatrix(mat1.getMatrixId(), mat1.getClock(), m, q, resBlas);
}
transMat1 = mat1;
}
// Split the row indices of mat1Trans
int subM = Math.max(1, transMat1.getNumRows() / executors.getParallel());
int[] leftRowOffIndices = splitRowIds(transMat1.getNumRows(), subM);
// Parallel execute use fork-join
DotForkJoinOp op = new DotForkJoinOp(transMat1, mat2, retMat, leftRowOffIndices, 0, leftRowOffIndices.length, subM, trans2);
executors.execute(op);
op.join();
return retMat;
}
Aggregations