Search in sources :

Example 1 with SetAccumulator

use of org.apache.spark.sql.delta.util.SetAccumulator in project parent by Daytime-Don-t-Know-Dark-Night.

the class YouData method replace.

public static void replace(Dataset<Row> ds, String uri, JdbcOptionsInWrite opts, Column part) throws SQLException, ExecutionException, SparkException, InvocationTargetException {
    final double commit_factor = 0.1;
    final double max_allowed_factor = 0.3;
    final double max_memory_factor = 0.1;
    JdbcDialect dialect = JdbcDialects.get(opts.url());
    long max_allowed_packet = 4 * 1024 * 1024;
    long buffer_pool_size = 128 * 1024 * 1024;
    try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
        Statement statement = conn.createStatement()) {
        try (ResultSet packetRes = statement.executeQuery("show global variables like 'max_allowed_packet'")) {
            while (packetRes.next()) {
                max_allowed_packet = packetRes.getLong("Value");
            }
        }
        try (ResultSet bufferRes = statement.executeQuery("show global variables like 'innodb_buffer_pool_size'")) {
            while (bufferRes.next()) {
                buffer_pool_size = bufferRes.getLong("Value");
            }
        }
    }
    StructType schema = ds.schema();
    String sql_ = JdbcUtils.getInsertStatement(opts.table(), schema, Option.empty(), true, dialect);
    // sql拼接时不使用''的类型
    List<DataType> specialType = ImmutableList.of(DataTypes.BooleanType, DataTypes.LongType, DataTypes.IntegerType);
    MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
    // 堆内存使用情况
    MemoryUsage memoryUsage = memoryMXBean.getHeapMemoryUsage();
    // 最大可用内存
    long maxMemorySize = memoryUsage.getMax();
    SparkContext sparkContext = ds.sparkSession().sparkContext();
    int cores = sparkContext.defaultParallelism();
    int executorNum = sparkContext.getExecutorIds().size() + 1;
    long partNum = Math.min(Math.round(max_allowed_packet * max_allowed_factor), Math.round(maxMemorySize * max_memory_factor * executorNum / cores));
    long bufferLength = Math.round(buffer_pool_size * commit_factor);
    Preconditions.checkArgument(partNum > 0, "partNum计算值<=0");
    Set<Object> lastSet = Sets.newHashSet();
    SetAccumulator<Object> collectionAc = new SetAccumulator<>();
    collectionAc.register(SparkSession.active().sparkContext(), Option.apply("setAc"), false);
    SparkSession.active().sparkContext().addSparkListener(new SparkListener() {

        public synchronized void onTaskEnd(SparkListenerTaskEnd taskEnd) {
            // 上次set中的值 lastSet, 本次set中的值 collectionAc.value()
            synchronized (lastSet) {
                Set<Object> diffSet;
                synchronized (collectionAc.value()) {
                    diffSet = Sets.difference(collectionAc.value(), lastSet).immutableCopy();
                }
                if (!diffSet.isEmpty()) {
                    for (Object i : diffSet) {
                        System.out.println("监听器检测到差异..." + i);
                    }
                    lastSet.addAll(diffSet);
                }
            }
        }
    });
    StructType structType = ds.schema();
    if (Objects.nonNull(part)) {
        ds = ds.withColumn("partNum", part).repartition(col("partNum")).sortWithinPartitions(col("partNum"));
    } else {
        ds = ds.withColumn("partNum", expr("'0'"));
    }
    ds.foreachPartition(rows -> {
        try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
            Statement statement = conn.createStatement()) {
            conn.setAutoCommit(false);
            int numFields = schema.fields().length;
            int executeLength = 0;
            StringBuilder sb = null;
            String sqlPrefix = null;
            Object lastPartNum = null;
            while (rows.hasNext()) {
                Row row = rows.next();
                // 如果本行数据和上一行数据在同一张表, 直接插入数据, 如果本行数据和上一行数据不在同一张表, 提交上一次数据, 建表, 插入数据
                Object tmpPartNum = row.getAs("partNum");
                Preconditions.checkArgument(Objects.nonNull(tmpPartNum));
                if (!Objects.equals(tmpPartNum, lastPartNum)) {
                    String newTableName = opts.table() + partSuffix(tmpPartNum);
                    // 如果有该表则将该表中的数据清空
                    String drop_table = String.format("drop table if exists %s", newTableName);
                    statement.executeUpdate(drop_table);
                    // 建表
                    String col_sql = tableCol(structType);
                    col_sql = col_sql.substring(0, col_sql.length() - 1);
                    String create_table = String.format("create table if not exists %s(%s) DEFAULT CHARSET=utf8mb4", newTableName, col_sql);
                    statement.executeUpdate(create_table);
                    sqlPrefix = sql_.substring(0, sql_.indexOf("(?"));
                    // 替换表名
                    sqlPrefix = sqlPrefix.replace(sqlPrefix.substring(12, sqlPrefix.indexOf("(")), newTableName);
                    if (!Objects.isNull(lastPartNum)) {
                        // 记录执行掉的sql的长度
                        executeLength += sb.length();
                        statement.executeUpdate(sb.substring(0, sb.length() - 1));
                        sb.setLength(0);
                        sb.append(sqlPrefix);
                    } else {
                        sb = new StringBuilder((int) partNum + 1000);
                        sb.append(sqlPrefix);
                    }
                    collectionAc.add(tmpPartNum);
                    lastPartNum = tmpPartNum;
                }
                StringBuilder group = new StringBuilder("(");
                for (int i = 0; i < numFields; i++) {
                    DataType type = schema.apply(i).dataType();
                    if (row.isNullAt(i)) {
                        // null值处理
                        group.append("null,");
                    } else if (specialType.contains(type)) {
                        // 判断该类型数据是否需要''
                        Object tmp = row.getAs(i);
                        group.append(tmp).append(',');
                    } else if (type == DataTypes.StringType) {
                        // 如果该类型为字符串类型且包含', 则对'进行转义
                        String tmp = row.getAs(i);
                        group.append("'").append(tmp.replaceAll("'", "''")).append("',");
                    } else {
                        Object tmp = row.getAs(i);
                        group.append("'").append(tmp).append("',");
                    }
                }
                group.delete(group.length() - 1, group.length());
                group.append("),");
                sb.append(group);
                if (sb.length() * 2L >= partNum) {
                    // 每执行一次, 累计 + sb.length
                    executeLength += sb.length();
                    statement.executeLargeUpdate(sb.delete(sb.length() - 1, sb.length()).toString());
                    sb.setLength(0);
                    sb.append(sqlPrefix);
                }
                // 上面每执行一次, 累计 + max_allowed_packet, 累计加到缓冲池的阈值, 提交
                if (executeLength >= bufferLength) {
                    logger.info("commit执行时间: {}", Instant.now());
                    conn.commit();
                    executeLength = 0;
                }
            }
            // 剩余还有未被执行的数据
            if (Objects.nonNull(sb) && sb.length() > sqlPrefix.length()) {
                String ex_sql = sb.substring(0, sb.length() - 1);
                statement.executeUpdate(ex_sql);
                sb.setLength(0);
                sb.append(sqlPrefix);
            }
            {
                logger.info("commit执行时间: {}", Instant.now());
                conn.commit();
            }
        }
    });
}
Also used : ResultSet(java.sql.ResultSet) Set(java.util.Set) StructType(org.apache.spark.sql.types.StructType) Statement(java.sql.Statement) SparkListenerTaskEnd(org.apache.spark.scheduler.SparkListenerTaskEnd) Connection(java.sql.Connection) MemoryUsage(java.lang.management.MemoryUsage) SparkListener(org.apache.spark.scheduler.SparkListener) MemoryMXBean(java.lang.management.MemoryMXBean) SparkContext(org.apache.spark.SparkContext) ResultSet(java.sql.ResultSet) DataType(org.apache.spark.sql.types.DataType) JdbcDialect(org.apache.spark.sql.jdbc.JdbcDialect) SetAccumulator(org.apache.spark.sql.delta.util.SetAccumulator) Row(org.apache.spark.sql.Row)

