package org.example.redisearch.server.service.impl;

import com.alibaba.fastjson2.JSON;
import com.volcengine.ark.runtime.model.embeddings.Embedding;
import com.volcengine.ark.runtime.model.embeddings.EmbeddingRequest;
import com.volcengine.ark.runtime.service.ArkService;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Resource;
import javax.validation.constraints.NotNull;
import lombok.Generated;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.example.redisearch.server.constant.VectorConstant;
import org.example.redisearch.server.entity.AudienceBaseInfo;
import org.example.redisearch.server.entity.BusinessBaseInfo;
import org.example.redisearch.server.service.VectorService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.jdbc.core.simple.JdbcClient;
import org.springframework.stereotype.Service;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.IndexDefinition;
import redis.clients.jedis.search.IndexOptions;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.Schema;

@Service
/* loaded from: input_file:BOOT-INF/classes/org/example/redisearch/server/service/impl/VectorServiceImpl.class */
public class VectorServiceImpl implements VectorService {

    @Generated
    private static final Logger log = LoggerFactory.getLogger((Class<?>) VectorServiceImpl.class);

    @Resource
    private JedisPooled jedisPooled;

    @Resource
    private ArkService arkService;

    @Value("${ark.embedding:doubao-embedding-text-240715}")
    private String embeddingModelId;

    @Resource
    private JdbcClient jdbcClient;

    @Override // org.example.redisearch.server.service.VectorService
    public void dropIndex(String str) {
        log.info("删除索引结果：{}", this.jedisPooled.ftDropIndex(str));
    }

    @Override // org.example.redisearch.server.service.VectorService
    public void createAudienceVectorIndex() {
        log.info("创建索引结果：{}", this.jedisPooled.ftCreate(VectorConstant.AUDIENCE_INDEX_NAME, IndexOptions.defaultOptions().setDefinition(new IndexDefinition(IndexDefinition.Type.HASH).setPrefixes(VectorConstant.AUDIENCE_INDEX_PREFIX).setLanguage("chinese")), new Schema().addTextField("businessName", 2.0d).addNumericField("id").addTextField("name", 1.0d).addFlatVectorField("embedding", getVectorAttribute())));
    }

    @Override // org.example.redisearch.server.service.VectorService
    public List<String> searchAudienceVector(String str) {
        List<Document> documents = this.jedisPooled.ftSearch(VectorConstant.AUDIENCE_INDEX_NAME, new Query("*=>[KNN $K @embedding $BLOB AS score]").addParam("K", 3).addParam("BLOB", getEmbedding(str)).returnFields("businessName", "id", "score").setSortBy("score", false).dialect(2)).getDocuments();
        log.info("查询结果：{}", JSON.toJSONString(documents));
        return documents.stream().map(document -> {
            return document.get("businessName").toString();
        }).distinct().toList();
    }

    @Override // org.example.redisearch.server.service.VectorService
    public void addAudienceDataToRedis(Integer num) {
        Optional optional = this.jdbcClient.sql("select id,business_name as businessName from smdm_audience_baseinfo\nwhere statuts = 1 and business_name != '' and business_name is not null and id = ?\n").param(num).query(AudienceBaseInfo.class).optional();
        if (optional.isEmpty()) {
            return;
        }
        batchAddAudienceDataToRedis(List.of((AudienceBaseInfo) optional.get()));
    }

    @Override // org.example.redisearch.server.service.VectorService
    public void batchAddAudienceDataToRedis(List<AudienceBaseInfo> list) {
        for (AudienceBaseInfo audienceBaseInfo : list) {
            String str = "idx:audience:" + audienceBaseInfo.getId();
            HashMap hashMap = new HashMap();
            hashMap.put("id", String.valueOf(audienceBaseInfo.getId()));
            hashMap.put("businessName", audienceBaseInfo.getBusinessName());
            this.jedisPooled.hset(str, hashMap);
            this.jedisPooled.hset(str.getBytes(StandardCharsets.UTF_8), "embedding".getBytes(StandardCharsets.UTF_8), getEmbedding(audienceBaseInfo.getBusinessName()));
        }
    }

    @Override // org.example.redisearch.server.service.VectorService
    public void createBusinessVectorIndex() {
        log.info("创建核心企业索引结果：{}", this.jedisPooled.ftCreate(VectorConstant.BUSINESS_INDEX_NAME, IndexOptions.defaultOptions().setDefinition(new IndexDefinition(IndexDefinition.Type.HASH).setPrefixes(VectorConstant.BUSINESS_INDEX_PREFIX).setLanguage("chinese")), new Schema().addTextField("name", 2.0d).addNumericField("id").addFlatVectorField("embedding", getVectorAttribute())));
    }

    @Override // org.example.redisearch.server.service.VectorService
    public List<Integer> searchBusinessVector(String str) {
        List<Document> documents = this.jedisPooled.ftSearch(VectorConstant.BUSINESS_INDEX_NAME, new Query("*=>[KNN $K @embedding $BLOB AS score]").addParam("K", 1).addParam("BLOB", getEmbedding(str)).returnFields("name", "id", "score").setSortBy("score", false).dialect(2)).getDocuments();
        log.info("查询核心企业结果：{}", JSON.toJSONString(documents));
        return documents.stream().map(document -> {
            return Integer.valueOf(Integer.parseInt(document.get("id").toString()));
        }).distinct().toList();
    }

    @Override // org.example.redisearch.server.service.VectorService
    public void batchAddBusinessDataToRedis(List<BusinessBaseInfo> list) {
        for (BusinessBaseInfo businessBaseInfo : list) {
            String str = "idx:business:" + businessBaseInfo.getId();
            HashMap hashMap = new HashMap();
            hashMap.put("id", String.valueOf(businessBaseInfo.getId()));
            hashMap.put("name", businessBaseInfo.getName());
            this.jedisPooled.hset(str, hashMap);
            this.jedisPooled.hset(str.getBytes(StandardCharsets.UTF_8), "embedding".getBytes(StandardCharsets.UTF_8), getEmbedding(businessBaseInfo.getName()));
        }
    }

    private byte[] getEmbedding(String str) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(str);
        List<Double> slicedNormL2 = slicedNormL2(((Embedding) this.arkService.createEmbeddings(EmbeddingRequest.builder().input(arrayList).model(this.embeddingModelId).build()).getData().getFirst()).getEmbedding(), 512);
        float[] fArr = new float[slicedNormL2.size()];
        for (int i = 0; i < slicedNormL2.size(); i++) {
            fArr[i] = slicedNormL2.get(i).floatValue();
        }
        ByteBuffer order = ByteBuffer.allocate(fArr.length * 4).order(ByteOrder.LITTLE_ENDIAN);
        order.asFloatBuffer().put(fArr);
        return order.array();
    }

    @NotNull
    private static Map<String, Object> getVectorAttribute() {
        HashMap hashMap = new HashMap();
        hashMap.put("TYPE", VectorConstant.VECTOR_ATTRIBUTE_TYPE_FLOAT32);
        hashMap.put("DIM", 512);
        hashMap.put("DISTANCE_METRIC", VectorConstant.VECTOR_ATTRIBUTE_DISTANCE_METRIC);
        return hashMap;
    }

    private static List<Double> slicedNormL2(List<Double> list, int i) {
        List<Double> subList = list.subList(0, Math.min(i, list.size()));
        double norm = new ArrayRealVector(subList.stream().mapToDouble(d -> {
            return d.doubleValue();
        }).toArray()).getNorm();
        return (List) subList.stream().map(d2 -> {
            return Double.valueOf(d2.doubleValue() / norm);
        }).collect(Collectors.toList());
    }
}
