package org.example.common.ai.impl;

import com.google.common.collect.Maps;
import com.volcengine.ark.runtime.model.bot.completion.chat.BotChatCompletionRequest;
import com.volcengine.ark.runtime.model.bot.completion.chat.BotChatCompletionResult;
import com.volcengine.ark.runtime.model.completion.chat.*;
import com.volcengine.ark.runtime.model.embeddings.EmbeddingRequest;
import com.volcengine.ark.runtime.model.embeddings.EmbeddingResult;
import com.volcengine.ark.runtime.service.ArkService;
import io.reactivex.Flowable;
import org.apache.commons.lang3.StringUtils;
import org.example.common.ai.ChatService;
import org.example.common.ai.ToolCallBackProvider;
import org.example.common.ai.entity.ChatRequest;
import org.springframework.beans.factory.annotation.Value;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * @author liutao
 * @since 2025/4/23
 */
public class ChatServiceImpl implements ChatService {


    private final ArkService arkService;
    private final List<ToolCallBackProvider> toolCallBackProviders;

    @Value("${ark.modelId:ep-m-20250414114216-wq492}")
    private String defaultModelId;
    @Value("${ark.botId:bot-20250414113448-lf27j}")
    private String defaultBotId;
    @Value("${ark.embedding:doubao-embedding-text-240715}")
    private String embeddingModelId;

    public ChatServiceImpl(ArkService arkService, List<ToolCallBackProvider> toolCallBackProviders) {
        this.arkService = arkService;
        this.toolCallBackProviders = toolCallBackProviders == null ? Collections.emptyList() : toolCallBackProviders;
    }

    @Override
    public Flowable<ChatCompletionChunk> chat(ChatRequest chatRequest) {
        final List<ChatMessage> messages = new ArrayList<>();
        final ChatMessage userMessage = ChatMessage.builder().role(ChatMessageRole.USER).content(chatRequest.getMessage()).build();
        messages.add(userMessage);
        ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
                .model(StringUtils.isNotBlank(chatRequest.getModelId()) ? chatRequest.getModelId() : defaultModelId)
                .messages(messages)
                .stream(true)
                .tools(getTools())
                .build();

        Map<Integer, ChatToolCall> pendingToolCalls = Maps.newHashMap();

        return arkService.streamChatCompletion(chatCompletionRequest).filter(chunk -> {
            String reasoningContent = chunk.getChoices().get(0).getMessage().getReasoningContent();
            return StringUtils.isBlank(reasoningContent);
        }).concatMap(chunk -> {

            ChatCompletionChoice chatCompletionChoice = chunk.getChoices().get(0);
            ChatMessage chatMessage = chatCompletionChoice.getMessage();
            Object content = chatMessage.getContent();
            if (content != null && !content.toString().isEmpty()) {
                return Flowable.just(chunk);
            }
            List<ChatToolCall> toolCalls = chatMessage.getToolCalls();
            if (toolCalls != null) {
                for (ChatToolCall toolCall : toolCalls) {
                    int index = toolCall.getIndex();
                    if (!pendingToolCalls.containsKey(index)) {
                        pendingToolCalls.put(index, toolCall);
                    } else {
                        ChatToolCall existing = pendingToolCalls.get(index);
                        existing.getFunction().setArguments(
                                existing.getFunction().getArguments() +
                                        toolCall.getFunction().getArguments()
                        );
                    }
                }
            }
            String finishReason = chatCompletionChoice.getFinishReason();
            if (StringUtils.isNotBlank(finishReason) && "tool_calls".equals(finishReason)) {
                for (ChatToolCall chatToolCall : pendingToolCalls.values()) {
                    ChatFunctionCall function = chatToolCall.getFunction();
                    String name = function.getName();
                    ToolCallBackProvider toolCallBackProvider = toolCallBackProviders.stream()
                            .filter(provider -> provider.getToolName().equals(name))
                            .findFirst()
                            .orElse(null);
                    if (toolCallBackProvider == null) {
                        continue;
                    }
                    ChatMessage toolMessage = toolCallBackProvider.getToolMessage(function.getArguments(), chatToolCall.getId());
                    if (toolMessage != null) {
                        messages.add(toolMessage);
                    }
                }
                return arkService.streamChatCompletion(chatCompletionRequest);
            }
            return Flowable.just(chunk);
        });
    }
    private List<ChatTool> getTools() {
        return toolCallBackProviders.stream().map(ToolCallBackProvider::getChatTool).collect(Collectors.toList());
    }


    @Override
    public String botChat(ChatRequest chatRequest) {
        BotChatCompletionResult botChatCompletion = arkService.createBotChatCompletion(BotChatCompletionRequest.builder()
                .botId(StringUtils.isNotBlank(chatRequest.getBotId()) ? chatRequest.getBotId() : defaultBotId)
                .messages(Collections.singletonList(ChatMessage.builder().role(ChatMessageRole.USER).content(chatRequest.getMessage()).build()))
                .build());
        return botChatCompletion.getChoices().get(0).getMessage().getContent().toString();
    }

    @Override
    public byte[] getEmbedding(String text) {
        List<String> input = new ArrayList<>();
        input.add(text);
        EmbeddingResult embeddingResult = arkService.createEmbeddings(EmbeddingRequest.builder().input(input).model(embeddingModelId).build());
        List<Double> embeddings = embeddingResult.getData().get(0).getEmbedding();

        float[] floats = new float[embeddings.size()];
        for (int i = 0; i < embeddings.size(); i++) {
            floats[i] = embeddings.get(i).floatValue();
        }

        ByteBuffer buffer = ByteBuffer.allocate(floats.length * Float.BYTES)
                .order(ByteOrder.LITTLE_ENDIAN);
        buffer.asFloatBuffer().put(floats);
        return buffer.array();
    }
}
