package ai.djl.huggingface.translator;

import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.Map;

/* loaded from: input_file:ai/djl/huggingface/translator/TextEmbeddingTranslator.class */
public class TextEmbeddingTranslator implements Translator<String, float[]> {
    private static final int[] AXIS = {0};
    private HuggingFaceTokenizer tokenizer;
    private Batchifier batchifier;
    private boolean normalize;
    private String pooling;
    private boolean includeTokenTypes;

    /* loaded from: input_file:ai/djl/huggingface/translator/TextEmbeddingTranslator$Builder.class */
    public static final class Builder {
        private HuggingFaceTokenizer tokenizer;
        private Batchifier batchifier = Batchifier.STACK;
        private boolean normalize = true;
        private String pooling = "mean";
        private boolean includeTokenTypes;

        Builder(HuggingFaceTokenizer huggingFaceTokenizer) {
            this.tokenizer = huggingFaceTokenizer;
        }

        public Builder optBatchifier(Batchifier batchifier) {
            this.batchifier = batchifier;
            return this;
        }

        public Builder optNormalize(boolean z) {
            this.normalize = z;
            return this;
        }

        public Builder optPoolingMode(String str) {
            if (!"mean".equals(str) && !"max".equals(str) && !"cls".equals(str) && !"mean_sqrt_len".equals(str) && !"weightedmean".equals(str)) {
                throw new IllegalArgumentException("Invalid pooling model, must be one of [mean, max, cls, mean_sqrt_len, weightedmean].");
            }
            this.pooling = str;
            return this;
        }

        public Builder optIncludeTokenTypes(boolean z) {
            this.includeTokenTypes = z;
            return this;
        }

        public void configure(Map<String, ?> map) {
            optBatchifier(Batchifier.fromString(ArgumentsUtil.stringValue(map, "batchifier", "stack")));
            optNormalize(ArgumentsUtil.booleanValue(map, "normalize", true));
            optPoolingMode(ArgumentsUtil.stringValue(map, "pooling", "mean"));
            optIncludeTokenTypes(ArgumentsUtil.booleanValue(map, "includeTokenTypes"));
        }

        public TextEmbeddingTranslator build() throws IOException {
            return new TextEmbeddingTranslator(this.tokenizer, this.batchifier, this.pooling, this.normalize, this.includeTokenTypes);
        }
    }

    TextEmbeddingTranslator(HuggingFaceTokenizer huggingFaceTokenizer, Batchifier batchifier, String str, boolean z, boolean z2) {
        this.tokenizer = huggingFaceTokenizer;
        this.batchifier = batchifier;
        this.pooling = str;
        this.normalize = z;
        this.includeTokenTypes = z2;
    }

    public Batchifier getBatchifier() {
        return this.batchifier;
    }

    public NDList processInput(TranslatorContext translatorContext, String str) {
        Encoding encode = this.tokenizer.encode(str);
        translatorContext.setAttachment("encoding", encode);
        return encode.toNDList(translatorContext.getNDManager(), this.includeTokenTypes);
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public float[] m20processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray processEmbedding = processEmbedding(translatorContext.getNDManager(), nDList, (Encoding) translatorContext.getAttachment("encoding"), this.pooling);
        if (this.normalize) {
            processEmbedding = processEmbedding.normalize(2.0d, 0L);
        }
        return processEmbedding.toFloatArray();
    }

    /* renamed from: toBatchTranslator, reason: merged with bridge method [inline-methods] */
    public TextEmbeddingBatchTranslator m19toBatchTranslator(Batchifier batchifier) {
        this.tokenizer.enableBatch();
        return new TextEmbeddingBatchTranslator(this.tokenizer, batchifier, this.pooling, this.normalize);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static NDArray processEmbedding(NDManager nDManager, NDList nDList, Encoding encoding, String str) {
        NDArray nDArray = nDList.get("last_hidden_state");
        if (nDArray == null) {
            nDArray = nDList.head();
        }
        NDArray type = nDManager.create(encoding.getAttentionMask()).toType(DataType.FLOAT32, true);
        boolean z = -1;
        switch (str.hashCode()) {
            case -883242576:
                if (str.equals("mean_sqrt_len")) {
                    z = true;
                    break;
                }
                break;
            case 98602:
                if (str.equals("cls")) {
                    z = 4;
                    break;
                }
                break;
            case 107876:
                if (str.equals("max")) {
                    z = 2;
                    break;
                }
                break;
            case 3347397:
                if (str.equals("mean")) {
                    z = false;
                    break;
                }
                break;
            case 156320092:
                if (str.equals("weightedmean")) {
                    z = 3;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return meanPool(nDArray, type, false);
            case true:
                return meanPool(nDArray, type, true);
            case true:
                return maxPool(nDArray, type);
            case true:
                return weightedMeanPool(nDArray, type);
            case true:
                return nDArray.get(new long[]{0});
            default:
                throw new AssertionError("Unexpected pooling mode: " + str);
        }
    }

    private static NDArray meanPool(NDArray nDArray, NDArray nDArray2, boolean z) {
        NDArray broadcast = nDArray2.expandDims(-1).broadcast(nDArray.getShape().getShape());
        NDArray clip = broadcast.sum(AXIS).clip(Double.valueOf(1.0E-9d), Double.valueOf(1.0E12d));
        NDArray sum = nDArray.mul(broadcast).sum(AXIS);
        return z ? sum.div(clip.sqrt()) : sum.div(clip);
    }

    private static NDArray maxPool(NDArray nDArray, NDArray nDArray2) {
        NDArray eq = nDArray2.expandDims(-1).broadcast(nDArray.getShape().getShape()).eq(0);
        NDArray duplicate = nDArray.duplicate();
        duplicate.set(eq, Double.valueOf(-1.0E9d));
        return duplicate.max(AXIS, true);
    }

    private static NDArray weightedMeanPool(NDArray nDArray, NDArray nDArray2) {
        long[] shape = nDArray.getShape().getShape();
        NDArray mul = nDArray2.expandDims(-1).broadcast(shape).mul(nDArray.getManager().arange(1.0f, (float) (shape[0] + 1)).expandDims(-1).broadcast(shape));
        return nDArray.mul(mul).sum(AXIS).div(mul.sum(AXIS));
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer) {
        return new Builder(huggingFaceTokenizer);
    }

    public static Builder builder(HuggingFaceTokenizer huggingFaceTokenizer, Map<String, ?> map) {
        Builder builder = builder(huggingFaceTokenizer);
        builder.configure(map);
        return builder;
    }
}