Example 2 with SetAccumulator

use of org.apache.spark.sql.delta.util.SetAccumulator in project parent by Daytime-Don-t-Know-Dark-Night.

the class YouData2 method replace.

public static void replace(Dataset<Row> ds, String uri, JdbcOptionsInWrite opts, int partNum) throws SQLException, ExecutionException, SparkException, InvocationTargetException {
    final double commit_factor = 0.1;
    final double max_allowed_factor = 0.3;
    final double max_memory_factor = 0.1;
    JdbcDialect dialect = JdbcDialects.get(opts.url());
    long max_allowed_packet = 4 * 1024 * 1024;
    long buffer_pool_size = 128 * 1024 * 1024;
    try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
        Statement statement = conn.createStatement()) {
        for (int i = 0; i < partNum; i++) {
            String newTableName = opts.table() + partSuffix(i);
            try {
                // 截断
                String truncate_table = String.format("truncate table %s", newTableName);
                statement.executeUpdate(truncate_table);
                // drop
                String drop_table = String.format("drop table if exists %s", newTableName);
                statement.executeUpdate(drop_table);
            } catch (SQLSyntaxErrorException e) {
                e.printStackTrace();
            }
            // 建表
            String col_sql = tableCol(ds.schema());
            col_sql = col_sql.substring(0, col_sql.length() - 1);
            String create_table = String.format("create table if not exists %s(%s) DEFAULT CHARSET=utf8mb4", newTableName, col_sql);
            statement.executeUpdate(create_table);
        }
        try (ResultSet packetRes = statement.executeQuery("show global variables like 'max_allowed_packet'")) {
            while (packetRes.next()) {
                max_allowed_packet = packetRes.getLong("Value");
            }
        }
        try (ResultSet bufferRes = statement.executeQuery("show global variables like 'innodb_buffer_pool_size'")) {
            while (bufferRes.next()) {
                buffer_pool_size = bufferRes.getLong("Value");
            }
        }
    }
    StructType schema = ds.schema();
    String sql_ = JdbcUtils.getInsertStatement(opts.table(), schema, Option.empty(), true, dialect);
    // sql拼接时不使用''的类型
    List<DataType> specialType = ImmutableList.of(DataTypes.BooleanType, DataTypes.LongType, DataTypes.IntegerType);
    MemoryMXBean memoryMXBean = ManagementFactory.getMemoryMXBean();
    // 堆内存使用情况
    MemoryUsage memoryUsage = memoryMXBean.getHeapMemoryUsage();
    // 最大可用内存
    long maxMemorySize = memoryUsage.getMax();
    SparkContext sparkContext = ds.sparkSession().sparkContext();
    int cores = sparkContext.defaultParallelism();
    int executorNum = sparkContext.getExecutorIds().size() + 1;
    long partLimit = Math.min(Math.round(max_allowed_packet * max_allowed_factor), Math.round(maxMemorySize * max_memory_factor * executorNum / cores));
    long bufferLength = Math.round(buffer_pool_size * commit_factor);
    Preconditions.checkArgument(partLimit > 0, "partLimit计算值<=0");
    // 重要参数
    logger.info("线程数: {}, 最大可用内存: {}, executorNum: {}, partLimit: {}, bufferLength: {}", cores, maxMemorySize, executorNum, partLimit, bufferLength);
    SetAccumulator<Integer> collectionAc = new SetAccumulator<>();
    collectionAc.register(SparkSession.active().sparkContext(), Option.apply("setAc"), false);
    Map<Integer, Integer> partitionMap = Maps.newHashMap();
    SparkSession.active().sparkContext().addSparkListener(new SparkListener() {

        public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
            int stageId = stageSubmitted.stageInfo().stageId();
            int numTasks = stageSubmitted.stageInfo().numTasks();
            partitionMap.put(stageId, numTasks);
        }

        public void onTaskEnd(SparkListenerTaskEnd taskEnd) {
            int currStageId = taskEnd.stageId();
            Preconditions.checkArgument(partitionMap.containsKey(currStageId), "当前task中的stageId: " + currStageId + ", map中: " + partitionMap);
            int partitions = partitionMap.get(currStageId);
            synchronized (collectionAc.value()) {
                // 确定该发哪一个
                Set<Integer> currSet = collectionAc.value();
                long index = taskEnd.taskInfo().index();
                long currIndex = index % partNum;
                // partitionId % partNum = i的个数
                long num1 = Stream.iterate(0, k -> k + 1).limit(partitions).filter(j -> j % partNum == currIndex).count();
                // currSet中的值 % partNum = i的个数
                long num2 = currSet.stream().filter(j -> j % partNum == currIndex).count();
                if (num1 == num2) {
                    // 证明累加器中 % partNum = i的值已经集齐
                    logger.info("累加器中partNum={}的数据已经导出完成", currIndex);
                    collectionAc.value().removeIf(k -> k % partNum == currIndex);
                }
            }
        }
    });
    ds.foreachPartition(rows -> {
        int taskId = TaskContext.getPartitionId();
        // 0-199
        collectionAc.add(taskId);
        long tmpPartNum = taskId % partNum;
        try (Connection conn = JdbcUtils.createConnectionFactory(opts).apply();
            Statement statement = conn.createStatement()) {
            conn.setAutoCommit(false);
            int numFields = schema.fields().length;
            long executeLength = 0;
            // 替换表名
            String newTableName = opts.table() + partSuffix(tmpPartNum);
            String sqlPrefix = sql_.substring(0, sql_.indexOf("(?"));
            sqlPrefix = sqlPrefix.replace(sqlPrefix.substring(12, sqlPrefix.indexOf("(")), newTableName);
            StringBuilder sb = new StringBuilder((int) partLimit + 1000);
            sb.append(sqlPrefix);
            while (rows.hasNext()) {
                Row row = rows.next();
                StringBuilder group = new StringBuilder("(");
                for (int i = 0; i < numFields; i++) {
                    DataType type = schema.apply(i).dataType();
                    if (row.isNullAt(i)) {
                        // null值处理
                        group.append("null,");
                    } else if (specialType.contains(type)) {
                        // 判断该类型数据是否需要''
                        Object tmp = row.getAs(i);
                        group.append(tmp).append(',');
                    } else if (type == DataTypes.StringType) {
                        // 如果该类型为字符串类型且包含', 则对'进行转义
                        String tmp = row.getAs(i);
                        group.append("'").append(tmp.replaceAll("'", "''")).append("',");
                    } else {
                        Object tmp = row.getAs(i);
                        group.append("'").append(tmp).append("',");
                    }
                }
                group.delete(group.length() - 1, group.length());
                group.append("),");
                sb.append(group);
                if (sb.length() * 2L >= partLimit) {
                    // 每执行一次, 累计 + sb.length
                    executeLength += sb.length();
                    logger.info("任务ID为: {}, execute过的数据长度: {}", taskId, executeLength);
                    statement.executeLargeUpdate(sb.delete(sb.length() - 1, sb.length()).toString());
                    sb.setLength(0);
                    sb.append(sqlPrefix);
                    // 每次execute过后, 查看数据库缓冲池的可用页数量, 如果可用页数量<1000, commit
                    int bufferPageFreeNum = -1;
                    ResultSet bufferPageRes = statement.executeQuery("show status like 'Innodb_buffer_pool_pages_free'");
                    while (bufferPageRes.next()) {
                        bufferPageFreeNum = bufferPageRes.getInt("Value");
                    }
                    if (bufferPageFreeNum > 0 && bufferPageFreeNum < 1000) {
                        logger.info("缓冲池剩余空闲页为: {}, 执行commit: {}", bufferPageFreeNum, Instant.now());
                        conn.commit();
                        executeLength = 0;
                    }
                }
            }
            // 剩余还有未被执行的数据
            if (sb.length() > sqlPrefix.length()) {
                String ex_sql = sb.substring(0, sb.length() - 1);
                logger.info("execute过的数据长度: {}", executeLength);
                statement.executeUpdate(ex_sql);
                sb.setLength(0);
                sb.append(sqlPrefix);
            }
            {
                logger.info("commit执行时间: {}", Instant.now());
                conn.commit();
            }
        }
    });
}
Also used : DataType(org.apache.spark.sql.types.DataType) java.sql(java.sql) JdbcDialect(org.apache.spark.sql.jdbc.JdbcDialect) Dataset(org.apache.spark.sql.Dataset) LoggerFactory(org.slf4j.LoggerFactory) JdbcOptionsInWrite(org.apache.spark.sql.execution.datasources.jdbc.JdbcOptionsInWrite) MemoryMXBean(java.lang.management.MemoryMXBean) ImmutableList(com.google.common.collect.ImmutableList) Map(java.util.Map) SparkListenerTaskEnd(org.apache.spark.scheduler.SparkListenerTaskEnd) ManagementFactory(java.lang.management.ManagementFactory) SparkException(org.apache.spark.SparkException) MemoryUsage(java.lang.management.MemoryUsage) SparkSession(org.apache.spark.sql.SparkSession) DataTypes(org.apache.spark.sql.types.DataTypes) StructType(org.apache.spark.sql.types.StructType) SetAccumulator(org.apache.spark.sql.delta.util.SetAccumulator) Logger(org.slf4j.Logger) JdbcUtils(org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils) TaskContext(org.apache.spark.TaskContext) SparkListener(org.apache.spark.scheduler.SparkListener) SparkContext(org.apache.spark.SparkContext) Set(java.util.Set) SparkListenerStageSubmitted(org.apache.spark.scheduler.SparkListenerStageSubmitted) Row(org.apache.spark.sql.Row) Option(scala.Option) Instant(java.time.Instant) Maps(com.google.common.collect.Maps) InvocationTargetException(java.lang.reflect.InvocationTargetException) Objects(java.util.Objects) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) Stream(java.util.stream.Stream) Preconditions(com.google.common.base.Preconditions) JdbcDialects(org.apache.spark.sql.jdbc.JdbcDialects) Set(java.util.Set) StructType(org.apache.spark.sql.types.StructType) SparkListenerStageSubmitted(org.apache.spark.scheduler.SparkListenerStageSubmitted) SparkListenerTaskEnd(org.apache.spark.scheduler.SparkListenerTaskEnd) SparkListener(org.apache.spark.scheduler.SparkListener) MemoryMXBean(java.lang.management.MemoryMXBean) DataType(org.apache.spark.sql.types.DataType) SetAccumulator(org.apache.spark.sql.delta.util.SetAccumulator) MemoryUsage(java.lang.management.MemoryUsage) SparkContext(org.apache.spark.SparkContext) JdbcDialect(org.apache.spark.sql.jdbc.JdbcDialect) Row(org.apache.spark.sql.Row)

