/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.search.processor.mmr;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.SpaceTypeResolver;
import org.opensearch.knn.plugin.transport.GetModelAction;
import org.opensearch.knn.plugin.transport.GetModelRequest;
import org.opensearch.knn.search.extension.MMRSearchExtBuilder;
import org.opensearch.knn.search.processor.mmr.MMRVectorFieldInfo;
import org.opensearch.search.pipeline.ProcessorGenerationContext;
import org.opensearch.transport.client.Client;
import reactor.util.annotation.NonNull;
import reactor.util.annotation.Nullable;

public class MMRUtil {
    private static List<MMRVectorFieldInfo> collectKnnVectorFieldInfos(@NonNull String path, @NonNull List<IndexMetadata> indexMetadataList) {
        ArrayList<MMRVectorFieldInfo> vectorFieldInfos = new ArrayList<MMRVectorFieldInfo>();
        for (IndexMetadata indexMetadata : indexMetadataList) {
            vectorFieldInfos.add(MMRUtil.collectKnnVectorFieldInfo(indexMetadata, path));
        }
        return vectorFieldInfos;
    }

    private static MMRVectorFieldInfo collectKnnVectorFieldInfo(IndexMetadata indexMetadata, String path) {
        MMRVectorFieldInfo vectorFieldInfo = new MMRVectorFieldInfo();
        vectorFieldInfo.setIndexNameByIndexMetadata(indexMetadata);
        MappingMetadata mappingMetadata = indexMetadata.mapping();
        if (mappingMetadata == null) {
            vectorFieldInfo.setUnmapped(true);
            return vectorFieldInfo;
        }
        Map mapping = mappingMetadata.sourceAsMap();
        Map<String, Object> config = MMRUtil.getMMRFieldMappingByPath(mapping, path);
        if (config == null) {
            vectorFieldInfo.setUnmapped(true);
            return vectorFieldInfo;
        }
        vectorFieldInfo.setUnmapped(false);
        vectorFieldInfo.setFieldPath(path);
        String fieldType = (String)config.get("type");
        vectorFieldInfo.setFieldType(fieldType);
        if (!"knn_vector".equals(fieldType)) {
            return vectorFieldInfo;
        }
        vectorFieldInfo.setKnnConfig(config);
        return vectorFieldInfo;
    }

    private static MMRVectorFieldInfo resolveKnnVectorFieldInfo(SpaceType userProvidedSpaceType, VectorDataType userProvidedVectorDataType, List<MMRVectorFieldInfo> MMRVectorFieldInfoList) throws IllegalArgumentException {
        boolean allUnmapped = true;
        ArrayList<MMRVectorFieldInfo> nonKnnFields = new ArrayList<MMRVectorFieldInfo>();
        SpaceType resolvedSpaceType = null;
        VectorDataType resolvedVectorDataType = null;
        for (MMRVectorFieldInfo info2 : MMRVectorFieldInfoList) {
            if (info2.isUnmapped()) continue;
            allUnmapped = false;
            if (!info2.isKNNVectorField()) {
                nonKnnFields.add(info2);
                continue;
            }
            resolvedSpaceType = MMRUtil.resolveConsistentValue(resolvedSpaceType, info2.getSpaceType(), SpaceType::getValue, "space type", info2.getFieldPath());
            resolvedVectorDataType = MMRUtil.resolveConsistentValue(resolvedVectorDataType, info2.getVectorDataType(), VectorDataType::getValue, "vector data type", info2.getFieldPath());
        }
        if (allUnmapped) {
            resolvedSpaceType = userProvidedSpaceType != null ? userProvidedSpaceType : SpaceTypeResolver.getDefaultSpaceType(VectorDataType.DEFAULT);
            resolvedVectorDataType = userProvidedVectorDataType != null ? userProvidedVectorDataType : VectorDataType.DEFAULT;
            return new MMRVectorFieldInfo(resolvedSpaceType, resolvedVectorDataType);
        }
        if (!nonKnnFields.isEmpty()) {
            throw new IllegalArgumentException(String.format("MMR query extension cannot support non knn_vector field [%s].", nonKnnFields.stream().map(info -> String.format(Locale.ROOT, "%s:%s", info.getIndexName(), info.getFieldPath())).collect(Collectors.joining(","))));
        }
        return MMRUtil.resolveFinalKnnVectorFieldInfo(userProvidedSpaceType, resolvedSpaceType, userProvidedVectorDataType, resolvedVectorDataType);
    }

