Search in sources :

Example 6 with UserAggregator

use of org.neo4j.internal.kernel.api.procs.UserAggregator in project neo4j by neo4j.

the class AllStoreHolder method createAggregationFunction.

private UserAggregator createAggregationFunction(int id) throws ProcedureException {
    ktx.assertOpen();
    AccessMode mode = ktx.securityContext().mode();
    if (!globalProcedures.isBuiltInAggregatingFunction(id) && !mode.allowsExecuteAggregatingFunction(id)) {
        String message = format("Executing a user defined aggregating function is not allowed for %s.", ktx.securityContext().description());
        throw ktx.securityAuthorizationHandler().logAndGetAuthorizationException(ktx.securityContext(), message);
    }
    final SecurityContext securityContext = mode.shouldBoostAggregatingFunction(id) ? ktx.securityContext().withMode(new OverriddenAccessMode(mode, AccessMode.Static.READ)) : ktx.securityContext().withMode(new RestrictedAccessMode(mode, AccessMode.Static.READ));
    try (KernelTransaction.Revertable ignore = ktx.overrideWith(securityContext)) {
        UserAggregator aggregator = globalProcedures.createAggregationFunction(prepareContext(securityContext, ProcedureCallContext.EMPTY), id);
        return new UserAggregator() {

            @Override
            public void update(AnyValue[] input) throws ProcedureException {
                try (KernelTransaction.Revertable ignore = ktx.overrideWith(securityContext)) {
                    aggregator.update(input);
                }
            }

            @Override
            public AnyValue result() throws ProcedureException {
                try (KernelTransaction.Revertable ignore = ktx.overrideWith(securityContext)) {
                    return aggregator.result();
                }
            }
        };
    }
}
Also used : KernelTransaction(org.neo4j.kernel.api.KernelTransaction) OverriddenAccessMode(org.neo4j.kernel.impl.api.security.OverriddenAccessMode) RestrictedAccessMode(org.neo4j.kernel.impl.api.security.RestrictedAccessMode) SecurityContext(org.neo4j.internal.kernel.api.security.SecurityContext) UserAggregator(org.neo4j.internal.kernel.api.procs.UserAggregator) AdminAccessMode(org.neo4j.internal.kernel.api.security.AdminAccessMode) AccessMode(org.neo4j.internal.kernel.api.security.AccessMode) RestrictedAccessMode(org.neo4j.kernel.impl.api.security.RestrictedAccessMode) OverriddenAccessMode(org.neo4j.kernel.impl.api.security.OverriddenAccessMode)

Example 7 with UserAggregator

use of org.neo4j.internal.kernel.api.procs.UserAggregator in project neo4j by neo4j.

the class ProcedureCompilationTest method aggregationShouldAccessContext.

@Test
void aggregationShouldAccessContext() throws ProcedureException, NoSuchFieldException {
    // Given
    UserFunctionSignature signature = functionSignature("test", "foo").in("in", NTString).out(NTString).build();
    FieldSetter setter = createSetter(InnerClass.class, "thread", Context::thread);
    String threadName = Thread.currentThread().getName();
    UserAggregator aggregator = compileAggregation(signature, singletonList(setter), method(InnerClass.class, "create"), method(InnerClass.Aggregator.class, "update", String.class), method(InnerClass.Aggregator.class, "result")).create(ctx);
    // When
    aggregator.update(new AnyValue[] { stringValue("1:") });
    aggregator.update(new AnyValue[] { stringValue("2:") });
    aggregator.update(new AnyValue[] { stringValue("3:") });
    // Then
    assertEquals(stringValue(format("1: %s, 2: %s, 3: %s", threadName, threadName, threadName)), aggregator.result());
}
Also used : Context(org.neo4j.kernel.api.procedure.Context) UserAggregator(org.neo4j.internal.kernel.api.procs.UserAggregator) NTString(org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTString) UserFunctionSignature(org.neo4j.internal.kernel.api.procs.UserFunctionSignature) Test(org.junit.jupiter.api.Test)

Example 8 with UserAggregator

use of org.neo4j.internal.kernel.api.procs.UserAggregator in project neo4j by neo4j.

the class ProcedureCompilationTest method shouldHandleThrowingAggregations.

