使用状态缓存优化Flink的输出性能

2021-01-18

< view all posts

我们在AI项目的在线特征计算中遇到的一个问题是,Flink对于小步长(slide)的时间窗口,优化并不是很好,或者可以说就没有优化。举个例子,对于每个卡号,计算这张卡片在过去5分钟、15分钟、60分钟……的平均交易金额,并且需要较快的刷新速度,也就是说,不能每5分钟才更新一次,因为卡片在这5分钟以内可以做很多笔交易,线上交易尤其如此,对这些交易必须要算出接近实时的特征,才能保证后续的AI模型预测准确。

那么需要利用到的就是Flink的滑动窗口功能。以5分钟为例,可以定义长度为5分钟,步长为5秒的窗口,这样统计数据每5秒钟就会更新一次,代码示意:

DataStream<Tuple2<String, String>> statistics = messages.keyBy(0)
    .timeWindow(Time.seconds(300), Time.seconds(5))
    .aggregate(new Min5Aggregate())
    .setParallelism(20);

statistics.addSink(new CustomRedisSink()).setParallelism(20);

但这样带来的一个问题是,对于很多交易并不是很活跃的卡片,有大量的窗口计算是重复的,白白浪费了很多资源。

例如一张卡,每几小时或更长的时间才会有一次交易,但每当一次交易发生,Flink都会产生300/5=60个窗口。而这60个窗口由于在过期前无法接受到新的交易,最终的计算结果实际上是一模一样的。

我们之前使用的策略是,在向redis中写数据之前,比较输出值和redis中的已有值是否相同,如果相同则跳过写入。这样虽然避免了重复写入相同数据,但是依然要做非常多次的比较,也就是读redis。在性能测试中,确认了这个实现确实导致了性能瓶颈。

优化后的解决方案是,对输出流使用RichFlatMapFunction进行过滤,RichFlatMapFunction是可以有状态的,可以将之前的输出结果作为状态缓存起来,当判断输出和缓存的状态相同时,直接从输出流中滤去,使得重复的结果不会到达Redis Sink。

这里顺便说一下FlatMap,其实它和Map的区别就是,Map相当于FlatMap的一个特殊情况:Map没法把原来的数据单元拆开来,比如Map的输入是一个Tuple2,因为Map的输出是靠return来传递的,你没法return两次,把Tuple2给拆开来。而FlatMap的输出是用Collector.collect()来传递的,调用两次就能分别传出两个元素。

另外,RichFlatMapFunction的状态是支持过期时间的,我们可以通过配置过期时间来保证旧状态不会无限地占用内存。下面是RichFlatMapFunction的一个示例程序,Flink版本1.12:

import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.io.TextInputFormat;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.fs.Path;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.source.FileProcessingMode;
import org.apache.flink.util.Collector;

public class FlinkStateFilterBeforeOutput {

    public static void main(String[] args) throws Exception {

        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(2);

        DataStream<Tuple2<String, String>> dataStream = env
                .readFile(new TextInputFormat(new Path()),
                        "D:\\tmp\\flinktestinput.txt",
                        FileProcessingMode.PROCESS_CONTINUOUSLY,
                        5000)
                .map(new Splitter())
                .keyBy(0)
                .flatMap(new RichFlatMapFilter(10));
        dataStream.print();

        env.execute("Window WordCount");
    }

    public static class Splitter implements MapFunction<String, Tuple2<String, String>> {
        @Override
        public Tuple2<String, String> map(String s) throws Exception {
            String name = s.split(" ")[0];
            String value = s.split(" ")[1];
            return new Tuple2<String, String>(name, value);
        }
    }

    public static class RichFlatMapFilter extends RichFlatMapFunction<Tuple2<String, String>, Tuple2<String, String>> {

        private ValueState<String> state;
        private final int ttlSeconds;

        public RichFlatMapFilter(int seconds) {
            ttlSeconds = seconds;
        }

        @Override
        public void open(Configuration config) {
            ValueStateDescriptor<String> stateDescriptor = new ValueStateDescriptor<String>("output", String.class);
            StateTtlConfig stateTtlConfig = StateTtlConfig.newBuilder(Time.seconds(ttlSeconds)).build();
            stateDescriptor.enableTimeToLive(stateTtlConfig);
            state = getRuntimeContext().getState(stateDescriptor);
        }

        @Override
        public void flatMap(Tuple2<String, String> input, Collector<Tuple2<String, String>> collector) throws Exception {
            if (state.value() == null) {
                state.update(input.f1);
                collector.collect(input);
            } else if (!state.value().equals(input.f1)) {
                state.update(input.f1);
                collector.collect(input);
            } else {
                collector.collect(new Tuple2<>("Out put of " + input.f0 + " " + input.f1, "is skiped"));
            }
        }
    }
}