    private static <T> T resolveConsistentValue(T current, T next, Function<T, String> valueFormatter, String fieldDescription, String fieldPath) {
        if (next == null) {
            return current;
        }
        if (current == null) {
            return next;
        }
        if (!current.equals(next)) {
            throw new IllegalArgumentException(String.format("MMR query extension cannot support different %s [%s, %s] for the knn_vector field at path %s.", fieldDescription, valueFormatter.apply(current), valueFormatter.apply(next), fieldPath));
        }
        return current;
    }

    private static MMRVectorFieldInfo resolveFinalKnnVectorFieldInfo(SpaceType userProvidedSpaceType, SpaceType resolvedSpaceType, VectorDataType userProvidedVectorDataType, VectorDataType resolvedVectorDataType) throws IllegalArgumentException {
        SpaceType finalSpaceType = MMRUtil.resolveFinalValue(userProvidedSpaceType, resolvedSpaceType, () -> SpaceTypeResolver.getDefaultSpaceType(VectorDataType.DEFAULT), SpaceType::getValue, "space type");
        VectorDataType finalVectorDataType = MMRUtil.resolveFinalValue(userProvidedVectorDataType, resolvedVectorDataType, () -> VectorDataType.DEFAULT, VectorDataType::getValue, "vector data type");
        return new MMRVectorFieldInfo(finalSpaceType, finalVectorDataType);
    }

    private static <T> T resolveFinalValue(T userProvided, T resolved, Supplier<T> defaultSupplier, Function<T, String> valueFormatter, String fieldDescription) {
        if (userProvided != null && resolved != null && !userProvided.equals(resolved)) {
            throw new IllegalArgumentException(String.format("The %s [%s] provided in the MMR query extension does not match the %s [%s] in target indices.", fieldDescription, valueFormatter.apply(userProvided), fieldDescription, valueFormatter.apply(resolved)));
        }
        if (userProvided != null) {
            return userProvided;
        }
        if (resolved != null) {
            return resolved;
        }
        return defaultSupplier.get();
    }

    private static MMRVectorFieldInfo resolveVectorFieldInfoFromModel(VectorDataType userProvidedVectorDataType, SpaceType userProvidedSpaceType, List<MMRVectorFieldInfo> MMRVectorFieldInfoList, Map<String, MMRVectorFieldInfo> modelIdToVectorFieldInfo) throws IllegalArgumentException {
        SpaceType resolvedSpaceType = null;
        VectorDataType resolvedVectorDataType = null;
        for (MMRVectorFieldInfo info : MMRVectorFieldInfoList) {
            SpaceType spaceType;
            VectorDataType vectorDataType;
            if (info.getModelId() != null) {
                MMRVectorFieldInfo infoFromModel = modelIdToVectorFieldInfo.get(info.getModelId());
                if (infoFromModel == null) {
                    throw new IllegalStateException(String.format("Unexpected null when try to resolve the info of the vector field at path [%s] based on its model [%s].", info.getModelId(), info.getFieldPath()));
                }
                vectorDataType = infoFromModel.getVectorDataType() != null ? infoFromModel.getVectorDataType() : VectorDataType.DEFAULT;
                spaceType = infoFromModel.getSpaceType() != null ? infoFromModel.getSpaceType() : SpaceTypeResolver.getDefaultSpaceType(vectorDataType);
            } else {
                spaceType = info.getSpaceType();
                vectorDataType = info.getVectorDataType();
            }
            resolvedSpaceType = MMRUtil.resolveConsistentValue(resolvedSpaceType, spaceType, SpaceType::getValue, "space type", info.getFieldPath());
            resolvedVectorDataType = MMRUtil.resolveConsistentValue(resolvedVectorDataType, vectorDataType, VectorDataType::getValue, "vector data type", info.getFieldPath());
        }
        return MMRUtil.resolveFinalKnnVectorFieldInfo(userProvidedSpaceType, resolvedSpaceType, userProvidedVectorDataType, resolvedVectorDataType);
    }

