/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.training;

import java.io.IOException;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.common.ValidationException;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.training.TrainingJob;
import org.opensearch.threadpool.ThreadPool;

public class TrainingJobRunner {
    public static Logger logger = LogManager.getLogger(TrainingJobRunner.class);
    private static TrainingJobRunner INSTANCE;
    private static ModelDao modelDao;
    private static ThreadPool threadPool;
    private final Semaphore semaphore;
    private final AtomicInteger jobCount = new AtomicInteger(0);

    public static synchronized TrainingJobRunner getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new TrainingJobRunner();
        }
        return INSTANCE;
    }

    private TrainingJobRunner() {
        this.semaphore = new Semaphore(1);
    }

    public static void initialize(ThreadPool threadPool, ModelDao modelDao) {
        TrainingJobRunner.threadPool = threadPool;
        TrainingJobRunner.modelDao = modelDao;
    }

    public void execute(TrainingJob trainingJob, ActionListener<IndexResponse> listener) throws IOException {
        if (!this.semaphore.tryAcquire()) {
            ValidationException exception2 = new ValidationException();
            exception2.addValidationError("Unable to run training job: No training capacity on node.");
            KNNCounter.TRAINING_ERRORS.increment();
            throw exception2;
        }
        this.jobCount.incrementAndGet();
        try {
            this.serializeModel(trainingJob, (ActionListener<IndexResponse>)ActionListener.wrap(indexResponse -> {
                listener.onResponse(indexResponse);
                this.train(trainingJob);
            }, exception -> {
                this.jobCount.decrementAndGet();
                this.semaphore.release();
                logger.error("Unable to initialize model serialization: " + exception.getMessage());
                listener.onFailure(exception);
            }), false);
        }
        catch (IOException ioe) {
            this.jobCount.decrementAndGet();
            this.semaphore.release();
            throw ioe;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void train(TrainingJob trainingJob) {
        ActionListener loggingListener = ActionListener.wrap(indexResponse -> logger.debug("[KNN] Model serialization update for \"" + trainingJob.getModelId() + "\" was successful"), e -> {
            logger.error("[KNN] Model serialization update for \"" + trainingJob.getModelId() + "\" failed: " + e.getMessage());
            KNNCounter.TRAINING_ERRORS.increment();
        });
        try {
            threadPool.executor("training").execute(() -> {
                try {
                    trainingJob.run();
                    this.serializeModel(trainingJob, (ActionListener<IndexResponse>)loggingListener, true);
                }
                catch (IOException e) {
                    logger.error("Unable to serialize model \"" + trainingJob.getModelId() + "\": " + e.getMessage());
                    KNNCounter.TRAINING_ERRORS.increment();
                }
                catch (Exception e) {
                    logger.error("Unable to complete training for \"" + trainingJob.getModelId() + "\": " + e.getMessage());
                    KNNCounter.TRAINING_ERRORS.increment();
                }
                finally {
                    this.jobCount.decrementAndGet();
                    this.semaphore.release();
                }
            });
        }
        catch (RejectedExecutionException ree) {
            logger.error("Unable to train model \"" + trainingJob.getModelId() + "\": " + ree.getMessage());
            ModelMetadata modelMetadata = trainingJob.getModel().getModelMetadata();
            modelMetadata.setState(ModelState.FAILED);
            modelMetadata.setError("Training job execution was rejected. Node's training queue is at capacity.");
            try {
                this.serializeModel(trainingJob, (ActionListener<IndexResponse>)loggingListener, true);
            }
            catch (IOException ioe) {
                logger.error("Unable to serialize the failure for model \"" + trainingJob.getModelId() + "\": " + ioe);
            }
            finally {
                this.jobCount.decrementAndGet();
                this.semaphore.release();
                KNNCounter.TRAINING_ERRORS.increment();
            }
        }
    }

    private void serializeModel(TrainingJob trainingJob, ActionListener<IndexResponse> listener, boolean update) throws IOException {
        if (update) {
            modelDao.update(trainingJob.getModel(), listener);
        } else {
            modelDao.put(trainingJob.getModel(), listener);
        }
    }

    public int getJobCount() {
        return this.jobCount.get();
    }
}

