/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.agent.tools.utils.clustering;

import com.google.common.collect.Lists;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.Generated;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.clustering.DoublePoint;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.agent.tools.utils.clustering.HierarchicalAgglomerativeClustering;

public class ClusteringHelper {
    @Generated
    private static final Logger log = LogManager.getLogger(ClusteringHelper.class);
    private final double logVectorsClusteringThreshold;

    public ClusteringHelper(double logVectorsClusteringThreshold) {
        if (logVectorsClusteringThreshold < 0.0 || logVectorsClusteringThreshold > 1.0) {
            throw new IllegalArgumentException("Clustering threshold must be between 0.0 and 1.0, got: " + logVectorsClusteringThreshold);
        }
        this.logVectorsClusteringThreshold = logVectorsClusteringThreshold;
    }

    public List<String> clusterLogVectorsAndGetRepresentative(Map<String, double[]> logVectors) {
        if (logVectors == null || logVectors.isEmpty()) {
            return new ArrayList<String>();
        }
        this.validateLogVectors(logVectors);
        log.debug("Starting two-phase clustering for {} log vectors", (Object)logVectors.size());
        double[][] vectors = new double[logVectors.size()][];
        HashMap<Integer, String> indexTraceIdMap = new HashMap<Integer, String>();
        this.convertLogVectorsToArrays(logVectors, vectors, indexTraceIdMap);
        List<String> finalCentroids = logVectors.size() > 1000 ? this.processTwoPhaseClusteringForLargeDataset(vectors, indexTraceIdMap) : this.performClustering(vectors, indexTraceIdMap);
        log.debug("Two-phase clustering completed: {} input vectors -> {} representative centroids", (Object)logVectors.size(), (Object)finalCentroids.size());
        return finalCentroids;
    }

    private void convertLogVectorsToArrays(Map<String, double[]> logVectors, double[][] vectors, Map<Integer, String> indexTraceIdMap) {
        int i = 0;
        for (Map.Entry<String, double[]> entry : logVectors.entrySet()) {
            vectors[i] = entry.getValue();
            indexTraceIdMap.put(i, entry.getKey());
            ++i;
        }
    }

    private List<String> processTwoPhaseClusteringForLargeDataset(double[][] vectors, Map<Integer, String> indexTraceIdMap) {
        ArrayList<String> finalCentroids = new ArrayList();
        log.debug("Large dataset detected ({}), applying K-means pre-clustering", (Object)vectors.length);
        int targetClusterSize = 500;
        int numKMeansClusters = (vectors.length + (targetClusterSize - 1)) / targetClusterSize;
        log.debug("Using {} K-means clusters for pre-clustering", (Object)numKMeansClusters);
        try {
            List<List<Integer>> kMeansClusters = this.performKMeansClustering(vectors, numKMeansClusters);
            for (int clusterIdx = 0; clusterIdx < kMeansClusters.size(); ++clusterIdx) {
                List<Integer> kMeansCluster = kMeansClusters.get(clusterIdx);
                log.debug("Processing K-means cluster {} with {} points", (Object)clusterIdx, (Object)kMeansCluster.size());
                List<String> clusterCentroids = this.processCluster(kMeansCluster, vectors, indexTraceIdMap, clusterIdx);
                finalCentroids.addAll(clusterCentroids);
            }
        }
        catch (Exception e) {
            log.warn("K-means clustering failed, falling back to hierarchical clustering only: {}", (Object)e.getMessage());
            finalCentroids = this.performClustering(vectors, indexTraceIdMap);
        }
        return finalCentroids;
    }

    private List<String> processCluster(List<Integer> kMeansCluster, double[][] vectors, Map<Integer, String> indexTraceIdMap, int clusterIdx) {
        if (kMeansCluster.isEmpty()) {
            return List.of();
        }
        if (kMeansCluster.size() == 1) {
            return List.of(indexTraceIdMap.get(kMeansCluster.getFirst()));
        }
        if (kMeansCluster.size() > 500) {
            log.debug("The cluster size is greater than 500, performing partitioned clustering");
            return this.performHierarchicalClusteringOfPartition(kMeansCluster, vectors, indexTraceIdMap);
        }
        log.debug("Applying hierarchical clustering to K-means cluster {} with {} points", (Object)clusterIdx, (Object)kMeansCluster.size());
        double[][] clusterVectors = this.extractVectors(kMeansCluster, vectors);
        Map<Integer, String> clusterIndexTraceIdMap = this.createTraceIdMapping(kMeansCluster, indexTraceIdMap);
        return this.performClustering(clusterVectors, clusterIndexTraceIdMap);
    }

