/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.opensearch.common.collect.Tuple;
import org.opensearch.ml.common.dataframe.ColumnMeta;
import org.opensearch.ml.common.dataframe.ColumnValue;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.Row;
import org.opensearch.ml.engine.contants.TribuoOutputType;
import org.tribuo.DataSource;
import org.tribuo.MutableDataset;
import org.tribuo.Output;
import org.tribuo.OutputFactory;
import org.tribuo.anomaly.Event;
import org.tribuo.clustering.ClusterID;
import org.tribuo.datasource.ListDataSource;
import org.tribuo.impl.ArrayExample;
import org.tribuo.provenance.DataSourceProvenance;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.regression.Regressor;

public final class TribuoUtil {
    public static Tuple<String[], double[][]> transformDataFrame(DataFrame dataFrame) {
        String[] featureNames = (String[])Arrays.stream(dataFrame.columnMetas()).map(ColumnMeta::getName).toArray(String[]::new);
        double[][] featureValues = new double[dataFrame.size()][];
        Iterator itr = dataFrame.iterator();
        int i = 0;
        while (itr.hasNext()) {
            Row row = (Row)itr.next();
            featureValues[i] = StreamSupport.stream(row.spliterator(), false).mapToDouble(ColumnValue::doubleValue).toArray();
            ++i;
        }
        return new Tuple((Object)featureNames, (Object)featureValues);
    }

    public static <T extends Output<T>> MutableDataset<T> generateDataset(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType) {
        ArrayList<ArrayExample> dataset = new ArrayList<ArrayExample>();
        Tuple<String[], double[][]> featureNamesValues = TribuoUtil.transformDataFrame(dataFrame);
        for (int i = 0; i < dataFrame.size(); ++i) {
            ArrayExample example;
            switch (outputType) {
                case CLUSTERID: {
                    example = new ArrayExample((Output)new ClusterID(-1), (String[])featureNamesValues.v1(), ((double[][])featureNamesValues.v2())[i]);
                    break;
                }
                case REGRESSOR: {
                    example = new ArrayExample((Output)new Regressor("DIM-0", Double.NaN), (String[])featureNamesValues.v1(), ((double[][])featureNamesValues.v2())[i]);
                    break;
                }
                case ANOMALY_DETECTION_LIBSVM: {
                    Event.EventType defaultEventType = Event.EventType.EXPECTED;
                    example = new ArrayExample((Output)new Event(defaultEventType), (String[])featureNamesValues.v1(), ((double[][])featureNamesValues.v2())[i]);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("unknown type:" + outputType);
                }
            }
            dataset.add(example);
        }
        SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
        return new MutableDataset((DataSource)new ListDataSource(dataset, outputFactory, (DataSourceProvenance)provenance));
    }

    public static <T extends Output<T>> MutableDataset<T> generateDatasetWithTarget(DataFrame dataFrame, OutputFactory<T> outputFactory, String desc, TribuoOutputType outputType, String target) {
        if (StringUtils.isEmpty((CharSequence)target)) {
            throw new IllegalArgumentException("Empty target when generating dataset from data frame.");
        }
        ArrayList<ArrayExample> dataset = new ArrayList<ArrayExample>();
        Tuple<String[], double[][]> featureNamesValues = TribuoUtil.transformDataFrame(dataFrame);
        int targetIndex = -1;
        for (int i = 0; i < ((String[])featureNamesValues.v1()).length; ++i) {
            if (!((String[])featureNamesValues.v1())[i].equals(target)) continue;
            targetIndex = i;
            break;
        }
        if (targetIndex == -1) {
            throw new IllegalArgumentException("No matched target when generating dataset from data frame.");
        }
        int finalTargetIndex = targetIndex;
        String[] featureNames = (String[])IntStream.range(0, ((String[])featureNamesValues.v1()).length).filter(e -> e != finalTargetIndex).mapToObj(e -> ((String[])featureNamesValues.v1())[e]).toArray(String[]::new);
        for (int i = 0; i < dataFrame.size(); ++i) {
            ArrayExample example;
            switch (outputType) {
                case REGRESSOR: {
                    int finalI = i;
                    double targetValue = ((double[][])featureNamesValues.v2())[finalI][finalTargetIndex];
                    double[] featureValues = IntStream.range(0, ((double[][])featureNamesValues.v2())[i].length).filter(e -> e != finalTargetIndex).mapToDouble(e -> ((double[][])featureNamesValues.v2())[finalI][e]).toArray();
                    example = new ArrayExample((Output)new Regressor(target, targetValue), featureNames, featureValues);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("unknown type:" + outputType);
                }
            }
            dataset.add(example);
        }
        SimpleDataSourceProvenance provenance = new SimpleDataSourceProvenance(desc, outputFactory);
        return new MutableDataset((DataSource)new ListDataSource(dataset, outputFactory, (DataSourceProvenance)provenance));
    }

    @Generated
    private TribuoUtil() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}

