/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.physical;

import java.util.List;
import java.util.function.Predicate;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.AbstractRelNode;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.type.ExprSqlType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
import org.opensearch.sql.opensearch.planner.physical.ImmutableOpenSearchAggregateIndexScanRule;
import org.opensearch.sql.opensearch.planner.physical.OpenSearchIndexScanRule;
import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan;

@Value.Enclosing
public class OpenSearchAggregateIndexScanRule
extends RelRule<Config> {
    protected OpenSearchAggregateIndexScanRule(Config config) {
        super((RelRule.Config)config);
    }

    public void onMatch(RelOptRuleCall call) {
        if (call.rels.length == 3) {
            LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
            LogicalProject project = (LogicalProject)call.rel(1);
            CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(2);
            if (aggregate.getGroupSet().length() > 1 && Config.containsWidthBucketFuncOnDate(project)) {
                return;
            }
            this.apply(call, aggregate, project, scan);
        } else if (call.rels.length == 2) {
            LogicalAggregate aggregate = (LogicalAggregate)call.rel(0);
            CalciteLogicalIndexScan scan = (CalciteLogicalIndexScan)call.rel(1);
            this.apply(call, aggregate, null, scan);
        } else {
            throw new AssertionError((Object)String.format("The length of rels should be %s but got %s", this.operands.size(), call.rels.length));
        }
    }

    protected void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProject project, CalciteLogicalIndexScan scan) {
        AbstractRelNode newRelNode = scan.pushDownAggregate((Aggregate)aggregate, (Project)project);
        if (newRelNode != null) {
            call.transformTo((RelNode)newRelNode);
        }
    }

    @Value.Immutable
    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = ImmutableOpenSearchAggregateIndexScanRule.Config.builder().build().withDescription("Agg-Project-TableScan").withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).oneInput(b1 -> b1.operand(LogicalProject.class).predicate(Predicate.not(OpenSearchIndexScanRule::containsRexOver).and(OpenSearchIndexScanRule::distinctProjectList).or(Config::containsWidthBucketFuncOnDate)).oneInput(b2 -> b2.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(OpenSearchIndexScanRule::isLimitPushed).and(OpenSearchIndexScanRule::noAggregatePushed)).noInputs())));
        public static final Config COUNT_STAR = ImmutableOpenSearchAggregateIndexScanRule.Config.builder().build().withDescription("Agg[count()]-TableScan").withOperandSupplier(b0 -> b0.operand(LogicalAggregate.class).predicate(agg -> agg.getGroupSet().isEmpty() && agg.getAggCallList().stream().allMatch(call -> call.getAggregation().kind == SqlKind.COUNT && call.getArgList().isEmpty())).oneInput(b1 -> b1.operand(CalciteLogicalIndexScan.class).predicate(Predicate.not(OpenSearchIndexScanRule::isLimitPushed).and(OpenSearchIndexScanRule::noAggregatePushed)).noInputs()));

        default public OpenSearchAggregateIndexScanRule toRule() {
            return new OpenSearchAggregateIndexScanRule(this);
        }

        public static boolean containsWidthBucketFuncOnDate(LogicalProject project) {
            return project.getProjects().stream().anyMatch(expr -> {
                RexCall rexCall;
                return expr instanceof RexCall && (rexCall = (RexCall)expr).getOperator().equals((Object)PPLBuiltinOperators.WIDTH_BUCKET) && Config.dateRelatedType(((RexNode)rexCall.getOperands().getFirst()).getType());
            });
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        public static boolean dateRelatedType(RelDataType type) {
            if (!(type instanceof ExprSqlType)) return false;
            ExprSqlType exprSqlType = (ExprSqlType)type;
            if (!List.of(OpenSearchTypeFactory.ExprUDT.EXPR_DATE, OpenSearchTypeFactory.ExprUDT.EXPR_TIME, OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP).contains((Object)exprSqlType.getUdt())) return false;
            return true;
        }
    }
}

