use of com.tencent.angel.ml.math.TVector in project angel by Tencent.
the class DenseFloatMatrixTest method plusBySparseFloatVectorTest.
@Test
public void plusBySparseFloatVectorTest() throws Exception {
float[][] value = { { 1.0f, 2.0f }, { 3.0f, 4.0f } };
DenseFloatMatrix mat = new DenseFloatMatrix(2, 2, value);
TVector vec = new SparseFloatVector(2);
((SparseFloatVector) vec).set(1, 1.0f);
vec.setRowId(0);
mat.plusBy(vec);
assertEquals(1.0f, mat.get(0, 0));
assertEquals(3.0f, mat.get(0, 1));
assertEquals(3.0f, mat.get(1, 0));
assertEquals(4.0f, mat.get(1, 1));
}
use of com.tencent.angel.ml.math.TVector in project angel by Tencent.
the class DenseFloatMatrixTest method plusByDenseFloatVectorTest.
@Test
public void plusByDenseFloatVectorTest() throws Exception {
float[][] value = { { 1.0f, 2.0f }, { 3.0f, 4.0f } };
DenseFloatMatrix mat = new DenseFloatMatrix(2, 2, value);
TVector vec = new DenseFloatVector(2, new float[] { 1.0f, 1.0f });
vec.setRowId(0);
mat.plusBy(vec);
assertEquals(2.0f, mat.get(0, 0));
assertEquals(3.0f, mat.get(0, 1));
assertEquals(3.0f, mat.get(1, 0));
assertEquals(4.0f, mat.get(1, 1));
DenseFloatMatrix mat_1 = new DenseFloatMatrix(2, 2);
DenseFloatVector vec_1 = new DenseFloatVector(2, new float[] { 1.0f, 1.0f });
vec_1.setRowId(0);
mat_1.plusBy(vec_1);
assertEquals(1.0f, mat_1.get(0, 0));
assertEquals(1.0f, mat_1.get(0, 1));
assertEquals(0.0f, mat_1.get(1, 0));
assertEquals(0.0f, mat_1.get(1, 1));
}
use of com.tencent.angel.ml.math.TVector in project angel by Tencent.
the class CompSparseIntMatrixOpLog method flushToLocalStorage.
/**
* Flush the update in cache to local matrix storage
*/
public void flushToLocalStorage() {
MatrixStorage storage = PSAgentContext.get().getMatrixStorageManager().getMatrixStoage(matrixId);
MatrixMeta matrixMeta = PSAgentContext.get().getMatrixMetaManager().getMatrixMeta(matrixId);
int row = matrixMeta.getRowNum();
TVector deltaVector = null;
TVector vector = null;
ReentrantReadWriteLock globalStorageLock = storage.getLock();
try {
globalStorageLock.writeLock().lock();
for (int rowIndex = 0; rowIndex < row; rowIndex++) {
deltaVector = getRow(rowIndex);
vector = storage.getRow(rowIndex);
if (deltaVector == null || vector == null)
continue;
vector.plusBy(deltaVector, 1.0 / PSAgentContext.get().getTotalTaskNum());
}
} finally {
globalStorageLock.writeLock().unlock();
}
}
use of com.tencent.angel.ml.math.TVector in project angel by Tencent.
the class RowSplitCombineUtils method combineServerDenseIntRowSplits.
private static TVector combineServerDenseIntRowSplits(List<ServerRow> rowSplits, MatrixMeta matrixMeta, int rowIndex) {
int colNum = (int) matrixMeta.getColNum();
int[] dataArray = new int[colNum];
Collections.sort(rowSplits, serverRowComp);
int clock = Integer.MAX_VALUE;
int size = rowSplits.size();
for (int i = 0; i < size; i++) {
if (rowSplits.get(i).getClock() < clock) {
clock = rowSplits.get(i).getClock();
}
((ServerDenseIntRow) rowSplits.get(i)).mergeTo(dataArray);
}
TVector row = new DenseIntVector(colNum, dataArray);
row.setMatrixId(matrixMeta.getId());
row.setRowId(rowIndex);
row.setClock(clock);
return row;
}
use of com.tencent.angel.ml.math.TVector in project angel by Tencent.
the class RowSplitCombineUtils method combineServerSparseFloatRowSplits.
private static TVector combineServerSparseFloatRowSplits(List<ServerRow> rowSplits, MatrixMeta matrixMeta, int rowIndex) {
int colNum = (int) matrixMeta.getColNum();
int splitNum = rowSplits.size();
int totalElemNum = 0;
int[] lens = new int[splitNum];
Collections.sort(rowSplits, serverRowComp);
int elemNum = 0;
for (int i = 0; i < splitNum; i++) {
elemNum = rowSplits.get(i).size();
totalElemNum += elemNum;
lens[i] = elemNum;
}
int[] indexes = new int[totalElemNum];
float[] values = new float[totalElemNum];
int clock = Integer.MAX_VALUE;
int startPos = 0;
for (int i = 0; i < splitNum; i++) {
if (rowSplits.get(i).getClock() < clock) {
clock = rowSplits.get(i).getClock();
}
((ServerSparseFloatRow) rowSplits.get(i)).mergeTo(indexes, values, startPos, lens[i]);
startPos += lens[i];
}
TVector row = new SparseFloatVector(colNum, indexes, values);
row.setMatrixId(matrixMeta.getId());
row.setRowId(rowIndex);
row.setClock(clock);
return row;
}
Aggregations