/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.clustering.evaluation;

import com.oracle.labs.mlrg.olcut.util.MutableLong;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import org.apache.commons.math3.special.Gamma;
import org.tribuo.clustering.ClusterID;
import org.tribuo.clustering.evaluation.ClusteringMetric;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.util.infotheory.InformationTheory;
import org.tribuo.util.infotheory.impl.PairDistribution;

public enum ClusteringMetrics {
    NORMALIZED_MI((target, context) -> ClusteringMetrics.normalizedMI(context)),
    ADJUSTED_MI((target, context) -> ClusteringMetrics.adjustedMI(context));

    private final BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> impl;

    private ClusteringMetrics(BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> impl) {
        this.impl = impl;
    }

    public BiFunction<MetricTarget<ClusterID>, ClusteringMetric.Context, Double> getImpl() {
        return this.impl;
    }

    public ClusteringMetric forTarget(MetricTarget<ClusterID> tgt) {
        return new ClusteringMetric(tgt, this.name(), this.getImpl());
    }

    public static double adjustedMI(ClusteringMetric.Context context) {
        double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs());
        double predEntropy = InformationTheory.entropy(context.getPredictedIDs());
        double trueEntropy = InformationTheory.entropy(context.getTrueIDs());
        double expectedMI = ClusteringMetrics.expectedMI(context.getPredictedIDs(), context.getTrueIDs());
        double minEntropy = Math.min(predEntropy, trueEntropy);
        return (mi - expectedMI) / (minEntropy - expectedMI);
    }

    public static double normalizedMI(ClusteringMetric.Context context) {
        double trueEntropy;
        double mi = InformationTheory.mi(context.getPredictedIDs(), context.getTrueIDs());
        double predEntropy = InformationTheory.entropy(context.getPredictedIDs());
        return predEntropy < (trueEntropy = InformationTheory.entropy(context.getTrueIDs())) ? mi / predEntropy : mi / trueEntropy;
    }

    private static double expectedMI(List<Integer> first, List<Integer> second) {
        PairDistribution pd = PairDistribution.constructFromLists(first, second);
        Map firstCount = pd.firstCount;
        Map secondCount = pd.secondCount;
        long count = pd.count;
        double output = 0.0;
        for (Map.Entry f : firstCount.entrySet()) {
            for (Map.Entry s : secondCount.entrySet()) {
                long start;
                long fVal = ((MutableLong)f.getValue()).longValue();
                long sVal = ((MutableLong)s.getValue()).longValue();
                long minCount = Math.min(fVal, sVal);
                for (long nij = start = (threshold = fVal + sVal - count) > 1L ? threshold : 1L; nij < minCount; ++nij) {
                    double acc = (double)nij / (double)count;
                    acc *= Math.log((double)(count * nij) / (double)(fVal * sVal));
                    double logSpace = Gamma.logGamma((double)(fVal + 1L));
                    logSpace += Gamma.logGamma((double)(sVal + 1L));
                    logSpace += Gamma.logGamma((double)(count - fVal + 1L));
                    logSpace += Gamma.logGamma((double)(count - sVal + 1L));
                    logSpace -= Gamma.logGamma((double)(count + 1L));
                    logSpace -= Gamma.logGamma((double)(nij + 1L));
                    logSpace -= Gamma.logGamma((double)(fVal - nij + 1L));
                    logSpace -= Gamma.logGamma((double)(sVal - nij + 1L));
                    output += (acc *= Math.exp(logSpace -= Gamma.logGamma((double)(count - fVal - sVal + nij + 1L))));
                }
            }
        }
        return output;
    }
}

