package com.google.j2cl.transpiler.passes;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.j2cl.transpiler.ast.AbstractRewriter;
import com.google.j2cl.transpiler.ast.AstUtils;
import com.google.j2cl.transpiler.ast.CastExpression;
import com.google.j2cl.transpiler.ast.CompilationUnit;
import com.google.j2cl.transpiler.ast.DeclaredTypeDescriptor;
import com.google.j2cl.transpiler.ast.Expression;
import com.google.j2cl.transpiler.ast.ForEachStatement;
import com.google.j2cl.transpiler.ast.Statement;
import com.google.j2cl.transpiler.ast.TypeDescriptor;
import com.google.j2cl.transpiler.ast.TypeDescriptors;
import com.google.j2cl.transpiler.ast.TypeVariable;
import com.google.j2cl.transpiler.ast.UnionTypeDescriptor;
import java.util.Objects;
import java.util.stream.Stream;

/* loaded from: input_file:com/google/j2cl/transpiler/passes/NormalizeForEachIterable.class */
public class NormalizeForEachIterable extends NormalizationPass {
    @Override // com.google.j2cl.transpiler.passes.NormalizationPass
    public void applyTo(CompilationUnit compilationUnit) {
        compilationUnit.accept(new AbstractRewriter() { // from class: com.google.j2cl.transpiler.passes.NormalizeForEachIterable.1
            /* renamed from: rewriteForEachStatement, reason: merged with bridge method [inline-methods] */
            public Statement m86rewriteForEachStatement(ForEachStatement forEachStatement) {
                return NormalizeForEachIterable.this.normalizeIterable(forEachStatement);
            }
        });
    }

    private ForEachStatement normalizeIterable(ForEachStatement forEachStatement) {
        Expression iterableExpression = forEachStatement.getIterableExpression();
        if (!iterableExpression.getTypeDescriptor().isUnion()) {
            return forEachStatement;
        }
        return ForEachStatement.Builder.from(forEachStatement).setIterableExpression(castToIterable(iterableExpression, computeTargetElementType(iterableExpression.getTypeDescriptor(), forEachStatement.getLoopVariable().getTypeDescriptor()))).build();
    }

    private static TypeDescriptor computeTargetElementType(UnionTypeDescriptor unionTypeDescriptor, TypeDescriptor typeDescriptor) {
        Stream stream = unionTypeDescriptor.getUnionTypeDescriptors().stream();
        Class<DeclaredTypeDescriptor> cls = DeclaredTypeDescriptor.class;
        Objects.requireNonNull(DeclaredTypeDescriptor.class);
        ImmutableSet immutableSet = (ImmutableSet) stream.map((v1) -> {
            return r1.cast(v1);
        }).map((v0) -> {
            return AstUtils.getIterableElement(v0);
        }).collect(ImmutableSet.toImmutableSet());
        if (immutableSet.size() != 1 && !typeDescriptor.isPrimitive()) {
            return typeDescriptor;
        }
        Preconditions.checkState(immutableSet.stream().map((v0) -> {
            return v0.toNullable();
        }).distinct().count() == 1);
        return (TypeDescriptor) immutableSet.stream().findFirst().get();
    }

    private static Expression castToIterable(Expression expression, TypeDescriptor typeDescriptor) {
        DeclaredTypeDescriptor declaredTypeDescriptor = TypeDescriptors.get().javaLangIterable;
        return CastExpression.newBuilder().setCastTypeDescriptor(declaredTypeDescriptor.specializeTypeVariables(ImmutableMap.of((TypeVariable) Iterables.getOnlyElement(declaredTypeDescriptor.getTypeDeclaration().getTypeParameterDescriptors()), typeDescriptor))).setExpression(expression).build();
    }
}
