/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.tool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.execute.tool.ToolMLInput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.Executable;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.transport.client.Client;

@Function(value=FunctionName.TOOL)
public class MLToolExecutor
implements Executable {
    @Generated
    private static final Logger log = LogManager.getLogger(MLToolExecutor.class);
    private Client client;
    private SdkClient sdkClient;
    private Settings settings;
    private ClusterService clusterService;
    private NamedXContentRegistry xContentRegistry;
    private Map<String, Tool.Factory> toolFactories;

    public MLToolExecutor(Client client, SdkClient sdkClient, Settings settings, ClusterService clusterService, NamedXContentRegistry xContentRegistry, Map<String, Tool.Factory> toolFactories) {
        this.client = client;
        this.sdkClient = sdkClient;
        this.settings = settings;
        this.clusterService = clusterService;
        this.xContentRegistry = xContentRegistry;
        this.toolFactories = toolFactories;
    }

    @Override
    public void execute(Input input, ActionListener<Output> listener) {
        if (!(input instanceof ToolMLInput)) {
            throw new IllegalArgumentException("wrong input");
        }
        ToolMLInput toolMLInput = (ToolMLInput)input;
        String toolName = toolMLInput.getToolName();
        RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)toolMLInput.getInputDataset();
        if (inputDataSet == null || inputDataSet.getParameters() == null) {
            throw new IllegalArgumentException("Tool input data can not be empty.");
        }
        Map parameters = inputDataSet.getParameters();
        Tool.Factory toolFactory = this.toolFactories.get(toolName);
        if (toolFactory == null) {
            listener.onFailure((Exception)new IllegalArgumentException("Tool not found: " + toolName));
            return;
        }
        try {
            HashMap mutableParams = new HashMap(parameters);
            Tool tool = toolFactory.create(mutableParams);
            if (!tool.validate(mutableParams)) {
                listener.onFailure((Exception)new IllegalArgumentException("Invalid parameters for tool: " + toolName));
                return;
            }
            tool.run(mutableParams, ActionListener.wrap(result -> {
                ArrayList<ModelTensor> modelTensors = new ArrayList<ModelTensor>();
                this.processOutput(result, modelTensors);
                ModelTensors tensors = ModelTensors.builder().mlModelTensors(modelTensors).build();
                listener.onResponse((Object)new ModelTensorOutput(List.of(tensors)));
            }, error -> {
                log.error("Failed to execute tool: " + toolName, (Throwable)error);
                listener.onFailure(error);
            }));
        }
        catch (Exception e) {
            log.error("Failed to execute tool: " + toolName, (Throwable)e);
            listener.onFailure(e);
        }
    }

    private void processOutput(Object output, List<ModelTensor> modelTensors) {
        if (output instanceof ModelTensorOutput) {
            ModelTensorOutput modelTensorOutput = (ModelTensorOutput)output;
            modelTensorOutput.getMlModelOutputs().forEach(outs -> modelTensors.addAll(outs.getMlModelTensors()));
        } else if (output instanceof ModelTensor) {
            modelTensors.add((ModelTensor)output);
        } else if (output instanceof List) {
            List list = (List)output;
            if (!list.isEmpty()) {
                if (list.get(0) instanceof ModelTensor) {
                    modelTensors.addAll(list);
                } else if (list.get(0) instanceof ModelTensors) {
                    list.forEach(outs -> modelTensors.addAll(outs.getMlModelTensors()));
                } else {
                    String result = StringUtils.toJson((Object)output);
                    modelTensors.add(ModelTensor.builder().name("response").result(result).build());
                }
            }
        } else {
            String result = output instanceof String ? (String)output : StringUtils.toJson((Object)output);
            modelTensors.add(ModelTensor.builder().name("response").result(result).build());
        }
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public SdkClient getSdkClient() {
        return this.sdkClient;
    }

    @Generated
    public Settings getSettings() {
        return this.settings;
    }

    @Generated
    public ClusterService getClusterService() {
        return this.clusterService;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Map<String, Tool.Factory> getToolFactories() {
        return this.toolFactories;
    }

    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Generated
    public void setSdkClient(SdkClient sdkClient) {
        this.sdkClient = sdkClient;
    }

    @Generated
    public void setSettings(Settings settings) {
        this.settings = settings;
    }

    @Generated
    public void setClusterService(ClusterService clusterService) {
        this.clusterService = clusterService;
    }

    @Generated
    public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
        this.xContentRegistry = xContentRegistry;
    }

    @Generated
    public void setToolFactories(Map<String, Tool.Factory> toolFactories) {
        this.toolFactories = toolFactories;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MLToolExecutor)) {
            return false;
        }
        MLToolExecutor other = (MLToolExecutor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Client this$client = this.getClient();
        Client other$client = other.getClient();
        if (this$client == null ? other$client != null : !this$client.equals(other$client)) {
            return false;
        }
        SdkClient this$sdkClient = this.getSdkClient();
        SdkClient other$sdkClient = other.getSdkClient();
        if (this$sdkClient == null ? other$sdkClient != null : !this$sdkClient.equals(other$sdkClient)) {
            return false;
        }
        Settings this$settings = this.getSettings();
        Settings other$settings = other.getSettings();
        if (this$settings == null ? other$settings != null : !this$settings.equals(other$settings)) {
            return false;
        }
        ClusterService this$clusterService = this.getClusterService();
        ClusterService other$clusterService = other.getClusterService();
        if (this$clusterService == null ? other$clusterService != null : !this$clusterService.equals(other$clusterService)) {
            return false;
        }
        NamedXContentRegistry this$xContentRegistry = this.getXContentRegistry();
        NamedXContentRegistry other$xContentRegistry = other.getXContentRegistry();
        if (this$xContentRegistry == null ? other$xContentRegistry != null : !this$xContentRegistry.equals(other$xContentRegistry)) {
            return false;
        }
        Map<String, Tool.Factory> this$toolFactories = this.getToolFactories();
        Map<String, Tool.Factory> other$toolFactories = other.getToolFactories();
        return !(this$toolFactories == null ? other$toolFactories != null : !((Object)this$toolFactories).equals(other$toolFactories));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof MLToolExecutor;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Client $client = this.getClient();
        result = result * 59 + ($client == null ? 43 : $client.hashCode());
        SdkClient $sdkClient = this.getSdkClient();
        result = result * 59 + ($sdkClient == null ? 43 : $sdkClient.hashCode());
        Settings $settings = this.getSettings();
        result = result * 59 + ($settings == null ? 43 : $settings.hashCode());
        ClusterService $clusterService = this.getClusterService();
        result = result * 59 + ($clusterService == null ? 43 : $clusterService.hashCode());
        NamedXContentRegistry $xContentRegistry = this.getXContentRegistry();
        result = result * 59 + ($xContentRegistry == null ? 43 : $xContentRegistry.hashCode());
        Map<String, Tool.Factory> $toolFactories = this.getToolFactories();
        result = result * 59 + ($toolFactories == null ? 43 : ((Object)$toolFactories).hashCode());
        return result;
    }

    @Generated
    public String toString() {
        return "MLToolExecutor(client=" + String.valueOf(this.getClient()) + ", sdkClient=" + String.valueOf(this.getSdkClient()) + ", settings=" + String.valueOf(this.getSettings()) + ", clusterService=" + String.valueOf(this.getClusterService()) + ", xContentRegistry=" + String.valueOf(this.getXContentRegistry()) + ", toolFactories=" + String.valueOf(this.getToolFactories()) + ")";
    }

    @Generated
    public MLToolExecutor() {
    }
}