    private List<List<Integer>> performKMeansClustering(double[][] vectors, int numClusters) {
        if (vectors == null || vectors.length == 0) {
            return new ArrayList<List<Integer>>();
        }
        if (numClusters <= 0) {
            numClusters = 1;
        }
        numClusters = Math.min(numClusters, vectors.length);
        try {
            KMeansPlusPlusClusterer<DoublePoint> clusterer = this.createKMeansClusterer(numClusters);
            List<DoublePoint> points = this.convertVectorsToPoints(vectors);
            List clusters = clusterer.cluster(points);
            return this.extractClusterIndices(clusters, vectors);
        }
        catch (Exception e) {
            log.error("K-means clustering failed: {}", (Object)e.getMessage(), (Object)e);
            throw new RuntimeException("K-means clustering failed: " + e.getMessage(), e);
        }
    }

    private KMeansPlusPlusClusterer<DoublePoint> createKMeansClusterer(int numClusters) {
        return new KMeansPlusPlusClusterer(numClusters, 300, (DistanceMeasure & Serializable)(a, b) -> 1.0 - HierarchicalAgglomerativeClustering.calculateCosineSimilarity(a, b));
    }

    private List<DoublePoint> convertVectorsToPoints(double[][] vectors) {
        ArrayList<DoublePoint> points = new ArrayList<DoublePoint>(vectors.length);
        for (double[] vector : vectors) {
            points.add(new DoublePoint(vector));
        }
        return points;
    }

    private void validateLogVectors(Map<String, double[]> logVectors) {
        int vectorDimension = -1;
        for (Map.Entry<String, double[]> entry : logVectors.entrySet()) {
            String traceId = entry.getKey();
            double[] vector = entry.getValue();
            if (traceId == null || traceId.isEmpty()) {
                throw new IllegalArgumentException("Trace ID cannot be null or empty");
            }
            if (vector == null) {
                throw new IllegalArgumentException("Vector for trace ID '" + traceId + "' is null");
            }
            if (vector.length == 0) {
                throw new IllegalArgumentException("Vector for trace ID '" + traceId + "' is empty");
            }
            if (vectorDimension == -1) {
                vectorDimension = vector.length;
            } else if (vector.length != vectorDimension) {
                throw new IllegalArgumentException("Vector dimension mismatch: expected " + vectorDimension + " but got " + vector.length + " for trace ID '" + traceId + "'");
            }
            for (int i = 0; i < vector.length; ++i) {
                if (!Double.isNaN(vector[i]) && !Double.isInfinite(vector[i])) continue;
                throw new IllegalArgumentException("Vector for trace ID '" + traceId + "' contains invalid value at index " + i + ": " + vector[i]);
            }
        }
    }

    private List<List<Integer>> extractClusterIndices(List<CentroidCluster<DoublePoint>> clusters, double[][] vectors) {
        ArrayList<List<Integer>> result = new ArrayList<List<Integer>>();
        for (CentroidCluster<DoublePoint> cluster : clusters) {
            ArrayList<Integer> clusterIndices = new ArrayList<Integer>();
            block1: for (DoublePoint point : cluster.getPoints()) {
                for (int i = 0; i < vectors.length; ++i) {
                    if (!Arrays.equals(vectors[i], point.getPoint())) continue;
                    clusterIndices.add(i);
                    continue block1;
                }
            }
            if (clusterIndices.isEmpty()) continue;
            result.add(clusterIndices);
        }
        return result;
    }

    private List<String> performClustering(double[][] vectors, Map<Integer, String> indexTraceIdMap) {
        if (vectors == null || vectors.length == 0) {
            return List.of();
        }
        if (vectors.length == 1) {
            String traceId = indexTraceIdMap.get(0);
            return List.of(traceId);
        }
        ArrayList<String> centroids = new ArrayList<String>();
        try {
            HierarchicalAgglomerativeClustering hac = new HierarchicalAgglomerativeClustering(vectors);
            List<HierarchicalAgglomerativeClustering.ClusterNode> clusters = hac.fit(HierarchicalAgglomerativeClustering.LinkageMethod.COMPLETE, this.logVectorsClusteringThreshold);
            for (HierarchicalAgglomerativeClustering.ClusterNode cluster : clusters) {
                int centroidIndex = hac.getClusterCentroid(cluster);
                String traceId = indexTraceIdMap.get(centroidIndex);
                centroids.add(traceId);
            }
        }
        catch (Exception e) {
            log.error("Hierarchical clustering failed: {}", (Object)e.getMessage(), (Object)e);
            String traceId = indexTraceIdMap.get(0);
            centroids.add(traceId);
        }
        return centroids;
    }

