/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.dataprepper.plugins.processor.oteltrace;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.opensearch.dataprepper.model.trace.Span;
import org.opensearch.dataprepper.plugins.processor.oteltrace.GenAiAttributeMappings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

final class GenAiEnrichmentHelper {
    private static final Logger LOG = LoggerFactory.getLogger(GenAiEnrichmentHelper.class);
    private static final String GEN_AI_SYSTEM_KEY = "gen_ai.system";
    private static final String GEN_AI_PROVIDER_NAME_KEY = "gen_ai.provider.name";
    private static final String GEN_AI_AGENT_NAME_KEY = "gen_ai.agent.name";
    private static final String GEN_AI_REQUEST_MODEL_KEY = "gen_ai.request.model";
    private static final String GEN_AI_INPUT_TOKENS_KEY = "gen_ai.usage.input_tokens";
    private static final String GEN_AI_OUTPUT_TOKENS_KEY = "gen_ai.usage.output_tokens";
    private static final String ATTRIBUTES_PREFIX = "attributes/";
    private static final String[] PROPAGATED_STRING_KEYS = new String[]{"gen_ai.system", "gen_ai.provider.name", "gen_ai.agent.name", "gen_ai.request.model"};
    private static final String[] FLATTENED_PARENT_KEYS = new String[]{"llm.input_messages", "llm.output_messages", "gen_ai.prompt", "gen_ai.completion"};

    private GenAiEnrichmentHelper() {
    }

    static void enrichBatch(List<Span> spans) {
        for (Span span : spans) {
            GenAiEnrichmentHelper.normalizeAttributes(span);
            GenAiEnrichmentHelper.stripFlattenedSubkeys(span);
        }
        HashMap<String, List> spansByTrace = new HashMap<String, List>();
        for (Span span : spans) {
            spansByTrace.computeIfAbsent(span.getTraceId(), k -> new ArrayList()).add(span);
        }
        for (List traceSpans : spansByTrace.values()) {
            Span rootSpan = null;
            ArrayList<Span> children = new ArrayList<Span>();
            for (Span span : traceSpans) {
                if (GenAiEnrichmentHelper.isRootSpan(span)) {
                    rootSpan = span;
                    continue;
                }
                children.add(span);
            }
            if (rootSpan == null || children.isEmpty()) continue;
            GenAiEnrichmentHelper.enrichRootSpan(rootSpan, children);
        }
    }

    private static boolean isRootSpan(Span span) {
        String parentSpanId = span.getParentSpanId();
        return parentSpanId == null || parentSpanId.isEmpty() || "0000000000000000".equals(parentSpanId);
    }

    static void enrichRootSpan(Span rootSpan, Collection<Span> children) {
        boolean rootHasTokens;
        Map rootAttrs = rootSpan.getAttributes();
        HashMap toPropagate = new HashMap();
        for (String key : PROPAGATED_STRING_KEYS) {
            if (rootAttrs != null && rootAttrs.containsKey(key)) continue;
            toPropagate.put(key, null);
        }
        boolean bl = rootHasTokens = rootAttrs != null && rootAttrs.containsKey(GEN_AI_INPUT_TOKENS_KEY);
        if (toPropagate.isEmpty() && rootHasTokens) {
            return;
        }
        long totalInputTokens = 0L;
        long totalOutputTokens = 0L;
        boolean foundTokens = false;
        for (Span span : children) {
            Map attrs = span.getAttributes();
            if (attrs == null) continue;
            for (Map.Entry entry : toPropagate.entrySet()) {
                if (entry.getValue() != null || !attrs.containsKey(entry.getKey())) continue;
                entry.setValue((String)attrs.get(entry.getKey()));
            }
            Number inputTokens = (Number)attrs.get(GEN_AI_INPUT_TOKENS_KEY);
            Number number = (Number)attrs.get(GEN_AI_OUTPUT_TOKENS_KEY);
            if (inputTokens == null && number == null) continue;
            foundTokens = true;
            if (inputTokens != null) {
                totalInputTokens += inputTokens.longValue();
            }
            if (number == null) continue;
            totalOutputTokens += number.longValue();
        }
        for (Map.Entry entry : toPropagate.entrySet()) {
            if (entry.getValue() == null) continue;
            rootSpan.put(ATTRIBUTES_PREFIX + (String)entry.getKey(), entry.getValue());
            LOG.debug("Propagated {} = {} to root span {}", new Object[]{entry.getKey(), entry.getValue(), rootSpan.getSpanId()});
        }
        if (!rootHasTokens && foundTokens) {
            rootSpan.put("attributes/gen_ai.usage.input_tokens", (Object)totalInputTokens);
            rootSpan.put("attributes/gen_ai.usage.output_tokens", (Object)totalOutputTokens);
            LOG.debug("Aggregated tokens (input={}, output={}) to root span {}", new Object[]{totalInputTokens, totalOutputTokens, rootSpan.getSpanId()});
        }
    }

    static void normalizeAttributes(Span span) {
        Map attrs = span.getAttributes();
        if (attrs == null) {
            return;
        }
        for (Map.Entry entry : new ArrayList(attrs.entrySet())) {
            GenAiAttributeMappings.MappingTarget target = GenAiAttributeMappings.getLookupTable().get(entry.getKey());
            if (target == null || attrs.containsKey(target.getKey())) continue;
            Object value = entry.getValue();
            if (value instanceof String) {
                String mapped;
                String strVal = (String)value;
                if ("gen_ai.operation.name".equals(target.getKey()) && (mapped = GenAiAttributeMappings.getOperationNameValues().get(strVal.toLowerCase())) != null) {
                    value = mapped;
                }
                if (target.isWrapAsArray()) {
                    value = "[\"" + String.valueOf(value) + "\"]";
                }
            }
            span.put(ATTRIBUTES_PREFIX + target.getKey(), value);
        }
    }

    static void stripFlattenedSubkeys(Span span) {
        Map attrs = span.getAttributes();
        if (attrs == null) {
            return;
        }
        ArrayList<String> toRemove = new ArrayList<String>();
        for (String parentKey : FLATTENED_PARENT_KEYS) {
            if (!attrs.containsKey(parentKey)) continue;
            for (String key : attrs.keySet()) {
                if (!key.startsWith(parentKey + ".") || key.length() <= parentKey.length() + 1 || !Character.isDigit(key.charAt(parentKey.length() + 1))) continue;
                toRemove.add(key);
            }
        }
        for (String key : toRemove) {
            try {
                span.delete(ATTRIBUTES_PREFIX + key);
            }
            catch (Exception e) {
                LOG.warn("Failed to delete flattened sub-key {}: {}", (Object)key, (Object)e.getMessage());
            }
        }
    }
}

