package org.apache.flink.table.planner.plan.rules.physical.stream;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelDistribution;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalCalc;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalChangelogNormalize;
import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalExchange;
import org.apache.flink.table.planner.plan.rules.physical.stream.ImmutablePushCalcPastChangelogNormalizeRule;
import org.apache.flink.table.planner.plan.trait.FlinkRelDistribution;
import org.apache.flink.table.planner.plan.utils.RexNodeExtractor;
import org.immutables.value.Value;

@Internal
@Value.Enclosing
/* loaded from: input_file:org/apache/flink/table/planner/plan/rules/physical/stream/PushCalcPastChangelogNormalizeRule.class */
public class PushCalcPastChangelogNormalizeRule extends RelRule<Config> {
    public static final RelOptRule INSTANCE = new PushCalcPastChangelogNormalizeRule(Config.DEFAULT);

    @Value.Immutable(singleton = false)
    /* loaded from: input_file:org/apache/flink/table/planner/plan/rules/physical/stream/PushCalcPastChangelogNormalizeRule$Config.class */
    public interface Config extends RelRule.Config {
        public static final Config DEFAULT = ImmutablePushCalcPastChangelogNormalizeRule.Config.builder().build().onMatch();

        @Override // org.apache.calcite.plan.RelRule.Config
        default RelOptRule toRule() {
            return new PushCalcPastChangelogNormalizeRule(this);
        }

        default Config onMatch() {
            RelRule.OperandTransform operandTransform = operandBuilder -> {
                return operandBuilder.operand(StreamPhysicalExchange.class).anyInputs();
            };
            RelRule.OperandTransform operandTransform2 = operandBuilder2 -> {
                return operandBuilder2.operand(StreamPhysicalChangelogNormalize.class).oneInput(operandTransform);
            };
            return (Config) withOperandSupplier(operandBuilder3 -> {
                return operandBuilder3.operand(StreamPhysicalCalc.class).oneInput(operandTransform2);
            }).as(Config.class);
        }
    }

    public PushCalcPastChangelogNormalizeRule(Config config) {
        super(config);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        StreamPhysicalCalc streamPhysicalCalc = (StreamPhysicalCalc) relOptRuleCall.rel(0);
        Set<Integer> set = (Set) IntStream.of(((StreamPhysicalChangelogNormalize) relOptRuleCall.rel(1)).uniqueKeys()).boxed().collect(Collectors.toSet());
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        RexProgram program = streamPhysicalCalc.getProgram();
        if (program.getCondition() != null) {
            partitionPrimaryKeyPredicates(RelOptUtil.conjunctions(RexUtil.toCnf(relOptRuleCall.builder().getRexBuilder(), program.expandLocalRef(program.getCondition()))), set, arrayList, arrayList2);
        }
        int[] extractUsedInputFields = extractUsedInputFields(streamPhysicalCalc, set);
        transformWithRemainingPredicates(relOptRuleCall, pushCalcThroughChangelogNormalize(relOptRuleCall, arrayList, extractUsedInputFields), arrayList2, extractUsedInputFields);
    }

    private int[] extractUsedInputFields(StreamPhysicalCalc streamPhysicalCalc, Set<Integer> set) {
        RexProgram program = streamPhysicalCalc.getProgram();
        Stream<RexLocalRef> stream = program.getProjectList().stream();
        program.getClass();
        List list = (List) stream.map(program::expandLocalRef).collect(Collectors.toList());
        if (program.getCondition() != null) {
            list.add(program.expandLocalRef(program.getCondition()));
        }
        Set set2 = (Set) Arrays.stream(RexNodeExtractor.extractRefInputFields(list)).boxed().collect(Collectors.toSet());
        set2.addAll(set);
        return set2.stream().sorted().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray();
    }

    private void partitionPrimaryKeyPredicates(List<RexNode> list, Set<Integer> set, List<RexNode> list2, List<RexNode> list3) {
        for (RexNode rexNode : list) {
            IntStream stream = Arrays.stream(RexNodeExtractor.extractRefInputFields(Collections.singletonList(rexNode)));
            set.getClass();
            if (stream.allMatch((v1) -> {
                return r1.contains(v1);
            })) {
                list2.add(rexNode);
            } else {
                list3.add(rexNode);
            }
        }
    }

