/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.dataprepper.plugins.lambda.common;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.opensearch.dataprepper.model.codec.InputCodec;
import org.opensearch.dataprepper.model.event.Event;
import org.opensearch.dataprepper.model.plugin.PluginFactory;
import org.opensearch.dataprepper.model.record.Record;
import org.opensearch.dataprepper.plugins.lambda.common.accumlator.Buffer;
import org.opensearch.dataprepper.plugins.lambda.common.config.ResponseHandling;
import org.opensearch.dataprepper.plugins.lambda.common.config.StreamingOptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.services.lambda.LambdaAsyncClient;
import software.amazon.awssdk.services.lambda.model.InvokeRequest;
import software.amazon.awssdk.services.lambda.model.InvokeWithResponseStreamRequest;
import software.amazon.awssdk.services.lambda.model.InvokeWithResponseStreamResponseHandler;
import software.amazon.awssdk.services.lambda.model.ResponseStreamingInvocationType;
import software.amazon.awssdk.services.lambda.model.invokewithresponsestreamresponseevent.DefaultPayloadChunk;

public class StreamingLambdaHandler {
    private static final Logger LOG = LoggerFactory.getLogger(StreamingLambdaHandler.class);
    private final LambdaAsyncClient lambdaAsyncClient;
    private final PluginFactory pluginFactory;
    private final InputCodec responseCodec;
    private final String functionName;
    private final StreamingOptions streamingOptions;

    public StreamingLambdaHandler(LambdaAsyncClient lambdaAsyncClient, PluginFactory pluginFactory, InputCodec responseCodec, String functionName, StreamingOptions streamingOptions) {
        this.lambdaAsyncClient = lambdaAsyncClient;
        this.pluginFactory = pluginFactory;
        this.responseCodec = responseCodec;
        this.functionName = functionName;
        this.streamingOptions = streamingOptions;
    }

    public CompletableFuture<List<Record<Event>>> invokeWithStreaming(Buffer inputBuffer) {
        CompletableFuture<List<Record<Event>>> resultFuture = new CompletableFuture<List<Record<Event>>>();
        ByteArrayOutputStream responseStream = new ByteArrayOutputStream();
        InvokeRequest invokeRequest = inputBuffer.getRequestPayload(this.functionName, "RequestResponse");
        if (invokeRequest == null) {
            resultFuture.completeExceptionally(new IllegalArgumentException("No payload in buffer"));
            return resultFuture;
        }
        InvokeWithResponseStreamRequest request = (InvokeWithResponseStreamRequest)InvokeWithResponseStreamRequest.builder().functionName(this.functionName).invocationType(ResponseStreamingInvocationType.REQUEST_RESPONSE).payload(invokeRequest.payload()).build();
        InvokeWithResponseStreamResponseHandler responseHandler = ((InvokeWithResponseStreamResponseHandler.Builder)((InvokeWithResponseStreamResponseHandler.Builder)((InvokeWithResponseStreamResponseHandler.Builder)((InvokeWithResponseStreamResponseHandler.Builder)InvokeWithResponseStreamResponseHandler.builder().onResponse(response -> LOG.debug("Streaming response started for function: {}", (Object)this.functionName))).onEventStream(publisher -> publisher.subscribe(event -> {
            if (event instanceof DefaultPayloadChunk) {
                DefaultPayloadChunk chunk = (DefaultPayloadChunk)event;
                try {
                    byte[] chunkBytes = chunk.payload().asByteArray();
                    ByteArrayOutputStream byteArrayOutputStream = responseStream;
                    synchronized (byteArrayOutputStream) {
                        responseStream.write(chunkBytes);
                    }
                    LOG.debug("Received chunk of size: {} bytes", (Object)chunkBytes.length);
                }
                catch (IOException e) {
                    LOG.error("Error writing chunk to response stream", (Throwable)e);
                    resultFuture.completeExceptionally(e);
                }
            } else {
                LOG.debug("Ignoring non-payload Lambda stream event: {}", (Object)event.getClass().getSimpleName());
            }
        }))).onComplete(() -> {
            try {
                byte[] completeResponse = responseStream.toByteArray();
                LOG.debug("Streaming response complete. Total size: {} bytes", (Object)completeResponse.length);
                List<Record<Event>> processedRecords = this.processStreamingResponse(completeResponse, inputBuffer, this.streamingOptions);
                resultFuture.complete(processedRecords);
            }
            catch (Exception e) {
                LOG.error("Error processing complete streaming response", (Throwable)e);
                resultFuture.completeExceptionally(e);
            }
        })).onError(throwable -> {
            LOG.error("Error in streaming Lambda invocation", throwable);
            resultFuture.completeExceptionally((Throwable)throwable);
        })).build();
        this.lambdaAsyncClient.invokeWithResponseStream(request, responseHandler);
        return resultFuture;
    }

    private List<Record<Event>> processStreamingResponse(byte[] responseBytes, Buffer inputBuffer, StreamingOptions streamingOptions) throws IOException {
        ArrayList<Record<Event>> resultRecords = new ArrayList<Record<Event>>();
        try (ByteArrayInputStream responseStream = new ByteArrayInputStream(responseBytes);){
            this.responseCodec.parse((InputStream)responseStream, record -> {
                Event parsedEvent = (Event)record.getData();
                resultRecords.add(new Record((Object)parsedEvent));
            });
        }
        LOG.info("Processed streaming response: {} records from {} bytes", (Object)resultRecords.size(), (Object)responseBytes.length);
        return this.applyResponseHandling(resultRecords, inputBuffer, streamingOptions);
    }

    private List<Record<Event>> applyResponseHandling(List<Record<Event>> parsedRecords, Buffer inputBuffer, StreamingOptions streamingOptions) {
        if (streamingOptions == null || streamingOptions.getResponseHandling() != ResponseHandling.RECONSTRUCT_DOCUMENT) {
            return parsedRecords;
        }
        return this.reconstructDocument(parsedRecords, inputBuffer);
    }

    private List<Record<Event>> reconstructDocument(List<Record<Event>> parsedRecords, Buffer inputBuffer) {
        if (parsedRecords.isEmpty()) {
            return parsedRecords;
        }
        List<Record<Event>> originalRecords = inputBuffer.getRecords();
        if (originalRecords.isEmpty()) {
            LOG.warn("No original records found in buffer for reconstruction");
            return parsedRecords;
        }
        if (originalRecords.size() != 1) {
            String errorMsg = String.format("reconstruct-document mode requires exactly 1 event per buffer, found %d events. This should have been prevented by configuration validation. Please ensure batch.threshold.event_count is set to 1.", originalRecords.size());
            LOG.error(errorMsg);
            throw new IllegalStateException(errorMsg);
        }
        Event reconstructedEvent = (Event)originalRecords.get(0).getData();
        for (Record<Event> parsedRecord : parsedRecords) {
            Event chunkEvent = (Event)parsedRecord.getData();
            reconstructedEvent.merge(chunkEvent);
        }
        LOG.info("Reconstructed {} chunks into {} document(s)", (Object)parsedRecords.size(), (Object)originalRecords.size());
        return originalRecords;
    }
}

