/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.util;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.Generated;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelFieldCollation;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.rex.RexBiVisitorImpl;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan;
import org.opensearch.sql.opensearch.storage.scan.context.PushDownType;
import org.opensearch.sql.opensearch.storage.scan.context.SortExprDigest;

public final class OpenSearchRelOptUtil {
    public static Optional<Pair<Integer, Boolean>> getOrderEquivalentInputInfo(RexNode expr) {
        switch (expr.getKind()) {
            case INPUT_REF: {
                RexInputRef inputRef = (RexInputRef)expr;
                return Optional.of(Pair.of((Object)inputRef.getIndex(), (Object)false));
            }
            case PLUS_PREFIX: {
                return OpenSearchRelOptUtil.getOrderEquivalentInputInfo((RexNode)((RexCall)expr).getOperands().get(0));
            }
            case MINUS_PREFIX: {
                return OpenSearchRelOptUtil.getOrderEquivalentInputInfo((RexNode)((RexCall)expr).getOperands().get(0)).map(inputInfo -> Pair.of((Object)((Integer)inputInfo.getLeft()), (Object)((Boolean)inputInfo.getRight() == false ? 1 : 0)));
            }
            case PLUS: 
            case MINUS: {
                RexNode operand0 = (RexNode)((RexCall)expr).getOperands().get(0);
                RexNode operand1 = (RexNode)((RexCall)expr).getOperands().get(1);
                boolean operand0Lit = operand0.isA(SqlKind.LITERAL);
                boolean operand1Lit = operand1.isA(SqlKind.LITERAL);
                if (operand0Lit == operand1Lit) {
                    return Optional.empty();
                }
                RexNode variable = operand0Lit ? operand1 : operand0;
                boolean flipped = expr.getKind() == SqlKind.MINUS && operand0Lit;
                return OpenSearchRelOptUtil.getOrderEquivalentInputInfo(variable).map(inputInfo -> Pair.of((Object)((Integer)inputInfo.getLeft()), (Object)(flipped != (Boolean)inputInfo.getRight() ? 1 : 0)));
            }
            case TIMES: {
                RexNode variable;
                RexNode operand0 = (RexNode)((RexCall)expr).getOperands().get(0);
                RexNode operand1 = (RexNode)((RexCall)expr).getOperands().get(1);
                RexNode lit = operand0.isA(SqlKind.LITERAL) ? operand0 : (operand1.isA(SqlKind.LITERAL) ? operand1 : null);
                RexNode rexNode = variable = lit == operand0 ? operand1 : operand0;
                if (lit == null) {
                    return Optional.empty();
                }
                BigDecimal k = (BigDecimal)((RexLiteral)lit).getValueAs(BigDecimal.class);
                if (k == null || k.signum() == 0) {
                    return Optional.empty();
                }
                boolean flipped = k.signum() < 0;
                return OpenSearchRelOptUtil.getOrderEquivalentInputInfo(variable).map(inputInfo -> Pair.of((Object)((Integer)inputInfo.getLeft()), (Object)(flipped != (Boolean)inputInfo.getRight() ? 1 : 0)));
            }
            case CAST: 
            case SAFE_CAST: {
                RexNode child = (RexNode)((RexCall)expr).getOperands().get(0);
                if (!OpenSearchRelOptUtil.isOrderPreservingCast(child.getType(), expr.getType())) {
                    return Optional.empty();
                }
                return OpenSearchRelOptUtil.getOrderEquivalentInputInfo(child);
            }
        }
        return Optional.empty();
    }

    public static boolean sourceCollationSatisfiesTargetCollation(RelFieldCollation sourceFieldCollation, RelFieldCollation targetFieldCollation, Optional<Pair<Integer, Boolean>> orderEquivInfo) {
        if (orderEquivInfo.isEmpty()) {
            return false;
        }
        int equivalentSourceIndex = (Integer)orderEquivInfo.get().getLeft();
        RelFieldCollation.Direction equivalentSourceDirection = (Boolean)orderEquivInfo.get().getRight() != false ? targetFieldCollation.getDirection().reverse() : targetFieldCollation.getDirection();
        return equivalentSourceIndex == sourceFieldCollation.getFieldIndex() && equivalentSourceDirection == sourceFieldCollation.getDirection();
    }

