/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.ensemble;

import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.ONNXExportable;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.ensemble.EnsembleCombiner;
import org.tribuo.ensemble.EnsembleExcuse;
import org.tribuo.ensemble.EnsembleModel;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TimestampedTrainerProvenance;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXInitializer;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
import org.tribuo.util.onnx.ONNXPlaceholder;
import org.tribuo.util.onnx.ONNXRef;

public final class WeightedEnsembleModel<T extends Output<T>>
extends EnsembleModel<T>
implements ONNXExportable {
    private static final long serialVersionUID = 1L;
    protected final float[] weights;
    protected final EnsembleCombiner<T> combiner;

    public WeightedEnsembleModel(String name, EnsembleModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Model<T>> newModels, EnsembleCombiner<T> combiner) {
        this(name, provenance, featureIDMap, outputIDInfo, newModels, combiner, Util.generateUniformVector(newModels.size(), 1.0f / (float)newModels.size()));
    }

    public WeightedEnsembleModel(String name, EnsembleModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, List<Model<T>> newModels, EnsembleCombiner<T> combiner, float[] weights) {
        super(name, provenance, featureIDMap, outputIDInfo, newModels);
        this.weights = Arrays.copyOf(weights, weights.length);
        this.combiner = combiner;
    }

    @Override
    public Prediction<T> predict(Example<T> example) {
        ArrayList predictions = new ArrayList();
        for (Model model : this.models) {
            predictions.add(model.predict(example));
        }
        return this.combiner.combine(this.outputIDInfo, predictions, this.weights);
    }

    @Override
    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        HashMap<String, Map> map = new HashMap<String, Map>();
        Prediction<T> prediction = this.predict(example);
        ArrayList excuses = new ArrayList();
        for (int i = 0; i < this.models.size(); ++i) {
            Optional<Excuse<T>> excuse = ((Model)this.models.get(i)).getExcuse(example);
            if (!excuse.isPresent()) continue;
            excuses.add(excuse.get());
            Map<String, List<Pair<String, Double>>> m = excuse.get().getScores();
            for (Map.Entry<String, List<Pair<String, Double>>> e : m.entrySet()) {
                Map innerMap = map.computeIfAbsent(e.getKey(), k -> new HashMap());
                for (Pair<String, Double> p : e.getValue()) {
                    innerMap.merge(p.getA(), (Double)p.getB() * (double)this.weights[i], Double::sum);
                }
            }
        }
        if (map.isEmpty()) {
            return Optional.empty();
        }
        HashMap<String, List<Pair<String, Double>>> outputMap = new HashMap<String, List<Pair<String, Double>>>();
        for (Map.Entry label : map.entrySet()) {
            ArrayList<Pair> list = new ArrayList<Pair>();
            for (Map.Entry entry : ((Map)label.getValue()).entrySet()) {
                list.add(new Pair(entry.getKey(), entry.getValue()));
            }
            list.sort((o1, o2) -> ((Double)o2.getB()).compareTo((Double)o1.getB()));
            outputMap.put((String)label.getKey(), (List<Pair<String, Double>>)list);
        }
        return Optional.of(new EnsembleExcuse<T>(example, prediction, outputMap, excuses));
    }

    @Override
    protected EnsembleModel<T> copy(String name, EnsembleModelProvenance newProvenance, List<Model<T>> newModels) {
        return new WeightedEnsembleModel<T>(name, newProvenance, this.featureIDMap, this.outputIDInfo, newModels, this.combiner);
    }

    public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner) {
        return WeightedEnsembleModel.createEnsembleFromExistingModels(name, models, combiner, Util.generateUniformVector(models.size(), 1.0f / (float)models.size()));
    }

    /*
     * WARNING - void declaration
     */
    public static <T extends Output<T>> WeightedEnsembleModel<T> createEnsembleFromExistingModels(String name, List<Model<T>> models, EnsembleCombiner<T> combiner, float[] weights) {
        void var7_9;
        if (models.size() < 2) {
            throw new IllegalArgumentException("Must supply at least 2 models, found " + models.size());
        }
        if (weights.length != models.size()) {
            throw new IllegalArgumentException("Must supply one weight per model, models.size() = " + models.size() + ", weights.length = " + weights.length);
        }
        ImmutableOutputInfo<T> outputInfo = models.get(0).getOutputIDInfo();
        ArrayList<Pair> firstList = new ArrayList<Pair>();
        for (Pair pair : outputInfo) {
            firstList.add(pair);
        }
        ArrayList<Pair> comparisonList = new ArrayList<Pair>();
        boolean bl = true;
        while (var7_9 < models.size()) {
            comparisonList.clear();
            for (Pair pair : models.get((int)var7_9).getOutputIDInfo()) {
                comparisonList.add(pair);
            }
            if (!firstList.equals(comparisonList)) {
                throw new IllegalArgumentException("Model output domains are not equal.");
            }
            ++var7_9;
        }
        ImmutableFeatureMap immutableFeatureMap = models.get(0).getFeatureIDMap();
        ArrayList<Model<T>> modelList = new ArrayList<Model<T>>(models);
        TimestampedTrainerProvenance timestampedTrainerProvenance = new TimestampedTrainerProvenance();
        EnsembleModelProvenance provenance = new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), models.get(0).getProvenance().getDatasetProvenance(), (TrainerProvenance)timestampedTrainerProvenance, (ListProvenance<? extends ModelProvenance>)ListProvenance.createListProvenance(models));
        return new WeightedEnsembleModel<T>(name, provenance, immutableFeatureMap, outputInfo, modelList, combiner, weights);
    }

    @Override
    public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
        ONNXContext onnx = new ONNXContext();
        onnx.setName("WeightedEnsembleModel");
        ONNXPlaceholder input = onnx.floatInput(this.featureIDMap.size());
        ONNXPlaceholder output = onnx.floatOutput(this.outputIDInfo.size());
        this.writeONNXGraph((ONNXRef<?>)input).assignTo((ONNXRef)output);
        return ONNXExportable.buildModel(onnx, domain, modelVersion, this);
    }

    @Override
    public ONNXNode writeONNXGraph(ONNXRef<?> input) {
        ONNXContext onnx = input.onnxContext();
        ONNXInitializer unsqueezeAxes = onnx.array("unsqueeze_ensemble_output", new long[]{2L});
        ArrayList<ONNXNode> unsquuezedMembers = new ArrayList<ONNXNode>();
        for (Model model : this.models) {
            if (model instanceof ONNXExportable) {
                ONNXNode memberOutput = ((ONNXExportable)((Object)model)).writeONNXGraph(input);
                unsquuezedMembers.add(memberOutput.apply(ONNXOperators.UNSQUEEZE, (ONNXRef)unsqueezeAxes));
                continue;
            }
            throw new IllegalStateException("Ensemble member '" + model.toString() + "' is not ONNXExportable.");
        }
        ONNXInitializer ensembleWeights = onnx.array("ensemble_weights", this.weights);
        ONNXNode concat = onnx.operation(ONNXOperators.CONCAT, unsquuezedMembers, "ensemble_concat", Collections.singletonMap("axis", 2));
        return this.combiner.exportCombiner(concat, ensembleWeights);
    }
}

