/*
 * Decompiled with CFR 0.152.
 */
package org.example.common.ai.impl;

import com.google.common.collect.Maps;
import com.volcengine.ark.runtime.model.bot.completion.chat.BotChatCompletionRequest;
import com.volcengine.ark.runtime.model.bot.completion.chat.BotChatCompletionResult;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionChoice;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionChunk;
import com.volcengine.ark.runtime.model.completion.chat.ChatCompletionRequest;
import com.volcengine.ark.runtime.model.completion.chat.ChatFunctionCall;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessage;
import com.volcengine.ark.runtime.model.completion.chat.ChatMessageRole;
import com.volcengine.ark.runtime.model.completion.chat.ChatTool;
import com.volcengine.ark.runtime.model.completion.chat.ChatToolCall;
import com.volcengine.ark.runtime.model.embeddings.Embedding;
import com.volcengine.ark.runtime.model.embeddings.EmbeddingRequest;
import com.volcengine.ark.runtime.model.embeddings.EmbeddingResult;
import com.volcengine.ark.runtime.service.ArkService;
import io.reactivex.Flowable;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.example.common.ai.ChatService;
import org.example.common.ai.ToolCallBackProvider;
import org.example.common.ai.entity.ChatRequest;
import org.springframework.beans.factory.annotation.Value;

public class ChatServiceImpl
implements ChatService {
    private final ArkService arkService;
    private final List<ToolCallBackProvider> toolCallBackProviders;
    @Value(value="${ark.modelId:ep-m-20250414114216-wq492}")
    private String defaultModelId;
    @Value(value="${ark.botId:bot-20250414113448-lf27j}")
    private String defaultBotId;
    @Value(value="${ark.embedding:doubao-embedding-text-240715}")
    private String embeddingModelId;

    public ChatServiceImpl(ArkService arkService, List<ToolCallBackProvider> toolCallBackProviders) {
        this.arkService = arkService;
        this.toolCallBackProviders = toolCallBackProviders == null ? Collections.emptyList() : toolCallBackProviders;
    }

    @Override
    public Flowable<ChatCompletionChunk> chat(ChatRequest chatRequest) {
        ArrayList<ChatMessage> messages = new ArrayList<ChatMessage>();
        ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(chatRequest.getMessage()).build();
        messages.add(userMessage);
        ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder().model(StringUtils.isNotBlank((CharSequence)chatRequest.getModelId()) ? chatRequest.getModelId() : this.defaultModelId).messages(messages).stream(Boolean.valueOf(true)).tools(this.getTools()).build();
        HashMap pendingToolCalls = Maps.newHashMap();
        return this.arkService.streamChatCompletion(chatCompletionRequest).filter(chunk -> {
            String reasoningContent = ((ChatCompletionChoice)chunk.getChoices().get(0)).getMessage().getReasoningContent();
            return StringUtils.isBlank((CharSequence)reasoningContent);
        }).concatMap(chunk -> {
            String finishReason;
            ChatCompletionChoice chatCompletionChoice = (ChatCompletionChoice)chunk.getChoices().get(0);
            ChatMessage chatMessage = chatCompletionChoice.getMessage();
            Object content = chatMessage.getContent();
            if (content != null && !content.toString().isEmpty()) {
                return Flowable.just((Object)chunk);
            }
            List toolCalls = chatMessage.getToolCalls();
            if (toolCalls != null) {
                for (ChatToolCall toolCall : toolCalls) {
                    int index = toolCall.getIndex();
                    if (!pendingToolCalls.containsKey(index)) {
                        pendingToolCalls.put(index, toolCall);
                        continue;
                    }
                    ChatToolCall existing = (ChatToolCall)pendingToolCalls.get(index);
                    existing.getFunction().setArguments(existing.getFunction().getArguments() + toolCall.getFunction().getArguments());
                }
            }
            if (StringUtils.isNotBlank((CharSequence)(finishReason = chatCompletionChoice.getFinishReason())) && "tool_calls".equals(finishReason)) {
                for (ChatToolCall chatToolCall : pendingToolCalls.values()) {
                    ChatMessage toolMessage;
                    ChatFunctionCall function = chatToolCall.getFunction();
                    String name = function.getName();
                    ToolCallBackProvider toolCallBackProvider = this.toolCallBackProviders.stream().filter(provider -> provider.getToolName().equals(name)).findFirst().orElse(null);
                    if (toolCallBackProvider == null || (toolMessage = toolCallBackProvider.getToolMessage(function.getArguments(), chatToolCall.getId())) == null) continue;
                    messages.add(toolMessage);
                }
                return this.arkService.streamChatCompletion(chatCompletionRequest);
            }
            return Flowable.just((Object)chunk);
        });
    }

    private List<ChatTool> getTools() {
        return this.toolCallBackProviders.stream().map(ToolCallBackProvider::getChatTool).collect(Collectors.toList());
    }

    @Override
    public String botChat(ChatRequest chatRequest) {
        BotChatCompletionResult botChatCompletion = this.arkService.createBotChatCompletion(BotChatCompletionRequest.builder().botId(StringUtils.isNotBlank((CharSequence)chatRequest.getBotId()) ? chatRequest.getBotId() : this.defaultBotId).messages(Collections.singletonList(ChatMessage.builder().role(ChatMessageRole.USER).content(chatRequest.getMessage()).build())).build());
        return ((ChatCompletionChoice)botChatCompletion.getChoices().get(0)).getMessage().getContent().toString();
    }

    @Override
    public byte[] getEmbedding(String text) {
        ArrayList<String> input = new ArrayList<String>();
        input.add(text);
        EmbeddingResult embeddingResult = this.arkService.createEmbeddings(EmbeddingRequest.builder().input(input).model(this.embeddingModelId).build());
        List embeddings = ((Embedding)embeddingResult.getData().get(0)).getEmbedding();
        float[] floats = new float[embeddings.size()];
        for (int i = 0; i < embeddings.size(); ++i) {
            floats[i] = ((Double)embeddings.get(i)).floatValue();
        }
        ByteBuffer buffer = ByteBuffer.allocate(floats.length * 4).order(ByteOrder.LITTLE_ENDIAN);
        buffer.asFloatBuffer().put(floats);
        return buffer.array();
    }
}

