/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.dataprepper.plugins.ml_inference.processor.common;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;
import org.json.JSONArray;
import org.json.JSONObject;
import org.opensearch.dataprepper.aws.api.AwsCredentialsSupplier;
import org.opensearch.dataprepper.common.utils.RetryUtil;
import org.opensearch.dataprepper.logging.DataPrepperMarkers;
import org.opensearch.dataprepper.metrics.PluginMetrics;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.event.EventKey;
import org.opensearch.dataprepper.model.failures.DlqObject;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessor;
import org.opensearch.dataprepper.plugins.ml_inference.processor.MLProcessorConfig;
import org.opensearch.dataprepper.plugins.ml_inference.processor.client.S3ClientFactory;
import org.opensearch.dataprepper.plugins.ml_inference.processor.common.AbstractBatchJobCreator;
import org.opensearch.dataprepper.plugins.ml_inference.processor.dlq.DlqPushHandler;
import org.opensearch.dataprepper.plugins.ml_inference.processor.exception.MLBatchJobException;
import org.slf4j.Logger;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;

public class SageMakerBatchJobCreator
extends AbstractBatchJobCreator {
    private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
    private final AwsCredentialsSupplier awsCredentialsSupplier;
    private final S3Client s3Client;
    private final Lock batchProcessingLock;
    private final ConcurrentLinkedQueue<Record<Event>> batch_records = new ConcurrentLinkedQueue();
    private final ConcurrentLinkedQueue<Record<Event>> processedBatchRecords = new ConcurrentLinkedQueue();
    private final ConcurrentLinkedQueue<AbstractBatchJobCreator.RetryRecord> retryQueue = new ConcurrentLinkedQueue();
    private final int maxBatchSize;
    private final AtomicLong lastUpdateTimestamp = new AtomicLong(-1L);
    private static final long INACTIVITY_TIMEOUT_MS = 60000L;
    private volatile boolean retryRecordsAddedToBatch = false;
    private static final String SAGEMAKER_PAYLOAD_TEMPLATE = "{\"parameters\":{\"TransformInput\":{\"ContentType\":\"application/json\",\"DataSource\":{\"S3DataSource\":{\"S3DataType\":\"ManifestFile\",\"S3Uri\":\"\"}},\"SplitType\":\"Line\"},\"TransformJobName\":\"\",\"TransformOutput\":{\"AssembleWith\":\"Line\",\"Accept\":\"application/json\",\"S3OutputPath\":\"s3://\"}}}";

    public SageMakerBatchJobCreator(MLProcessorConfig mlProcessorConfig, AwsCredentialsSupplier awsCredentialsSupplier, PluginMetrics pluginMetrics, DlqPushHandler dlqPushHandler) {
        super(mlProcessorConfig, awsCredentialsSupplier, pluginMetrics, dlqPushHandler);
        this.awsCredentialsSupplier = awsCredentialsSupplier;
        this.s3Client = S3ClientFactory.createS3Client(mlProcessorConfig, awsCredentialsSupplier);
        this.maxBatchSize = mlProcessorConfig.getMaxBatchSize();
        this.batchProcessingLock = new ReentrantLock();
    }

    @Override
    public void createMLBatchJob(List<Record<Event>> inputRecords, List<Record<Event>> resultRecords) {
        if (inputRecords.isEmpty()) {
            return;
        }
        this.batch_records.addAll(inputRecords);
        this.lastUpdateTimestamp.set(System.currentTimeMillis());
        MLProcessor.LOG.info("Added {} records to batch. Current batch size: {}", (Object)inputRecords.size(), (Object)this.batch_records.size());
    }

    @Override
    public void addProcessedBatchRecordsToResults(List<Record<Event>> resultRecords) {
        if (!this.batchProcessingLock.tryLock()) {
            MLProcessor.LOG.debug("Another thread is currently processing results, skipping this attempt");
            return;
        }
        try {
            if (!this.processedBatchRecords.isEmpty()) {
                resultRecords.addAll(this.processedBatchRecords);
                MLProcessor.LOG.info("Result records updated: {} processed records added, new total size: {}", (Object)this.processedBatchRecords.size(), (Object)resultRecords.size());
                this.processedBatchRecords.clear();
            }
        }
        finally {
            this.batchProcessingLock.unlock();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void checkAndProcessBatch() {
        if (!this.batchProcessingLock.tryLock()) {
            MLProcessor.LOG.debug("Another thread is currently processing the current batch, skipping this attempt");
            return;
        }
        try {
            if (!this.retryRecordsAddedToBatch) {
                this.processRetryQueue();
            }
            if (this.batch_records.isEmpty()) {
                return;
            }
            boolean shouldProcess = false;
            long currentTime = System.currentTimeMillis();
            long lastUpdate = this.lastUpdateTimestamp.get();
            if (this.batch_records.size() >= this.maxBatchSize) {
                shouldProcess = true;
                MLProcessor.LOG.info("Processing batch due to size limit reached: {}", (Object)this.batch_records.size());
            } else if (lastUpdate != -1L && currentTime - lastUpdate >= 60000L) {
                shouldProcess = true;
                MLProcessor.LOG.info("Processing batch due to inactivity timeout. Time since last update: {} ms", (Object)(currentTime - lastUpdate));
            }
            if (shouldProcess) {
                ArrayList<Record<Event>> currentBatch = new ArrayList<Record<Event>>(this.batch_records);
                this.batch_records.clear();
                this.lastUpdateTimestamp.set(-1L);
                this.retryRecordsAddedToBatch = false;
                this.processCurrentBatch(currentBatch);
            }
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Error in batch processing check: ", (Throwable)e);
        }
        finally {
            this.batchProcessingLock.unlock();
        }
    }

    private void processCurrentBatch(List<Record<Event>> currentBatch) {
        try {
            String customerBucket = currentBatch.stream().findAny().map(record -> ((Event)record.getData()).getJsonNode().get("bucket").asText()).orElse(null);
            String commonPrefix = this.findCommonPrefix(currentBatch);
            String manifestUrl = this.generateManifest(currentBatch, customerBucket, commonPrefix);
            String payload = this.createPayloadSageMaker(manifestUrl, this.mlProcessorConfig);
            RetryUtil.RetryResult result = RetryUtil.retryWithBackoffWithResult(() -> this.mlCommonRequester.sendRequestToMLCommons(payload), (Logger)MLProcessor.LOG);
            if (result.isSuccess()) {
                this.handleSuccess(currentBatch, manifestUrl);
            } else {
                this.handleRetryOrFailure(currentBatch, result, manifestUrl);
            }
        }
        catch (IllegalArgumentException e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Invalid arguments for SageMaker batch job. Error: {}", (Object)e.getMessage());
            this.handleFailure(currentBatch, this.processedBatchRecords, e, 400);
        }
        catch (RuntimeException e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Runtime Exception for SageMaker batch job. Error: {}", (Object)e.getMessage());
            this.handleFailure(currentBatch, this.processedBatchRecords, e, 500);
        }
        catch (Exception e) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Unexpected Error occurred while creating a batch job through SageMaker: {}", (Object)e.getMessage(), (Object)e);
            this.handleFailure(currentBatch, this.processedBatchRecords, e, 500);
        }
    }

    private void handleSuccess(List<Record<Event>> currentBatch, String manifestUrl) {
        MLProcessor.LOG.info("Successfully created SageMaker batch job for manifest URL: {}", (Object)manifestUrl);
        this.removeCurrentBatchFromRetryQueue(currentBatch);
        this.processedBatchRecords.addAll(currentBatch);
        this.incrementSuccessCounter();
        this.numberOfRecordsSuccessCounter.increment((double)currentBatch.size());
    }

    private void handleRetryOrFailure(List<Record<Event>> currentBatch, RetryUtil.RetryResult result, String manifestUrl) {
        Exception lastException = result.getLastException();
        if (lastException instanceof MLBatchJobException) {
            MLBatchJobException mlException = (MLBatchJobException)lastException;
            int statusCode = mlException.getStatusCode();
            if (statusCode == 429 || statusCode == 400 && this.isThrottlingError(mlException.getMessage())) {
                this.handleThrottling(currentBatch);
                return;
            }
            String errorMessage = String.format("Failed to Create SageMaker batch job after %d attempts: %s", result.getAttemptsMade(), mlException.getMessage());
            this.handleFailure(currentBatch, this.processedBatchRecords, new MLBatchJobException(statusCode, errorMessage), statusCode);
            MLProcessor.LOG.error("SageMaker batch job failed for manifest URL: {}. Status: {}, Error: {}", new Object[]{manifestUrl, statusCode, errorMessage});
        } else {
            this.handleFailure(currentBatch, this.processedBatchRecords, lastException, 500);
            MLProcessor.LOG.error("SageMaker batch job failed for manifest URL: {}. Status: {}, Error: {}", new Object[]{manifestUrl, 500, lastException.getMessage()});
        }
    }

    private boolean isThrottlingError(String errorMessage) {
        if (errorMessage == null) {
            return false;
        }
        return errorMessage.toLowerCase().contains("throttling") || errorMessage.toLowerCase().contains("request was denied due to remote server throttling");
    }

    private void handleThrottling(List<Record<Event>> currentBatch) {
        MLProcessor.LOG.warn("Rate limited (429). Adding {} records to retry queue", (Object)currentBatch.size());
        if (!this.retryQueue.isEmpty()) {
            Set existingRetryRecords = this.retryQueue.stream().map(AbstractBatchJobCreator.RetryRecord::getRecord).collect(Collectors.toSet());
            currentBatch.forEach(record -> {
                if (!existingRetryRecords.contains(record)) {
                    this.retryQueue.add(new AbstractBatchJobCreator.RetryRecord(this, (Record<Event>)record));
                }
            });
        } else {
            currentBatch.forEach(record -> this.retryQueue.add(new AbstractBatchJobCreator.RetryRecord(this, (Record<Event>)record)));
        }
    }

    private void handleFailure(List<Record<Event>> failedRecords, ConcurrentLinkedQueue<Record<Event>> resultRecords, Throwable throwable, int statusCode) {
        if (failedRecords.isEmpty()) {
            this.incrementFailureCounter();
            return;
        }
        this.removeCurrentBatchFromRetryQueue(failedRecords);
        resultRecords.addAll(this.addFailureTags(failedRecords));
        this.incrementFailureCounter();
        this.numberOfRecordsFailedCounter.increment((double)failedRecords.size());
        if (this.dlqPushHandler == null) {
            return;
        }
        try {
            ArrayList<DlqObject> dlqObjects = new ArrayList<DlqObject>();
            for (Record<Event> record : failedRecords) {
                if (record.getData() == null) continue;
                dlqObjects.add(this.createDlqObjectFromEvent((Event)record.getData(), statusCode, throwable.getMessage()));
            }
            this.dlqPushHandler.perform(dlqObjects);
        }
        catch (Exception ex) {
            MLProcessor.LOG.error(DataPrepperMarkers.NOISY, "Exception occured during error handling: {}", (Object)ex.getMessage());
        }
    }

    private void processRetryQueue() {
        ArrayList<Record<Event>> expiredRecords = new ArrayList<Record<Event>>();
        this.retryQueue.removeIf(retryRecord -> {
            if (retryRecord.isExpired()) {
                expiredRecords.add(retryRecord.getRecord());
                MLProcessor.LOG.debug("Record expired after {} attempts over {} ms", (Object)retryRecord.getRetryCount(), (Object)this.maxRetryTimeWindow);
                return true;
            }
            return false;
        });
        if (!expiredRecords.isEmpty()) {
            this.handleExpiredRecords(expiredRecords);
        }
        this.retryQueue.forEach(retryRecord -> {
            this.batch_records.add(retryRecord.getRecord());
            retryRecord.incrementRetryCount();
        });
        this.retryRecordsAddedToBatch = true;
        if (!this.retryQueue.isEmpty()) {
            this.lastUpdateTimestamp.set(System.currentTimeMillis());
            MLProcessor.LOG.info("Added {} records to batch for retry, retry queue size: {}", (Object)this.retryQueue.size(), (Object)this.retryQueue.size());
        }
    }

    private void handleExpiredRecords(List<Record<Event>> expiredRecords) {
        MLProcessor.LOG.warn("{} records expired from retry queue after {} ms timeout", (Object)expiredRecords.size(), (Object)this.maxRetryTimeWindow);
        this.handleFailure(expiredRecords, this.processedBatchRecords, new MLBatchJobException(400, "Records expired after " + this.maxRetryTimeWindow / 60000L + " minute retry window"), 400);
    }

    private void removeCurrentBatchFromRetryQueue(List<Record<Event>> currentBatch) {
        int initialRetryQueueSize = this.retryQueue.size();
        if (initialRetryQueueSize > 0) {
            HashSet<Record<Event>> processedRecords = new HashSet<Record<Event>>(currentBatch);
            this.retryQueue.removeIf(retryRecord -> processedRecords.contains(retryRecord.getRecord()));
            int removedCount = initialRetryQueueSize - this.retryQueue.size();
            if (removedCount > 0) {
                MLProcessor.LOG.info("Removed {} processed records from retry queue. Remaining: {}", (Object)removedCount, (Object)this.retryQueue.size());
            }
        }
    }

    @Override
    public void prepareForShutdown() {
    }

    @Override
    public boolean isReadyForShutdown() {
        return this.batch_records.isEmpty();
    }

    @Override
    public void shutdown() {
        this.processRemainingBatch();
        this.prepareForShutdown();
    }

    private void processRemainingBatch() {
        if (!this.batch_records.isEmpty()) {
            ArrayList<Record<Event>> currentBatch = new ArrayList<Record<Event>>(this.batch_records);
            this.batch_records.clear();
            this.processCurrentBatch(currentBatch);
        }
    }

    private String findCommonPrefix(Collection<Record<Event>> records) {
        EventKey inputKey = this.mlProcessorConfig.getInputKey();
        List keys = records.stream().map(record -> inputKey == null ? ((Event)record.getData()).getJsonNode().get("key").asText() : (String)((Event)record.getData()).get(inputKey, String.class)).collect(Collectors.toList());
        if (keys.isEmpty()) {
            throw new IllegalArgumentException("Empty inputs identified from input key : " + String.valueOf(inputKey));
        }
        if (keys.size() == 1) {
            String singleKey = (String)keys.get(0);
            int lastSlashIndex = singleKey.lastIndexOf(47);
            return lastSlashIndex >= 0 ? singleKey.substring(0, lastSlashIndex + 1) : "";
        }
        String prefix = (String)keys.get(0);
        for (int i = 1; i < keys.size() && !(prefix = this.findCommonPrefix(prefix, (String)keys.get(i))).isEmpty(); ++i) {
        }
        return prefix;
    }

    private String findCommonPrefix(String s1, String s2) {
        int i;
        int minLength = Math.min(s1.length(), s2.length());
        for (i = 0; i < minLength && s1.charAt(i) == s2.charAt(i); ++i) {
        }
        int lastSlashIndex = s1.lastIndexOf(47, i - 1);
        return lastSlashIndex >= 0 ? s1.substring(0, lastSlashIndex + 1) : "";
    }

    private String generateManifest(Collection<Record<Event>> records, String customerBucket, String prefix) {
        try {
            String jobName = this.generateJobName();
            String folderName = prefix + jobName;
            String fileName = folderName + "/" + jobName + ".manifest";
            JSONArray manifestArray = new JSONArray();
            manifestArray.put((Object)new JSONObject().put("prefix", (Object)("s3://" + customerBucket + "/")));
            for (Record<Event> record : records) {
                String key = ((Event)record.getData()).getJsonNode().get("key").asText();
                manifestArray.put((Object)key);
            }
            byte[] jsonData = manifestArray.toString(4).getBytes();
            PutObjectRequest putObjectRequest = (PutObjectRequest)PutObjectRequest.builder().bucket(customerBucket).key(fileName).build();
            this.s3Client.putObject(putObjectRequest, RequestBody.fromBytes((byte[])jsonData));
            return "s3://" + customerBucket + "/" + fileName;
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Unexpected error while generating manifest file for SageMaker job.", (Throwable)e);
            return null;
        }
    }

    private String createPayloadSageMaker(String manifestUri, MLProcessorConfig mlProcessorConfig) {
        if (manifestUri == null || manifestUri.isEmpty()) {
            throw new IllegalArgumentException("Invalid manifest URI: manifestUri is either null or empty. Please ensure the correct input S3 uris are provided");
        }
        try {
            String jobName = this.generateJobName();
            String outputPath = mlProcessorConfig.getOutputPath();
            if (outputPath != null) {
                outputPath = outputPath.concat(outputPath.endsWith("/") ? "" : "/").concat(jobName);
            }
            JsonNode rootNode = OBJECT_MAPPER.readTree(SAGEMAKER_PAYLOAD_TEMPLATE);
            ((ObjectNode)rootNode.at("/parameters/TransformInput/DataSource/S3DataSource")).put("S3Uri", manifestUri);
            ((ObjectNode)rootNode.at("/parameters")).put("TransformJobName", jobName);
            if (outputPath != null) {
                ((ObjectNode)rootNode.at("/parameters/TransformOutput")).put("S3OutputPath", outputPath);
            } else {
                ((ObjectNode)rootNode).remove("parameters").path("TransformOutput");
            }
            return OBJECT_MAPPER.writeValueAsString((Object)rootNode);
        }
        catch (JsonProcessingException e) {
            MLProcessor.LOG.error("Failed to process the JSON payload for SageMaker batch job. Error: {}", (Object)e.getMessage());
            throw new RuntimeException("Error processing JSON payload for SageMaker batch job", e);
        }
        catch (Exception e) {
            MLProcessor.LOG.error("Failed to create SageMaker batch job payload with input {}.", (Object)manifestUri, (Object)e);
            throw new RuntimeException("Failed to create payload for SageMaker batch job", e);
        }
    }

    public ConcurrentLinkedQueue<Record<Event>> getBatch_records() {
        return this.batch_records;
    }

    public ConcurrentLinkedQueue<Record<Event>> getProcessedBatchRecords() {
        return this.processedBatchRecords;
    }

    public ConcurrentLinkedQueue<AbstractBatchJobCreator.RetryRecord> getRetryQueue() {
        return this.retryQueue;
    }
}