    private static void retrieveFieldInfoFromModel(@NonNull Set<String> modelIds, @NonNull Client client, @NonNull ActionListener<Map<String, MMRVectorFieldInfo>> listener) {
        ConcurrentHashMap modelIdToVectorFieldInfo = new ConcurrentHashMap();
        List errors = Collections.synchronizedList(new ArrayList());
        AtomicInteger counter = new AtomicInteger(modelIds.size());
        for (String modelId : modelIds) {
            client.execute((ActionType)GetModelAction.INSTANCE, (ActionRequest)new GetModelRequest(modelId), ActionListener.wrap(response -> {
                SpaceType spaceTypeFromModel = null;
                VectorDataType vectorDataTypeFromModel = null;
                if (response != null && response.getModel() != null && response.getModel().getModelMetadata() != null) {
                    spaceTypeFromModel = response.getModel().getModelMetadata().getSpaceType();
                    vectorDataTypeFromModel = response.getModel().getModelMetadata().getVectorDataType();
                }
                modelIdToVectorFieldInfo.put(modelId, new MMRVectorFieldInfo(spaceTypeFromModel, vectorDataTypeFromModel));
                if (counter.decrementAndGet() == 0) {
                    listener.onResponse((Object)modelIdToVectorFieldInfo);
                }
            }, e -> {
                errors.add(e.getMessage());
                if (counter.decrementAndGet() == 0) {
                    listener.onFailure((Exception)new RuntimeException(String.format(Locale.ROOT, "Failed to retrieve model(s) to resolve the space type and vector data type for the MMR query extension. Errors: %s.", String.join((CharSequence)", ", errors))));
                }
            }));
        }
    }

    public static void resolveKnnVectorFieldInfo(@NonNull String path, @Nullable SpaceType userProvidedSpaceType, @Nullable VectorDataType userProvidedVectorDataType, @NonNull List<IndexMetadata> localIndexMetadataList, @NonNull Client client, @NonNull ActionListener<MMRVectorFieldInfo> continuation) {
        try {
            List<MMRVectorFieldInfo> knnVectorFieldInfos = MMRUtil.collectKnnVectorFieldInfos(path, localIndexMetadataList);
            MMRUtil.resolveKnnVectorFieldInfo(knnVectorFieldInfos, userProvidedSpaceType, userProvidedVectorDataType, client, continuation);
        }
        catch (Exception e) {
            continuation.onFailure(e);
        }
    }

    public static void resolveKnnVectorFieldInfo(@NonNull List<MMRVectorFieldInfo> MMRVectorFieldInfoList, @Nullable SpaceType userProvidedSpaceType, @Nullable VectorDataType userProvidedVectorDataType, @NonNull Client client, @NonNull ActionListener<MMRVectorFieldInfo> continuation) {
        try {
            MMRVectorFieldInfo resolvedVectorFieldInfo = MMRUtil.resolveKnnVectorFieldInfo(userProvidedSpaceType, userProvidedVectorDataType, MMRVectorFieldInfoList);
            Set<String> modelIds = MMRVectorFieldInfoList.stream().map(MMRVectorFieldInfo::getModelId).filter(Objects::nonNull).collect(Collectors.toSet());
            if (modelIds.isEmpty()) {
                continuation.onResponse((Object)resolvedVectorFieldInfo);
            } else {
                MMRUtil.retrieveFieldInfoFromModel(modelIds, client, (ActionListener<Map<String, MMRVectorFieldInfo>>)ActionListener.wrap(modelIdToVectorFieldInfo -> {
                    MMRVectorFieldInfo resolvedVectorFieldInfoFromModel = MMRUtil.resolveVectorFieldInfoFromModel(userProvidedVectorDataType, userProvidedSpaceType, MMRVectorFieldInfoList, modelIdToVectorFieldInfo);
                    continuation.onResponse((Object)resolvedVectorFieldInfoFromModel);
                }, arg_0 -> continuation.onFailure(arg_0)));
            }
        }
        catch (Exception e) {
            continuation.onFailure(e);
        }
    }