Example 3 with SetAccumulator

use of org.apache.spark.sql.delta.util.SetAccumulator in project parent by Daytime-Don-t-Know-Dark-Night.

the class Listen method main.

// Spark 监听器
public static void main(String[] args) {
    SparkSession spark = SparkSession.builder().master("local[*]").getOrCreate();
    StructType schema = new StructType().add("date", "date").add("pay_no", "string").add("amount", "double").add("pay_time", "timestamp");
    Row row1 = RowFactory.create(Date.valueOf("2021-11-11"), "001", 1.0, Timestamp.from(Instant.now()));
    Row row2 = RowFactory.create(Date.valueOf("2021-11-11"), "002", 2.0, Timestamp.from(Instant.now()));
    Row row3 = RowFactory.create(Date.valueOf("2021-11-12"), "003", 3.0, Timestamp.from(Instant.now()));
    Row row4 = RowFactory.create(Date.valueOf("2021-11-12"), "004", 4.0, Timestamp.from(Instant.now()));
    Row row5 = RowFactory.create(Date.valueOf("2021-11-13"), "005", 5.0, Timestamp.from(Instant.now()));
    Dataset<Row> ds = spark.createDataFrame(ImmutableList.of(row1, row2, row3, row4, row5), schema);
    Set<Object> lastSet = Sets.newHashSet();
    SetAccumulator<Object> collectionAc = new SetAccumulator<>();
    collectionAc.register(SparkSession.active().sparkContext(), Option.apply("setAc"), false);
    SparkSession.active().sparkContext().addSparkListener(new SparkListener() {

        public void onTaskEnd(SparkListenerTaskEnd taskEnd) {
            // 上次set中的值 lastSet
            // 本次set中的值 collectionAc.value()
            Set<Object> diffSet = Sets.difference(collectionAc.value(), lastSet);
            if (!diffSet.isEmpty()) {
                for (Object obj : diffSet) {
                    System.out.println("监听器检测到差异..." + obj);
                }
            }
        }
    });
    ds.repartition(col("date")).sortWithinPartitions(col("date")).foreachPartition(rows -> {
        Object a = null;
        while (rows.hasNext()) {
            Row row = rows.next();
            Object tmpPartNum = row.getAs("date");
            if (!Objects.equals(tmpPartNum, a)) {
                collectionAc.add(tmpPartNum);
                a = tmpPartNum;
            }
        }
    });
    ds.show(false);
}
Also used : SparkListener(org.apache.spark.scheduler.SparkListener) SparkSession(org.apache.spark.sql.SparkSession) Set(java.util.Set) StructType(org.apache.spark.sql.types.StructType) SparkListenerTaskEnd(org.apache.spark.scheduler.SparkListenerTaskEnd) Row(org.apache.spark.sql.Row) SetAccumulator(org.apache.spark.sql.delta.util.SetAccumulator)

Aggregations

Set (java.util.Set)3 SparkListener (org.apache.spark.scheduler.SparkListener)3 SparkListenerTaskEnd (org.apache.spark.scheduler.SparkListenerTaskEnd)3 Row (org.apache.spark.sql.Row)3 SetAccumulator (org.apache.spark.sql.delta.util.SetAccumulator)3 StructType (org.apache.spark.sql.types.StructType)3 MemoryMXBean (java.lang.management.MemoryMXBean)2 MemoryUsage (java.lang.management.MemoryUsage)2 SparkContext (org.apache.spark.SparkContext)2 SparkSession (org.apache.spark.sql.SparkSession)2 JdbcDialect (org.apache.spark.sql.jdbc.JdbcDialect)2 DataType (org.apache.spark.sql.types.DataType)2 Preconditions (com.google.common.base.Preconditions)1 ImmutableList (com.google.common.collect.ImmutableList)1 Maps (com.google.common.collect.Maps)1 ManagementFactory (java.lang.management.ManagementFactory)1 InvocationTargetException (java.lang.reflect.InvocationTargetException)1 java.sql (java.sql)1 Connection (java.sql.Connection)1 ResultSet (java.sql.ResultSet)1