/*
 * Decompiled with CFR 0.152.
 */
package com.linecorp.armeria.server.grpc;

import com.linecorp.armeria.common.ExchangeType;
import com.linecorp.armeria.common.HttpHeaders;
import com.linecorp.armeria.common.HttpHeadersBuilder;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.HttpResponseWriter;
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.MediaType;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.ResponseHeadersBuilder;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.grpc.GrpcExceptionHandlerFunction;
import com.linecorp.armeria.common.grpc.GrpcJsonMarshaller;
import com.linecorp.armeria.common.grpc.GrpcSerializationFormats;
import com.linecorp.armeria.common.grpc.protocol.GrpcHeaderNames;
import com.linecorp.armeria.common.logging.RequestLogProperty;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.common.util.TimeoutMode;
import com.linecorp.armeria.internal.common.grpc.GrpcExchangeTypeUtil;
import com.linecorp.armeria.internal.common.grpc.MetadataUtil;
import com.linecorp.armeria.internal.common.grpc.TimeoutHeaderUtil;
import com.linecorp.armeria.internal.server.grpc.AbstractServerCall;
import com.linecorp.armeria.internal.server.grpc.ServerStatusAndMetadata;
import com.linecorp.armeria.internal.shaded.guava.base.MoreObjects;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableList;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableMap;
import com.linecorp.armeria.internal.shaded.guava.collect.ImmutableSet;
import com.linecorp.armeria.internal.shaded.guava.primitives.Ints;
import com.linecorp.armeria.internal.shaded.guava.util.concurrent.MoreExecutors;
import com.linecorp.armeria.server.AbstractHttpService;
import com.linecorp.armeria.server.Route;
import com.linecorp.armeria.server.RoutingContext;
import com.linecorp.armeria.server.ServiceConfig;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.grpc.AbstractUnframedGrpcService;
import com.linecorp.armeria.server.grpc.GrpcHealthCheckService;
import com.linecorp.armeria.server.grpc.GrpcService;
import com.linecorp.armeria.server.grpc.GrpcServiceBuilder;
import com.linecorp.armeria.server.grpc.HandlerRegistry;
import com.linecorp.armeria.server.grpc.ProtoReflectionServiceInterceptor;
import com.linecorp.armeria.server.grpc.StreamingServerCall;
import com.linecorp.armeria.server.grpc.UnaryServerCall;
import io.grpc.Codec;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerMethodDefinition;
import io.grpc.ServerServiceDefinition;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;
import io.netty.util.AttributeKey;
import java.time.Duration;
import java.util.AbstractMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