    private StreamPhysicalChangelogNormalize pushCalcThroughChangelogNormalize(RelOptRuleCall relOptRuleCall, List<RexNode> list, int[] iArr) {
        StreamPhysicalChangelogNormalize streamPhysicalChangelogNormalize = (StreamPhysicalChangelogNormalize) relOptRuleCall.rel(1);
        StreamPhysicalExchange streamPhysicalExchange = (StreamPhysicalExchange) relOptRuleCall.rel(2);
        Set set = (Set) IntStream.of(streamPhysicalChangelogNormalize.uniqueKeys()).boxed().collect(Collectors.toSet());
        if (list.isEmpty() && iArr.length == streamPhysicalChangelogNormalize.getRowType().getFieldCount()) {
            return streamPhysicalChangelogNormalize;
        }
        StreamPhysicalCalc projectUsedFieldsWithConditions = projectUsedFieldsWithConditions(relOptRuleCall.builder(), streamPhysicalExchange.getInput(), list, iArr);
        Map<Integer, Integer> buildFieldsMapping = buildFieldsMapping(iArr);
        Stream stream = set.stream();
        buildFieldsMapping.getClass();
        List list2 = (List) stream.map((v1) -> {
            return r1.get(v1);
        }).collect(Collectors.toList());
        FlinkRelDistribution hash = FlinkRelDistribution.hash((Collection<? extends Number>) list2, true);
        return (StreamPhysicalChangelogNormalize) streamPhysicalChangelogNormalize.copy(streamPhysicalChangelogNormalize.getTraitSet(), streamPhysicalExchange.copy(streamPhysicalExchange.getTraitSet().replace(hash), (RelNode) projectUsedFieldsWithConditions, (RelDistribution) hash), list2.stream().mapToInt((v0) -> {
            return v0.intValue();
        }).toArray());
    }

    private StreamPhysicalCalc projectUsedFieldsWithConditions(RelBuilder relBuilder, RelNode relNode, List<RexNode> list, int[] iArr) {
        RelDataType rowType = relNode.getRowType();
        List<String> fieldNames = rowType.getFieldNames();
        RexProgramBuilder rexProgramBuilder = new RexProgramBuilder(rowType, relBuilder.getRexBuilder());
        for (int i : iArr) {
            rexProgramBuilder.addProject(rexProgramBuilder.makeInputRef(i), fieldNames.get(i));
        }
        RexNode and = relBuilder.and(list);
        if (!and.isAlwaysTrue()) {
            rexProgramBuilder.addCondition(and);
        }
        RexProgram program = rexProgramBuilder.getProgram();
        return new StreamPhysicalCalc(relNode.getCluster(), relNode.getTraitSet(), relNode, program, program.getOutputRowType());
    }

    private void transformWithRemainingPredicates(RelOptRuleCall relOptRuleCall, StreamPhysicalChangelogNormalize streamPhysicalChangelogNormalize, List<RexNode> list, int[] iArr) {
        StreamPhysicalCalc streamPhysicalCalc = (StreamPhysicalCalc) relOptRuleCall.rel(0);
        RelBuilder builder = relOptRuleCall.builder();
        RexProgramBuilder rexProgramBuilder = new RexProgramBuilder(streamPhysicalChangelogNormalize.getRowType(), builder.getRexBuilder());
        Map<Integer, Integer> buildFieldsMapping = buildFieldsMapping(iArr);
        for (Pair<RexLocalRef, String> pair : streamPhysicalCalc.getProgram().getNamedProjects()) {
            rexProgramBuilder.addProject(adjustInputRef(streamPhysicalCalc.getProgram().expandLocalRef(pair.left), buildFieldsMapping), pair.right);
        }
        RexNode and = builder.and((List) list.stream().map(rexNode -> {
            return adjustInputRef(rexNode, buildFieldsMapping);
        }).collect(Collectors.toList()));
        if (!and.isAlwaysTrue()) {
            rexProgramBuilder.addCondition(and);
        }
        RexProgram program = rexProgramBuilder.getProgram();
        if (program.isTrivial()) {
            relOptRuleCall.transformTo(streamPhysicalChangelogNormalize);
        } else {
            relOptRuleCall.transformTo(new StreamPhysicalCalc(streamPhysicalChangelogNormalize.getCluster(), streamPhysicalChangelogNormalize.getTraitSet(), streamPhysicalChangelogNormalize, program, program.getOutputRowType()));
        }
    }

    private RexNode adjustInputRef(RexNode rexNode, final Map<Integer, Integer> map) {
        return (RexNode) rexNode.accept(new RexShuttle() { // from class: org.apache.flink.table.planner.plan.rules.physical.stream.PushCalcPastChangelogNormalizeRule.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.apache.calcite.rex.RexShuttle, org.apache.calcite.rex.RexVisitor
            /* renamed from: visitInputRef */
            public RexNode mo4803visitInputRef(RexInputRef rexInputRef) {
                return new RexInputRef(((Integer) map.get(Integer.valueOf(rexInputRef.getIndex()))).intValue(), rexInputRef.getType());
            }
        });
    }

    private Map<Integer, Integer> buildFieldsMapping(int[] iArr) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < iArr.length; i++) {
            hashMap.put(Integer.valueOf(iArr[i]), Integer.valueOf(i));
        }
        return hashMap;
    }
}