    private static boolean isOrderPreservingCast(RelDataType src, RelDataType dst) {
        SqlTypeName srcType = src.getSqlTypeName();
        SqlTypeName dstType = dst.getSqlTypeName();
        if (SqlTypeUtil.isIntType((RelDataType)src) && SqlTypeUtil.isApproximateNumeric((RelDataType)dst)) {
            int intBits = switch (srcType) {
                case SqlTypeName.TINYINT -> 8;
                case SqlTypeName.SMALLINT -> 16;
                case SqlTypeName.INTEGER -> 32;
                case SqlTypeName.BIGINT -> 64;
                default -> 0;
            };
            int floatBits = switch (dstType) {
                case SqlTypeName.FLOAT -> 24;
                case SqlTypeName.DOUBLE -> 53;
                default -> 0;
            };
            return intBits > 0 && floatBits > 0 && intBits <= floatBits;
        }
        if (SqlTypeUtil.isExactNumeric((RelDataType)src) && SqlTypeUtil.isExactNumeric((RelDataType)dst)) {
            int srcPrec = src.getPrecision();
            int dstPrec = dst.getPrecision();
            return dstPrec >= srcPrec;
        }
        if (SqlTypeUtil.isCharacter((RelDataType)src) && SqlTypeUtil.isCharacter((RelDataType)dst)) {
            int srcLength = src.getPrecision();
            int dstLength = dst.getPrecision();
            return dstLength >= srcLength || dstLength == -1;
        }
        if (srcType == SqlTypeName.DATE && (dstType == SqlTypeName.TIMESTAMP || dstType == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) {
            return true;
        }
        if (srcType == SqlTypeName.TIME && (dstType == SqlTypeName.TIMESTAMP || dstType == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE)) {
            return true;
        }
        if (srcType == dstType) {
            return dst.getPrecision() >= src.getPrecision() && dst.getScale() >= src.getScale();
        }
        return false;
    }

    public static RelDataType replaceDot(RelDataTypeFactory typeFactory, RelDataType rowType) {
        RelDataTypeFactory.FieldInfoBuilder builder = typeFactory.builder();
        List fieldList = rowType.getFieldList();
        ArrayList<String> originalNames = new ArrayList<String>();
        for (RelDataTypeField field : fieldList) {
            originalNames.add(field.getName());
        }
        List<String> resolvedNames = OpenSearchRelOptUtil.resolveColumnNameConflicts(originalNames);
        for (int i = 0; i < fieldList.size(); ++i) {
            RelDataTypeField field = (RelDataTypeField)fieldList.get(i);
            builder.add((RelDataTypeField)new RelDataTypeFieldImpl(resolvedNames.get(i), field.getIndex(), field.getType()));
        }
        return builder.build();
    }

    public static List<String> resolveColumnNameConflicts(List<String> originalNames) {
        ArrayList<String> result = new ArrayList<String>(originalNames);
        HashSet<String> usedNames = new HashSet<String>(originalNames);
        for (int i = 0; i < originalNames.size(); ++i) {
            String originalName = originalNames.get(i);
            if (!originalName.contains(".")) continue;
            String baseName = originalName.replace('.', '_');
            String newName = OpenSearchRelOptUtil.generateUniqueName(baseName, usedNames);
            result.set(i, newName);
            usedNames.add(newName);
        }
        return result;
    }

    private static String generateUniqueName(String baseName, Set<String> usedNames) {
        if (!usedNames.contains(baseName)) {
            return baseName;
        }
        String candidate = baseName + "0";
        if (!usedNames.contains(candidate)) {
            return candidate;
        }
        int suffix = 1;
        while (usedNames.contains(candidate = baseName + suffix)) {
            ++suffix;
        }
        return candidate;
    }

    public static boolean canScanProvideSortCollation(AbstractCalciteIndexScan scan, Project project, RelCollation toCollation, Map<Integer, Optional<Pair<Integer, Boolean>>> orderEquivInfoMap) {
        if (scan.getPushDownContext().stream().noneMatch(operation -> operation.type() == PushDownType.SORT_EXPR)) {
            return false;
        }
        List sortExprDigests = (List)scan.getPushDownContext().getDigestByType(PushDownType.SORT_EXPR);
        if (sortExprDigests.isEmpty() || sortExprDigests.size() < toCollation.getFieldCollations().size()) {
            return false;
        }
        for (int i = 0; i < toCollation.getFieldCollations().size(); ++i) {
            RexInputRef scanInputRef;
            RelFieldCollation sourceCollation;
            RelFieldCollation requiredFieldCollation = (RelFieldCollation)toCollation.getFieldCollations().get(i);
            RexNode projectExpr = (RexNode)project.getProjects().get(requiredFieldCollation.getFieldIndex());
            SortExprDigest scanSortInfo = (SortExprDigest)sortExprDigests.get(i);
            RexNode scanSortExpression = scanSortInfo.getEffectiveExpression(scan);
            if (scanSortExpression != null && scanSortExpression.equals((Object)projectExpr)) {
                if (requiredFieldCollation.getDirection() == scanSortInfo.getDirection() && requiredFieldCollation.nullDirection == scanSortInfo.getNullDirection()) continue;
                return false;
            }
            if (scanSortExpression instanceof RexInputRef && projectExpr instanceof RexCall && OpenSearchRelOptUtil.sourceCollationSatisfiesTargetCollation(sourceCollation = new RelFieldCollation((scanInputRef = (RexInputRef)scanSortExpression).getIndex(), scanSortInfo.getDirection(), scanSortInfo.getNullDirection()), requiredFieldCollation, orderEquivInfoMap.get(requiredFieldCollation.getFieldIndex()))) continue;
            return false;
        }
        return true;
    }

    @Generated
    private OpenSearchRelOptUtil() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }

    private static class RemapIndexBiVisitor
    extends RexBiVisitorImpl<Void, Pair<BitSet, List<Integer>>> {
        protected RemapIndexBiVisitor(boolean deep) {
            super(deep);
        }

        public Void visitInputRef(RexInputRef inputRef, Pair<BitSet, List<Integer>> args) {
            BitSet seenOldIndex = (BitSet)args.getLeft();
            List newMappings = (List)args.getRight();
            int oldIdx = inputRef.getIndex();
            if (!seenOldIndex.get(oldIdx)) {
                seenOldIndex.set(oldIdx);
                newMappings.add(oldIdx);
            }
            return null;
        }
    }
}