    private List<String> performHierarchicalClusteringOfPartition(List<Integer> kMeansCluster, double[][] vectors, Map<Integer, String> indexTraceIdMap) {
        List partition = Lists.partition(kMeansCluster, (int)500);
        ArrayList<double[]> vectorRes = new ArrayList<double[]>();
        HashMap<Integer, String> index2Trace = new HashMap<Integer, String>();
        for (List partList : partition) {
            double[][] clusterVectors = this.extractVectors(partList, vectors);
            Map<Integer, String> clusterIndexTraceIdMap = this.createTraceIdMapping(partList, indexTraceIdMap);
            log.debug("Starting performHierarchicalClusteringOfPartition!");
            this.processPartition(clusterVectors, clusterIndexTraceIdMap, vectorRes, index2Trace);
        }
        return this.removeSimilarVectors(vectorRes, index2Trace);
    }

    private double[][] extractVectors(List<Integer> partList, double[][] vectors) {
        double[][] clusterVectors = new double[partList.size()][];
        for (int j = 0; j < partList.size(); ++j) {
            int originalIndex = partList.get(j);
            clusterVectors[j] = vectors[originalIndex];
        }
        return clusterVectors;
    }

    private Map<Integer, String> createTraceIdMapping(List<Integer> partList, Map<Integer, String> indexTraceIdMap) {
        HashMap<Integer, String> clusterIndexTraceIdMap = new HashMap<Integer, String>();
        for (int j = 0; j < partList.size(); ++j) {
            int originalIndex = partList.get(j);
            clusterIndexTraceIdMap.put(j, indexTraceIdMap.get(originalIndex));
        }
        return clusterIndexTraceIdMap;
    }

    private void processPartition(double[][] clusterVectors, Map<Integer, String> clusterIndexTraceIdMap, List<double[]> vectorRes, Map<Integer, String> index2Trace) {
        if (clusterVectors.length == 0) {
            return;
        }
        if (clusterVectors.length == 1) {
            vectorRes.add(clusterVectors[0]);
            index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(0));
            return;
        }
        try {
            HierarchicalAgglomerativeClustering hac = new HierarchicalAgglomerativeClustering(clusterVectors);
            List<HierarchicalAgglomerativeClustering.ClusterNode> clusters = hac.fit(HierarchicalAgglomerativeClustering.LinkageMethod.COMPLETE, this.logVectorsClusteringThreshold);
            log.info("Completing performHierarchicalClusteringOfPartition!");
            for (HierarchicalAgglomerativeClustering.ClusterNode cluster : clusters) {
                int centroidIndex = hac.getClusterCentroid(cluster);
                vectorRes.add(clusterVectors[centroidIndex]);
                index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(centroidIndex));
            }
        }
        catch (Exception e) {
            log.error("Hierarchical clustering failed: {}", (Object)e.getMessage(), (Object)e);
            vectorRes.add(clusterVectors[0]);
            index2Trace.put(vectorRes.size() - 1, clusterIndexTraceIdMap.get(0));
        }
    }

    private List<String> removeSimilarVectors(List<double[]> vectorRes, Map<Integer, String> index2Trace) {
        HashSet<Integer> toRemove = new HashSet<Integer>();
        for (int i = 0; i < vectorRes.size(); ++i) {
            if (toRemove.contains(i)) continue;
            for (int j = i + 1; j < vectorRes.size(); ++j) {
                double similarity;
                if (toRemove.contains(j) || !((similarity = HierarchicalAgglomerativeClustering.calculateCosineSimilarity(vectorRes.get(i), vectorRes.get(j))) > this.logVectorsClusteringThreshold)) continue;
                log.debug("Removing similar vector with similarity: {}", (Object)similarity);
                toRemove.add(j);
            }
        }
        log.debug("Removed {} similar vectors out of {}", (Object)toRemove.size(), (Object)vectorRes.size());
        return this.collectNonRemovedTraceIds(vectorRes, index2Trace, toRemove);
    }

    private List<String> collectNonRemovedTraceIds(List<double[]> vectors, Map<Integer, String> indexToTraceMap, Set<Integer> indicesToRemove) {
        ArrayList<String> result = new ArrayList<String>(vectors.size() - indicesToRemove.size());
        for (int i = 0; i < vectors.size(); ++i) {
            if (indicesToRemove.contains(i)) continue;
            result.add(indexToTraceMap.get(i));
        }
        return result;
    }
}

