2021-01-18
< view all posts我们在AI项目的在线特征计算中遇到的一个问题是,Flink对于小步长(slide)的时间窗口,优化并不是很好,或者可以说就没有优化。举个例子,对于每个卡号,计算这张卡片在过去5分钟、15分钟、60分钟……的平均交易金额,并且需要较快的刷新速度,也就是说,不能每5分钟才更新一次,因为卡片在这5分钟以内可以做很多笔交易,线上交易尤其如此,对这些交易必须要算出接近实时的特征,才能保证后续的AI模型预测准确。
那么需要利用到的就是Flink的滑动窗口功能。以5分钟为例,可以定义长度为5分钟,步长为5秒的窗口,这样统计数据每5秒钟就会更新一次,代码示意:
DataStream<Tuple2<String, String>> statistics = messages.keyBy(0) .timeWindow(Time.seconds(300), Time.seconds(5)) .aggregate(new Min5Aggregate()) .setParallelism(20); statistics.addSink(new CustomRedisSink()).setParallelism(20);
但这样带来的一个问题是,对于很多交易并不是很活跃的卡片,有大量的窗口计算是重复的,白白浪费了很多资源。
例如一张卡,每几小时或更长的时间才会有一次交易,但每当一次交易发生,Flink都会产生300/5=60个窗口。而这60个窗口由于在过期前无法接受到新的交易,最终的计算结果实际上是一模一样的。
我们之前使用的策略是,在向redis中写数据之前,比较输出值和redis中的已有值是否相同,如果相同则跳过写入。这样虽然避免了重复写入相同数据,但是依然要做非常多次的比较,也就是读redis。在性能测试中,确认了这个实现确实导致了性能瓶颈。
优化后的解决方案是,对输出流使用RichFlatMapFunction进行过滤,RichFlatMapFunction是可以有状态的,可以将之前的输出结果作为状态缓存起来,当判断输出和缓存的状态相同时,直接从输出流中滤去,使得重复的结果不会到达Redis Sink。
这里顺便说一下FlatMap,其实它和Map的区别就是,Map相当于FlatMap的一个特殊情况:Map没法把原来的数据单元拆开来,比如Map的输入是一个Tuple2,因为Map的输出是靠return来传递的,你没法return两次,把Tuple2给拆开来。而FlatMap的输出是用Collector.collect()来传递的,调用两次就能分别传出两个元素。
另外,RichFlatMapFunction的状态是支持过期时间的,我们可以通过配置过期时间来保证旧状态不会无限地占用内存。下面是RichFlatMapFunction的一个示例程序,Flink版本1.12:
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.state.StateTtlConfig; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.time.Time; import org.apache.flink.api.java.io.TextInputFormat; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.core.fs.Path; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.functions.source.FileProcessingMode; import org.apache.flink.util.Collector; public class FlinkStateFilterBeforeOutput { public static void main(String[] args) throws Exception { StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setParallelism(2); DataStream<Tuple2<String, String>> dataStream = env .readFile(new TextInputFormat(new Path()), "D:\\tmp\\flinktestinput.txt", FileProcessingMode.PROCESS_CONTINUOUSLY, 5000) .map(new Splitter()) .keyBy(0) .flatMap(new RichFlatMapFilter(10)); dataStream.print(); env.execute("Window WordCount"); } public static class Splitter implements MapFunction<String, Tuple2<String, String>> { @Override public Tuple2<String, String> map(String s) throws Exception { String name = s.split(" ")[0]; String value = s.split(" ")[1]; return new Tuple2<String, String>(name, value); } } public static class RichFlatMapFilter extends RichFlatMapFunction<Tuple2<String, String>, Tuple2<String, String>> { private ValueState<String> state; private final int ttlSeconds; public RichFlatMapFilter(int seconds) { ttlSeconds = seconds; } @Override public void open(Configuration config) { ValueStateDescriptor<String> stateDescriptor = new ValueStateDescriptor<String>("output", String.class); StateTtlConfig stateTtlConfig = StateTtlConfig.newBuilder(Time.seconds(ttlSeconds)).build(); stateDescriptor.enableTimeToLive(stateTtlConfig); state = getRuntimeContext().getState(stateDescriptor); } @Override public void flatMap(Tuple2<String, String> input, Collector<Tuple2<String, String>> collector) throws Exception { if (state.value() == null) { state.update(input.f1); collector.collect(input); } else if (!state.value().equals(input.f1)) { state.update(input.f1); collector.collect(input); } else { collector.collect(new Tuple2<>("Out put of " + input.f0 + " " + input.f1, "is skiped")); } } } }