    public static Object extractVectorFromHit(Map<String, Object> sourceAsMap, String fieldPath, String docId, boolean isFloatVector) throws IllegalArgumentException {
        String baseError = String.format(Locale.ROOT, "Failed to extract the vector from the doc [%s] for MMR rerank", docId);
        if (sourceAsMap == null || fieldPath == null) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s: source map and fieldPath must not be null.", baseError));
        }
        String[] pathParts = fieldPath.split("\\.");
        Object current = sourceAsMap;
        if (pathParts.length == 0) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s: fieldPath must not be an empty string.", baseError));
        }
        for (int i = 0; i < pathParts.length; ++i) {
            String part = pathParts[i];
            if (!(current instanceof Map)) {
                throw new IllegalArgumentException(String.format("%s: expected object at [%s], but found [%s]", baseError, part, current.getClass().getName()));
            }
            Map<String, Object> map = current;
            current = map.get(part);
            if (current == null) {
                throw new IllegalArgumentException(String.format("%s: field path [%s] not found in document source.", baseError, fieldPath));
            }
            if (i != pathParts.length - 1) continue;
            if (current instanceof List) {
                List list = (List)current;
                float[] floatVector = null;
                byte[] byteVector = null;
                if (isFloatVector) {
                    floatVector = new float[list.size()];
                } else {
                    byteVector = new byte[list.size()];
                }
                try {
                    for (int j = 0; j < list.size(); ++j) {
                        if (isFloatVector) {
                            floatVector[j] = (float)((Double)list.get(j)).doubleValue();
                            continue;
                        }
                        byteVector[j] = (byte)((Double)list.get(j)).doubleValue();
                    }
                }
                catch (Exception e) {
                    throw new IllegalArgumentException(String.format("%s: unexpected value at the vector field [%s]. error: %s", baseError, fieldPath, e.getMessage()), e);
                }
                if (isFloatVector) {
                    return floatVector;
                }
                return byteVector;
            }
            throw new IllegalArgumentException(String.format("%s: expected vector (list of numbers) at field path [%s], but found type [%s]", baseError, fieldPath, current.getClass().getName()));
        }
        throw new IllegalStateException(String.format("%s: unexpected error resolving field path [%s].", baseError, fieldPath));
    }

    public static boolean shouldGenerateMMRProcessor(ProcessorGenerationContext processorGenerationContext) {
        SearchRequest request = processorGenerationContext.searchRequest();
        if (request == null || request.source() == null || request.source().ext() == null) {
            return false;
        }
        return request.source().ext().stream().anyMatch(MMRSearchExtBuilder.class::isInstance);
    }

    public static Map<String, Object> getMMRFieldMappingByPath(Map<String, Object> mappings, @NonNull String fieldPath) {
        if (mappings == null) {
            return null;
        }
        String[] parts = fieldPath.split("\\.");
        Map current = mappings;
        for (int i = 0; i < parts.length; ++i) {
            String part = parts[i];
            Object propertiesObj = current.get("properties");
            if (!(propertiesObj instanceof Map)) {
                return null;
            }
            Map properties = (Map)propertiesObj;
            Object fieldConfig = properties.get(part);
            if (!(fieldConfig instanceof Map)) {
                return null;
            }
            current = (Map)fieldConfig;
            String fieldType = (String)current.get("type");
            if (!"nested".equals(fieldType)) continue;
            throw new IllegalArgumentException(String.format("MMR search extension cannot support the field %s because it is in the nested field %s.", fieldPath, part));
        }
        return current;
    }
}

