/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.executor;

import com.amazon.randomcutforest.ComponentList;
import com.amazon.randomcutforest.IMultiVisitorFactory;
import com.amazon.randomcutforest.IVisitorFactory;
import com.amazon.randomcutforest.executor.AbstractForestTraversalExecutor;
import com.amazon.randomcutforest.returntypes.ConvergingAccumulator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;

public class ParallelForestTraversalExecutor
extends AbstractForestTraversalExecutor {
    private ForkJoinPool forkJoinPool;
    private final int threadPoolSize;

    public ParallelForestTraversalExecutor(ComponentList<?, ?> treeExecutors, int threadPoolSize) {
        super(treeExecutors);
        this.threadPoolSize = threadPoolSize;
        this.forkJoinPool = new ForkJoinPool(threadPoolSize);
    }

    @Override
    public <R, S> S traverseForest(double[] point, IVisitorFactory<R> visitorFactory, BinaryOperator<R> accumulator, Function<R, S> finisher) {
        return (S)this.submitAndJoin(() -> this.components.parallelStream().map(c -> c.traverse(point, visitorFactory)).reduce(accumulator).map(finisher)).orElseThrow(() -> new IllegalStateException("accumulator returned an empty result"));
    }

    @Override
    public <R, S> S traverseForest(double[] point, IVisitorFactory<R> visitorFactory, Collector<R, ?, S> collector) {
        return (S)this.submitAndJoin(() -> this.components.parallelStream().map(c -> c.traverse(point, visitorFactory)).collect(collector));
    }

    @Override
    public <R, S> S traverseForest(double[] point, IVisitorFactory<R> visitorFactory, ConvergingAccumulator<R> accumulator, Function<R, S> finisher) {
        for (int i = 0; i < this.components.size(); i += this.threadPoolSize) {
            int start = i;
            int end = Math.min(start + this.threadPoolSize, this.components.size());
            List results = this.submitAndJoin(() -> this.components.subList(start, end).parallelStream().map(c -> c.traverse(point, visitorFactory)).collect(Collectors.toList()));
            results.forEach(accumulator::accept);
            if (accumulator.isConverged()) break;
        }
        return finisher.apply(accumulator.getAccumulatedValue());
    }

    @Override
    public <R, S> S traverseForestMulti(double[] point, IMultiVisitorFactory<R> visitorFactory, BinaryOperator<R> accumulator, Function<R, S> finisher) {
        return (S)this.submitAndJoin(() -> this.components.parallelStream().map(c -> c.traverseMulti(point, visitorFactory)).reduce(accumulator).map(finisher)).orElseThrow(() -> new IllegalStateException("accumulator returned an empty result"));
    }

    @Override
    public <R, S> S traverseForestMulti(double[] point, IMultiVisitorFactory<R> visitorFactory, Collector<R, ?, S> collector) {
        return (S)this.submitAndJoin(() -> this.components.parallelStream().map(c -> c.traverseMulti(point, visitorFactory)).collect(collector));
    }

    private <T> T submitAndJoin(Callable<T> callable) {
        if (this.forkJoinPool == null) {
            this.forkJoinPool = new ForkJoinPool(this.threadPoolSize);
        }
        return (T)((ForkJoinTask)this.forkJoinPool.submit((Callable)callable)).join();
    }
}