@Test
void shouldHandleThrowingAggregations() throws ProcedureException {
    UserFunctionSignature signature = functionSignature("test", "foo").out(NTInteger).build();
    UserAggregator aggregator = compileAggregation(signature, emptyList(), method("blackAdder"), method(BlackAdder.class, "update"), method(BlackAdder.class, "result")).create(ctx);
    assertThrows(ProcedureException.class, () -> aggregator.update(EMPTY));
    assertThrows(ProcedureException.class, aggregator::result);
}
Also used : UserAggregator(org.neo4j.internal.kernel.api.procs.UserAggregator) UserFunctionSignature(org.neo4j.internal.kernel.api.procs.UserFunctionSignature) Test(org.junit.jupiter.api.Test)

Example 9 with UserAggregator

use of org.neo4j.internal.kernel.api.procs.UserAggregator in project neo4j by neo4j.

the class UserAggregationFunctionTest method shouldLoadWhiteListedFunction.

@Test
void shouldLoadWhiteListedFunction() throws Throwable {
    // Given
    procedureCompiler = new ProcedureCompiler(new TypeCheckers(), components, new ComponentRegistry(), NullLog.getInstance(), new ProcedureConfig(Config.defaults(GraphDatabaseSettings.procedure_allowlist, List.of("org.neo4j.procedure.impl.collectCool"))));
    CallableUserAggregationFunction method = compile(SingleAggregationFunction.class).get(0);
    // Expect
    UserAggregator created = method.create(prepareContext());
    created.update(new AnyValue[] { stringValue("Bonnie") });
    assertThat(created.result()).isEqualTo(VirtualValues.list(stringValue("Bonnie")));
}
Also used : CallableUserAggregationFunction(org.neo4j.kernel.api.procedure.CallableUserAggregationFunction) UserAggregator(org.neo4j.internal.kernel.api.procs.UserAggregator) Test(org.junit.jupiter.api.Test)

Example 10 with UserAggregator

use of org.neo4j.internal.kernel.api.procs.UserAggregator in project neo4j by neo4j.

the class UserAggregationFunctionTest method shouldRunClassWithMultipleFunctionsDeclared.

@Test
void shouldRunClassWithMultipleFunctionsDeclared() throws Throwable {
    // Given
    List<CallableUserAggregationFunction> compiled = compile(MultiFunction.class);
    CallableUserAggregationFunction f1 = compiled.get(0);
    CallableUserAggregationFunction f2 = compiled.get(1);
    // When
    UserAggregator f1Aggregator = f1.create(prepareContext());
    f1Aggregator.update(new AnyValue[] { stringValue("Bonnie") });
    f1Aggregator.update(new AnyValue[] { stringValue("Clyde") });
    UserAggregator f2Aggregator = f2.create(prepareContext());
    f2Aggregator.update(new AnyValue[] { stringValue("Bonnie"), longValue(1337L) });
    f2Aggregator.update(new AnyValue[] { stringValue("Bonnie"), longValue(42L) });
    // Then
    assertThat(f1Aggregator.result()).isEqualTo(VirtualValues.list(stringValue("Bonnie"), stringValue("Clyde")));
    assertThat(((MapValue) f2Aggregator.result()).get("Bonnie")).isEqualTo(longValue(1337L));
}
Also used : CallableUserAggregationFunction(org.neo4j.kernel.api.procedure.CallableUserAggregationFunction) UserAggregator(org.neo4j.internal.kernel.api.procs.UserAggregator) MapValue(org.neo4j.values.virtual.MapValue) Test(org.junit.jupiter.api.Test)

Aggregations

UserAggregator (org.neo4j.internal.kernel.api.procs.UserAggregator)10 Test (org.junit.jupiter.api.Test)9 CallableUserAggregationFunction (org.neo4j.kernel.api.procedure.CallableUserAggregationFunction)7 UserFunctionSignature (org.neo4j.internal.kernel.api.procs.UserFunctionSignature)4 NTString (org.neo4j.internal.kernel.api.procs.Neo4jTypes.NTString)1 AccessMode (org.neo4j.internal.kernel.api.security.AccessMode)1 AdminAccessMode (org.neo4j.internal.kernel.api.security.AdminAccessMode)1 SecurityContext (org.neo4j.internal.kernel.api.security.SecurityContext)1 KernelTransaction (org.neo4j.kernel.api.KernelTransaction)1 Context (org.neo4j.kernel.api.procedure.Context)1 OverriddenAccessMode (org.neo4j.kernel.impl.api.security.OverriddenAccessMode)1 RestrictedAccessMode (org.neo4j.kernel.impl.api.security.RestrictedAccessMode)1 Log (org.neo4j.logging.Log)1 NullLog (org.neo4j.logging.NullLog)1 MapValue (org.neo4j.values.virtual.MapValue)1