use of com.tencent.angel.ml.matrix.psf.get.single.GetRowResult in project angel by Tencent.
the class PSFTestTask method run.
@Override
public void run(TaskContext taskContext) throws AngelException {
try {
MatrixClient client = taskContext.getMatrix("psf_test");
Pull func = new Pull(client.getMatrixId(), 0);
Pull func1 = new Pull(client.getMatrixId(), 1);
while (taskContext.getEpoch() < 100) {
taskContext.globalSync(client.getMatrixId());
long startTs = System.currentTimeMillis();
TVector row = ((GetRowResult) client.get(func)).getRow();
TVector row1 = ((GetRowResult) client.get(func1)).getRow();
LOG.info("Task " + taskContext.getTaskId() + " in iteration " + taskContext.getEpoch() + " pull use time=" + (System.currentTimeMillis() - startTs) + ", sum of row 0=" + sum((DenseDoubleVector) row) + " sum of row 1=" + sum((DenseDoubleVector) row1));
double[] delta = new double[10000000];
for (int i = 0; i < 10000000; i++) {
delta[i] = 1.0;
}
DenseDoubleVector deltaV = new DenseDoubleVector(10000000, delta);
deltaV.setMatrixId(client.getMatrixId());
deltaV.setRowId(0);
double[] delta1 = new double[10000000];
for (int i = 0; i < 10000000; i++) {
delta1[i] = 2.0;
}
DenseDoubleVector deltaV1 = new DenseDoubleVector(10000000, delta1);
deltaV1.setMatrixId(client.getMatrixId());
deltaV1.setRowId(1);
client.increment(deltaV);
client.increment(deltaV1);
client.clock().get();
taskContext.incEpoch();
}
} catch (Throwable x) {
throw new AngelException("run task failed ", x);
}
}
use of com.tencent.angel.ml.matrix.psf.get.single.GetRowResult in project angel by Tencent.
the class AggrFuncTest method testPull.
@Test
public void testPull() throws InvalidParameterException, InterruptedException, ExecutionException {
GetFunc func = new Pull(w2Client.getMatrixId(), 1);
double[] result = ((DenseDoubleVector) (((GetRowResult) w2Client.get(func)).getRow())).getValues();
for (int i = 0; i < dim; i++) {
Assert.assertEquals(result[i], localArray1[i], delta);
}
}
use of com.tencent.angel.ml.matrix.psf.get.single.GetRowResult in project angel by Tencent.
the class IndexGetFunc method merge.
@Override
public GetResult merge(List<PartitionGetResult> partResults) {
long startTs = System.currentTimeMillis();
RowType rowType = PSAgentContext.get().getMatrixMetaManager().getMatrixMeta(param.getMatrixId()).getRowType();
GetRowResult result = null;
switch(rowType) {
case T_DOUBLE_DENSE:
case T_DOUBLE_SPARSE:
result = new GetRowResult(ResponseType.SUCCESS, ValuesCombineUtils.mergeSparseDoubleVector((IndexGetParam) param, partResults));
break;
case T_DOUBLE_SPARSE_COMPONENT:
result = new GetRowResult(ResponseType.SUCCESS, ValuesCombineUtils.mergeSparseDoubleCompVector((IndexGetParam) param, partResults));
break;
case T_FLOAT_DENSE:
case T_FLOAT_SPARSE:
result = new GetRowResult(ResponseType.SUCCESS, ValuesCombineUtils.mergeSparseFloatVector((IndexGetParam) param, partResults));
break;
case T_FLOAT_SPARSE_COMPONENT:
result = new GetRowResult(ResponseType.SUCCESS, ValuesCombineUtils.mergeSparseFloatCompVector((IndexGetParam) param, partResults));
break;
case T_INT_DENSE:
case T_INT_SPARSE:
result = new GetRowResult(ResponseType.SUCCESS, ValuesCombineUtils.mergeSparseIntVector((IndexGetParam) param, partResults));
break;
case T_INT_SPARSE_COMPONENT:
result = new GetRowResult(ResponseType.SUCCESS, ValuesCombineUtils.mergeSparseIntCompVector((IndexGetParam) param, partResults));
break;
default:
throw new UnsupportedOperationException("Unsupport operation: update " + rowType + " to " + this.getClass().getName());
}
LOG.info("Merge use time=" + (System.currentTimeMillis() - startTs) + " ms");
return result;
}
use of com.tencent.angel.ml.matrix.psf.get.single.GetRowResult in project angel by Tencent.
the class ConsistencyController method getRow.
/**
* Get row from storage/cache or pss.
*
* @param taskContext task context
* @param matrixId matrix id
* @param rowIndex row index
* @return TVector matrix row
* @throws Exception
*/
public TVector getRow(TaskContext taskContext, int matrixId, int rowIndex) throws Exception {
int staleness = getStaleness(matrixId);
if (staleness >= 0) {
// Get row from cache.
TVector row = PSAgentContext.get().getMatrixStorageManager().getRow(matrixId, rowIndex);
// if row clock is satisfy ssp staleness limit, just return.
if (row != null && (taskContext.getPSMatrixClock(matrixId) <= row.getClock()) && (taskContext.getMatrixClock(matrixId) - row.getClock() <= staleness)) {
LOG.debug("task " + taskContext.getIndex() + " matrix " + matrixId + " clock " + taskContext.getMatrixClock(matrixId) + ", row clock " + row.getClock() + ", staleness " + staleness + ", just get from global storage");
return cloneRow(matrixId, rowIndex, row, taskContext);
}
// Get row from ps.
row = PSAgentContext.get().getMatrixClientAdapter().getRow(matrixId, rowIndex, taskContext.getMatrixClock(matrixId) - staleness);
return cloneRow(matrixId, rowIndex, row, taskContext);
} else {
// For ASYNC mode, just get from pss.
GetRowFunc func = new GetRowFunc(new GetRowParam(matrixId, rowIndex));
GetRowResult result = ((GetRowResult) PSAgentContext.get().getMatrixClientAdapter().get(func));
if (result.getResponseType() == ResponseType.FAILED) {
throw new IOException("get row from ps failed.");
} else {
return result.getRow();
}
}
}
Aggregations