final class FramedGrpcService
extends AbstractHttpService
implements GrpcService {
    private static final Logger logger = LoggerFactory.getLogger(FramedGrpcService.class);
    static final ServerCall.Listener<?> EMPTY_LISTENER = new EmptyListener();
    static final AttributeKey<ServerMethodDefinition<?, ?>> RESOLVED_GRPC_METHOD = AttributeKey.valueOf(FramedGrpcService.class, (String)"RESOLVED_GRPC_METHOD");
    private final HandlerRegistry registry;
    private final Set<Route> routes;
    private final Map<String, ExchangeType> exchangeTypes;
    private final DecompressorRegistry decompressorRegistry;
    private final CompressorRegistry compressorRegistry;
    private final Set<SerializationFormat> supportedSerializationFormats;
    private final Map<String, GrpcJsonMarshaller> jsonMarshallers;
    @Nullable
    private final ProtoReflectionServiceInterceptor protoReflectionServiceInterceptor;
    private final int maxResponseMessageLength;
    private final boolean useBlockingTaskExecutor;
    private final boolean unsafeWrapRequestBuffers;
    private final boolean useClientTimeoutHeader;
    private final boolean useMethodMarshaller;
    private final String advertisedEncodingsHeader;
    private final Map<SerializationFormat, ResponseHeaders> defaultHeaders;
    @Nullable
    private final GrpcHealthCheckService grpcHealthCheckService;
    private int maxRequestMessageLength;
    private final boolean lookupMethodFromAttribute;
    private final boolean autoCompression;

    private static Map<String, GrpcJsonMarshaller> getJsonMarshallers(HandlerRegistry registry, Set<SerializationFormat> supportedSerializationFormats, Function<? super ServiceDescriptor, ? extends GrpcJsonMarshaller> jsonMarshallerFactory) {
        if (supportedSerializationFormats.stream().noneMatch(GrpcSerializationFormats::isJson)) {
            return ImmutableMap.of();
        }
        try {
            return (Map)registry.services().stream().map(ServerServiceDefinition::getServiceDescriptor).distinct().collect(ImmutableMap.toImmutableMap(ServiceDescriptor::getName, jsonMarshallerFactory));
        }
        catch (Exception e) {
            logger.warn("Failed to instantiate a JSON marshaller. Consider disabling gRPC-JSON serialization with {}.supportedSerializationFormats() or using {}.ofGson() instead.", new Object[]{GrpcServiceBuilder.class.getName(), GrpcJsonMarshaller.class.getName(), e});
            return ImmutableMap.of();
        }
    }

    FramedGrpcService(HandlerRegistry registry, DecompressorRegistry decompressorRegistry, CompressorRegistry compressorRegistry, Set<SerializationFormat> supportedSerializationFormats, Function<? super ServiceDescriptor, ? extends GrpcJsonMarshaller> jsonMarshallerFactory, @Nullable ProtoReflectionServiceInterceptor protoReflectionServiceInterceptor, int maxRequestMessageLength, int maxResponseMessageLength, boolean useBlockingTaskExecutor, boolean unsafeWrapRequestBuffers, boolean useClientTimeoutHeader, boolean lookupMethodFromAttribute, @Nullable GrpcHealthCheckService grpcHealthCheckService, boolean autoCompression, boolean useMethodMarshaller) {
        this.registry = Objects.requireNonNull(registry, "registry");
        this.routes = ImmutableSet.copyOf(registry.methodsByRoute().keySet());
        this.exchangeTypes = (Map)registry.methods().entrySet().stream().collect(ImmutableMap.toImmutableMap(e -> '/' + (String)e.getKey(), e -> GrpcExchangeTypeUtil.toExchangeType(((ServerMethodDefinition)e.getValue()).getMethodDescriptor().getType())));
        this.decompressorRegistry = Objects.requireNonNull(decompressorRegistry, "decompressorRegistry");
        this.compressorRegistry = Objects.requireNonNull(compressorRegistry, "compressorRegistry");
        this.supportedSerializationFormats = supportedSerializationFormats;
        this.useClientTimeoutHeader = useClientTimeoutHeader;
        this.jsonMarshallers = FramedGrpcService.getJsonMarshallers(registry, supportedSerializationFormats, jsonMarshallerFactory);
        this.protoReflectionServiceInterceptor = protoReflectionServiceInterceptor;
        this.maxRequestMessageLength = maxRequestMessageLength;
        this.maxResponseMessageLength = maxResponseMessageLength;
        this.useBlockingTaskExecutor = useBlockingTaskExecutor;
        this.unsafeWrapRequestBuffers = unsafeWrapRequestBuffers;
        this.lookupMethodFromAttribute = lookupMethodFromAttribute;
        this.autoCompression = autoCompression;
        this.useMethodMarshaller = useMethodMarshaller;
        this.advertisedEncodingsHeader = String.join((CharSequence)",", decompressorRegistry.getAdvertisedMessageEncodings());
        this.defaultHeaders = (Map)supportedSerializationFormats.stream().map(format -> {
            ResponseHeadersBuilder builder = ResponseHeaders.builder((HttpStatus)HttpStatus.OK).contentType(format.mediaType()).add((CharSequence)GrpcHeaderNames.GRPC_ENCODING, Codec.Identity.NONE.getMessageEncoding());
            if (!this.advertisedEncodingsHeader.isEmpty()) {
                builder.add((CharSequence)GrpcHeaderNames.GRPC_ACCEPT_ENCODING, this.advertisedEncodingsHeader);
            }
            return new AbstractMap.SimpleImmutableEntry<SerializationFormat, ResponseHeaders>((SerializationFormat)format, builder.build());
        }).collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
        this.grpcHealthCheckService = grpcHealthCheckService;
    }

    public ExchangeType exchangeType(RoutingContext routingContext) {
        return (ExchangeType)MoreObjects.firstNonNull((Object)this.exchangeTypes.get(routingContext.result().routingResult().path()), (Object)ExchangeType.BIDI_STREAMING);
    }

    protected HttpResponse doPost(ServiceRequestContext ctx, HttpRequest req) throws Exception {
        HttpResponseWriter res;
        MediaType contentType = req.contentType();
        SerializationFormat serializationFormat = this.findSerializationFormat(contentType);
        if (serializationFormat == null) {
            return HttpResponse.of((HttpStatus)HttpStatus.UNSUPPORTED_MEDIA_TYPE, (MediaType)MediaType.PLAIN_TEXT_UTF_8, (String)"Missing or invalid Content-Type header.");
        }
        ctx.logBuilder().serializationFormat(serializationFormat);
        ServerMethodDefinition<?, ?> method = this.methodDefinition(ctx);
        if (method == null) {
            return HttpResponse.of((ResponseHeaders)((ResponseHeaders)AbstractServerCall.statusToTrailers(ctx, (HttpHeadersBuilder)this.defaultHeaders.get(serializationFormat).toBuilder(), Status.UNIMPLEMENTED.withDescription("Method not found: " + ctx.config().route().patternString()), new Metadata())));
        }
        if (this.useClientTimeoutHeader) {
            String timeoutHeader = req.headers().get((CharSequence)GrpcHeaderNames.GRPC_TIMEOUT);
            if (timeoutHeader != null) {
                try {
                    long timeout = TimeoutHeaderUtil.fromHeaderValue(timeoutHeader);
                    if (timeout == 0L) {
                        ctx.clearRequestTimeout();
                    }
                    ctx.setRequestTimeout(TimeoutMode.SET_FROM_NOW, Duration.ofNanos(timeout));
                }
                catch (IllegalArgumentException e) {
                    Metadata metadata = new Metadata();
                    GrpcExceptionHandlerFunction exceptionHandler = this.registry.getExceptionHandler(method);
                    return HttpResponse.of((ResponseHeaders)((ResponseHeaders)AbstractServerCall.statusToTrailers(ctx, (HttpHeadersBuilder)this.defaultHeaders.get(serializationFormat).toBuilder(), exceptionHandler.apply((RequestContext)ctx, e, metadata), metadata)));
                }
            } else if (!Boolean.TRUE.equals(ctx.attr(AbstractUnframedGrpcService.IS_UNFRAMED_GRPC))) {
                ctx.clearRequestTimeout();
            }
        }
        ctx.logBuilder().defer(new RequestLogProperty[]{RequestLogProperty.REQUEST_CONTENT, RequestLogProperty.RESPONSE_CONTENT});
        if (method.getMethodDescriptor().getType() == MethodDescriptor.MethodType.UNARY) {
            CompletableFuture<HttpResponse> resFuture = new CompletableFuture<HttpResponse>();
            res = HttpResponse.of(resFuture);
            this.startCall(this.registry.simpleMethodName(method.getMethodDescriptor()), method, ctx, req, (HttpResponse)res, resFuture, serializationFormat);
        } else {
            res = HttpResponse.streaming();
            this.startCall(this.registry.simpleMethodName(method.getMethodDescriptor()), method, ctx, req, (HttpResponse)res, null, serializationFormat);
        }
        return res;
    }

    private <I, O> void startCall(String simpleMethodName, ServerMethodDefinition<I, O> methodDef, ServiceRequestContext ctx, HttpRequest req, HttpResponse res, @Nullable CompletableFuture<HttpResponse> resFuture, SerializationFormat serializationFormat) {
        MethodDescriptor methodDescriptor = methodDef.getMethodDescriptor();
        Executor blockingExecutor = this.useBlockingTaskExecutor || this.registry.needToUseBlockingTaskExecutor(methodDef) ? MoreExecutors.newSequentialExecutor((Executor)ctx.blockingTaskExecutor()) : null;
        AbstractServerCall call = this.newServerCall(simpleMethodName, methodDef, ctx, req, res, resFuture, serializationFormat, blockingExecutor);
        if (blockingExecutor != null) {
            blockingExecutor.execute(() -> this.startCall(methodDef, ctx, req, methodDescriptor, call));
        } else {
            try (SafeCloseable ignored = ctx.push();){
                this.startCall(methodDef, ctx, req, methodDescriptor, call);
            }
        }
    }

    private <I, O> void startCall(ServerMethodDefinition<I, O> methodDef, ServiceRequestContext ctx, HttpRequest req, MethodDescriptor<I, O> methodDescriptor, AbstractServerCall<I, O> call) {
        ServerCall.Listener listener;
        Metadata headers = MetadataUtil.copyFromHeaders((HttpHeaders)req.headers());
        try {
            listener = methodDef.getServerCallHandler().startCall(call, headers);
        }
        catch (Throwable t) {
            call.setListener(EMPTY_LISTENER);
            call.close(t);
            return;
        }
        if (listener == null) {
            throw new NullPointerException("startCall() returned a null listener for method " + methodDescriptor.getFullMethodName());
        }
        call.setListener(listener);
        call.startDeframing();
        ctx.whenRequestCancelling().handle((cancellationCause, unused) -> {
            Status status = call.exceptionHandler().apply((RequestContext)ctx, (Throwable)cancellationCause, headers);
            assert (status != null);
            call.close(new ServerStatusAndMetadata(status, new Metadata(), true, true));
            return null;
        });
    }

    private <I, O> AbstractServerCall<I, O> newServerCall(String simpleMethodName, ServerMethodDefinition<I, O> methodDef, ServiceRequestContext ctx, HttpRequest req, HttpResponse res, @Nullable CompletableFuture<HttpResponse> resFuture, SerializationFormat serializationFormat, @Nullable Executor blockingExecutor) {
        MethodDescriptor methodDescriptor = methodDef.getMethodDescriptor();
        GrpcExceptionHandlerFunction exceptionHandler = this.registry.getExceptionHandler(methodDef);
        if (methodDescriptor.getType() == MethodDescriptor.MethodType.UNARY) {
            assert (resFuture != null);
            return new UnaryServerCall(req, methodDescriptor, simpleMethodName, this.compressorRegistry, this.decompressorRegistry, res, resFuture, this.maxRequestMessageLength, this.maxResponseMessageLength, ctx, serializationFormat, this.jsonMarshallers.get(methodDescriptor.getServiceName()), this.unsafeWrapRequestBuffers, this.defaultHeaders.get(serializationFormat), exceptionHandler, blockingExecutor, this.autoCompression, this.useMethodMarshaller);
        }
        return new StreamingServerCall(req, methodDescriptor, simpleMethodName, this.compressorRegistry, this.decompressorRegistry, (HttpResponseWriter)res, this.maxRequestMessageLength, this.maxResponseMessageLength, ctx, serializationFormat, this.jsonMarshallers.get(methodDescriptor.getServiceName()), this.unsafeWrapRequestBuffers, this.defaultHeaders.get(serializationFormat), exceptionHandler, blockingExecutor, this.autoCompression, this.useMethodMarshaller);
    }

    public void serviceAdded(ServiceConfig cfg) {
        if (this.maxRequestMessageLength == -1) {
            this.maxRequestMessageLength = Ints.saturatedCast((long)cfg.maxRequestLength());
        }
        if (this.protoReflectionServiceInterceptor != null) {
            Map grpcServices = (Map)cfg.server().config().virtualHosts().stream().flatMap(host -> host.serviceConfigs().stream()).map(serviceConfig -> (FramedGrpcService)serviceConfig.service().as(FramedGrpcService.class)).filter(Objects::nonNull).flatMap(service -> service.services().stream()).collect(ImmutableMap.toImmutableMap(def -> def.getServiceDescriptor().getName(), Function.identity(), (a, b) -> a));
            this.protoReflectionServiceInterceptor.setServer(FramedGrpcService.newDummyServer(grpcServices));
        }
        if (this.grpcHealthCheckService != null) {
            this.grpcHealthCheckService.serviceAdded(cfg);
        }
    }

    @Override
    public ServerMethodDefinition<?, ?> methodDefinition(ServiceRequestContext ctx) {
        ServerMethodDefinition method;
        ServerMethodDefinition serverMethodDefinition = method = this.lookupMethodFromAttribute ? (ServerMethodDefinition)ctx.attr(RESOLVED_GRPC_METHOD) : null;
        if (method != null) {
            return method;
        }
        return GrpcService.super.methodDefinition(ctx);
    }

    private static Server newDummyServer(final Map<String, ServerServiceDefinition> grpcServices) {
        return new Server(){

            public Server start() {
                throw new UnsupportedOperationException();
            }

            public List<ServerServiceDefinition> getServices() {
                return ImmutableList.copyOf(grpcServices.values());
            }

            public List<ServerServiceDefinition> getImmutableServices() {
                return this.getServices();
            }

            public List<ServerServiceDefinition> getMutableServices() {
                return ImmutableList.of();
            }

            public Server shutdown() {
                throw new UnsupportedOperationException();
            }

            public Server shutdownNow() {
                throw new UnsupportedOperationException();
            }

            public boolean isShutdown() {
                throw new UnsupportedOperationException();
            }

            public boolean isTerminated() {
                throw new UnsupportedOperationException();
            }

            public boolean awaitTermination(long timeout, TimeUnit unit) {
                throw new UnsupportedOperationException();
            }

            public void awaitTermination() {
                throw new UnsupportedOperationException();
            }
        };
    }

    @Override
    public boolean isFramed() {
        return true;
    }

    @Override
    public List<ServerServiceDefinition> services() {
        List<ServerServiceDefinition> services = this.registry.services();
        assert (services instanceof ImmutableList);
        return services;
    }

    @Override
    public Map<String, ServerMethodDefinition<?, ?>> methods() {
        return this.registry.methods();
    }

    @Override
    public Map<Route, ServerMethodDefinition<?, ?>> methodsByRoute() {
        return this.registry.methodsByRoute();
    }

    @Override
    public Set<SerializationFormat> supportedSerializationFormats() {
        return this.supportedSerializationFormats;
    }

    @Nullable
    private SerializationFormat findSerializationFormat(@Nullable MediaType contentType) {
        if (contentType == null) {
            return null;
        }
        for (SerializationFormat format : this.supportedSerializationFormats) {
            if (!format.isAccepted(contentType)) continue;
            return format;
        }
        return null;
    }

    public Set<Route> routes() {
        return this.routes;
    }

    private static class EmptyListener<T>
    extends ServerCall.Listener<T> {
        private EmptyListener() {
        }
    }
}

