/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.hyperparameter;

import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingResult;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.hyperparameter.optimizer.HpORandom;
import ai.djl.training.hyperparameter.param.HpSet;
import ai.djl.translate.TranslateException;
import ai.djl.util.Pair;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class EasyHpo {
    private static final Logger logger = LoggerFactory.getLogger(EasyHpo.class);

    public Pair<Model, TrainingResult> fit() throws IOException, TranslateException {
        RandomAccessDataset trainingSet = this.getDataset(Dataset.Usage.TRAIN);
        RandomAccessDataset validateSet = this.getDataset(Dataset.Usage.TEST);
        HpSet hyperParams = this.setupHyperParams();
        HpORandom hpOptimizer = new HpORandom(hyperParams);
        int hyperparameterTests = this.numHyperParameterTests();
        for (int i = 0; i < hyperparameterTests; ++i) {
            HpSet hpVals = hpOptimizer.nextConfig();
            Pair<Model, TrainingResult> trained = this.train(hpVals, trainingSet, validateSet);
            trained.getKey().close();
            float loss = trained.getValue().getValidateLoss().floatValue();
            hpOptimizer.update(hpVals, loss);
            logger.info("--------- hp test {}/{} - Loss {} - {}", new Object[]{i, hyperparameterTests, Float.valueOf(loss), hpVals});
        }
        HpSet bestHpVals = hpOptimizer.getBest().getKey();
        Pair<Model, TrainingResult> trained = this.train(bestHpVals, trainingSet, validateSet);
        TrainingResult result = trained.getValue();
        Model model = trained.getKey();
        this.saveModel(model, result);
        return trained;
    }

    private Pair<Model, TrainingResult> train(HpSet hpVals, RandomAccessDataset trainingSet, RandomAccessDataset validateSet) throws IOException, TranslateException {
        Model model = this.buildModel(hpVals);
        TrainingConfig config = this.setupTrainingConfig(hpVals);
        try (Trainer trainer = model.newTrainer(config);){
            trainer.setMetrics(new Metrics());
            trainer.initialize(this.inputShape(hpVals));
            EasyTrain.fit(trainer, this.numEpochs(hpVals), trainingSet, validateSet);
            TrainingResult result = trainer.getTrainingResult();
            Pair<Model, TrainingResult> pair = new Pair<Model, TrainingResult>(model, result);
            return pair;
        }
    }

    protected abstract HpSet setupHyperParams();

    protected abstract RandomAccessDataset getDataset(Dataset.Usage var1) throws IOException;

    protected abstract TrainingConfig setupTrainingConfig(HpSet var1);

    protected abstract Model buildModel(HpSet var1);

    protected abstract Shape inputShape(HpSet var1);

    protected abstract int numEpochs(HpSet var1);

    protected abstract int numHyperParameterTests();

    protected void saveModel(Model model, TrainingResult result) throws IOException {
    }
}

