Search in sources :

Example 6 with FloatArrayList

use of it.unimi.dsi.fastutil.floats.FloatArrayList in project angel by Tencent.

the class CooLongFloatMatrix method getCol.

@Override
public Vector getCol(int idx) {
    LongArrayList cols = new LongArrayList();
    FloatArrayList data = new FloatArrayList();
    for (int i = 0; i < colIndices.length; i++) {
        if (colIndices[i] == idx) {
            cols.add(rowIndices[i]);
            data.add(values[i]);
        }
    }
    LongFloatSparseVectorStorage storage = new LongFloatSparseVectorStorage(shape[0], cols.toLongArray(), data.toFloatArray());
    return new LongFloatVector(getMatrixId(), 0, getClock(), shape[0], storage);
}
Also used : LongArrayList(it.unimi.dsi.fastutil.longs.LongArrayList) FloatArrayList(it.unimi.dsi.fastutil.floats.FloatArrayList) LongFloatVector(com.tencent.angel.ml.math2.vector.LongFloatVector) LongFloatSparseVectorStorage(com.tencent.angel.ml.math2.storage.LongFloatSparseVectorStorage)

Example 7 with FloatArrayList

use of it.unimi.dsi.fastutil.floats.FloatArrayList in project angel by Tencent.

the class CsrFloatMatrix method getCol.

@Override
public Vector getCol(int idx) {
    IntArrayList cols = new IntArrayList();
    FloatArrayList data = new FloatArrayList();
    int[] rows = new int[indices.length];
    int i = 0;
    int j = 0;
    while (i < indptr.length - 1 && j < indptr.length - 1) {
        int r = indptr[i + 1] - indptr[i];
        for (int p = j; p < j + r; p++) {
            rows[p] = i;
        }
        if (r != 0) {
            j++;
        }
        i++;
    }
    for (int id = 0; id < indices.length; id++) {
        if (indices[id] == idx) {
            cols.add(rows[id]);
            data.add(values[id]);
        }
    }
    IntFloatSparseVectorStorage storage = new IntFloatSparseVectorStorage(shape[0], cols.toIntArray(), data.toFloatArray());
    return new IntFloatVector(getMatrixId(), 0, getClock(), shape[0], storage);
}
Also used : IntFloatSparseVectorStorage(com.tencent.angel.ml.math2.storage.IntFloatSparseVectorStorage) FloatArrayList(it.unimi.dsi.fastutil.floats.FloatArrayList) IntArrayList(it.unimi.dsi.fastutil.ints.IntArrayList) IntFloatVector(com.tencent.angel.ml.math2.vector.IntFloatVector)

Aggregations

FloatArrayList (it.unimi.dsi.fastutil.floats.FloatArrayList)7 IntFloatSparseVectorStorage (com.tencent.angel.ml.math2.storage.IntFloatSparseVectorStorage)4 IntFloatVector (com.tencent.angel.ml.math2.vector.IntFloatVector)4 IntArrayList (it.unimi.dsi.fastutil.ints.IntArrayList)4 LongArrayList (it.unimi.dsi.fastutil.longs.LongArrayList)3 LongFloatSparseVectorStorage (com.tencent.angel.ml.math2.storage.LongFloatSparseVectorStorage)2 LongFloatVector (com.tencent.angel.ml.math2.vector.LongFloatVector)2 InvalidParameterException (com.tencent.angel.exception.InvalidParameterException)1 GeneralPartGetParam (com.tencent.angel.ml.matrix.psf.get.base.GeneralPartGetParam)1 ServerLongFloatRow (com.tencent.angel.ps.storage.vector.ServerLongFloatRow)1 KeyPart (com.tencent.angel.psagent.matrix.transport.router.KeyPart)1