/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.physical;

import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import org.opensearch.client.node.NodeClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.ColumnValue;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLAlgoParams;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLPredictionOutput;
import org.opensearch.sql.data.model.ExprBooleanValue;
import org.opensearch.sql.data.model.ExprDoubleValue;
import org.opensearch.sql.data.model.ExprFloatValue;
import org.opensearch.sql.data.model.ExprIntegerValue;
import org.opensearch.sql.data.model.ExprLongValue;
import org.opensearch.sql.data.model.ExprShortValue;
import org.opensearch.sql.data.model.ExprStringValue;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.opensearch.client.MLClient;
import org.opensearch.sql.planner.physical.PhysicalPlan;

public abstract class MLCommonsOperatorActions
extends PhysicalPlan {
    protected DataFrame generateInputDataset(final PhysicalPlan input) {
        LinkedList<1> inputData = new LinkedList<1>();
        while (input.hasNext()) {
            inputData.add(new HashMap<String, Object>(){
                {
                    ((ExprValue)input.next()).tupleValue().forEach((key, value) -> this.put(key, value.value()));
                }
            });
        }
        return DataFrameBuilder.load(inputData);
    }

    protected Map<String, ExprValue> convertRowIntoExprValue(ColumnMeta[] columnMetas, Row row) {
        ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder();
        for (int i = 0; i < columnMetas.length; ++i) {
            ColumnValue columnValue = row.getValue(i);
            String resultKeyName = columnMetas[i].getName();
            this.populateResultBuilder(columnValue, resultKeyName, (ImmutableMap.Builder<String, ExprValue>)resultBuilder);
        }
        return resultBuilder.build();
    }

    protected void populateResultBuilder(ColumnValue columnValue, String resultKeyName, ImmutableMap.Builder<String, ExprValue> resultBuilder) {
        switch (columnValue.columnType()) {
            case INTEGER: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprIntegerValue((Number)columnValue.intValue()));
                break;
            }
            case DOUBLE: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprDoubleValue((Number)columnValue.doubleValue()));
                break;
            }
            case STRING: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprStringValue(columnValue.stringValue()));
                break;
            }
            case SHORT: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprShortValue((Number)columnValue.shortValue()));
                break;
            }
            case LONG: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprLongValue((Number)columnValue.longValue()));
                break;
            }
            case FLOAT: {
                resultBuilder.put((Object)resultKeyName, (Object)new ExprFloatValue((Number)Float.valueOf(columnValue.floatValue())));
                break;
            }
            case BOOLEAN: {
                resultBuilder.put((Object)resultKeyName, (Object)ExprBooleanValue.of((Boolean)columnValue.booleanValue()));
                break;
            }
        }
    }

    protected Map<String, ExprValue> convertResultRowIntoExprValue(ColumnMeta[] columnMetas, Row row, Map<String, ExprValue> schema) {
        ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder();
        for (int i = 0; i < columnMetas.length; ++i) {
            ColumnValue columnValue = row.getValue(i);
            Object resultKeyName = columnMetas[i].getName();
            if (schema.containsKey(resultKeyName)) {
                resultKeyName = (String)resultKeyName + "1";
            }
            this.populateResultBuilder(columnValue, (String)resultKeyName, (ImmutableMap.Builder<String, ExprValue>)resultBuilder);
        }
        return resultBuilder.build();
    }

    protected ExprTupleValue buildResult(Iterator<Row> inputRowIter, DataFrame inputDataFrame, MLPredictionOutput predictionResult, Iterator<Row> resultRowIter) {
        ImmutableMap.Builder resultSchemaBuilder = new ImmutableMap.Builder();
        resultSchemaBuilder.putAll(this.convertRowIntoExprValue(inputDataFrame.columnMetas(), inputRowIter.next()));
        ImmutableMap resultSchema = resultSchemaBuilder.build();
        ImmutableMap.Builder resultBuilder = new ImmutableMap.Builder();
        resultBuilder.putAll(this.convertResultRowIntoExprValue(predictionResult.getPredictionResult().columnMetas(), resultRowIter.next(), (Map<String, ExprValue>)resultSchema));
        resultBuilder.putAll((Map)resultSchema);
        return ExprTupleValue.fromExprValueMap((Map)resultBuilder.build());
    }

    protected MLPredictionOutput getMLPredictionResult(FunctionName functionName, MLAlgoParams mlAlgoParams, DataFrame inputDataFrame, NodeClient nodeClient) {
        MLInput mlinput = MLInput.builder().algorithm(functionName).parameters(mlAlgoParams).inputDataset((MLInputDataset)new DataFrameInputDataset(inputDataFrame)).build();
        MachineLearningNodeClient machineLearningClient = MLClient.getMLClient(nodeClient);
        return (MLPredictionOutput)machineLearningClient.trainAndPredict(mlinput).actionGet(30L, TimeUnit.SECONDS);
    }
}

