package com.bsc36.project11cs.models;

import com.bsc36.project11cs.config.BasicConfig;

import java.util.*;


public class GeneticAlgo{
    private final KnapsackGA knapsackGA;
    private ArrayList<Individual> populationList = new ArrayList<>();
    private final int population = 300;
    private int maxRounds = BasicConfig.GA_MAX_GENERATION;
    private int previousMax = 0;
    private int plateauCounter = 0;
    private int currentMax = 0;
    private int maxPlateauThreshold = 15;
    private final Random rn = new Random();
    private double mutationRate = 0.2;

    /**
     * Constructor
     * @param knapsackGA the knapsackGA to use
     */
    public GeneticAlgo(KnapsackGA knapsackGA){
        this.knapsackGA = knapsackGA;
    }
    public void setMutationRate(double newMutation){
        this.mutationRate = newMutation;
    }

    /**
     * Runs the genetic algorithm (Entry point)
     */
    public void run(){
        // System.out.println("Running genetic algorithm ...");
        knapsackGA.cargoSpace.clearCargoSpace();
        int counter = 0;
        intializePopulation();
        while(counter < maxRounds){
            sortIndividuals(populationList);
            newGeneration();
            counter++;
            System.out.println(populationList.get(0).getScore()); 
            if(plateaus()){
                reInitializePopulation();
            }
        }
        knapsackGA.updateCargoSpace(populationList.get(0).getChromosomes());
        System.out.println(populationList.get(0).getChromosomes().size());
    }

    /**
     * Initializes the population
     */
    private void intializePopulation(){
        populationList.clear();
        for(int i = 0; i < population; i++){
            Individual individual = new Individual(this.knapsackGA, mutationRate);
            populationList.add(individual);
        }
    }

    /**
     * Sorts the individuals in the population
     * @param populationList the population to sort
     */
    private void sortIndividuals(ArrayList<Individual> populationList) {
        Collections.sort(populationList, new Comparator<Individual>() {
            @Override
            public int compare(Individual individual1, Individual individual2) {
                return Integer.compare(individual2.getScore(), individual1.getScore());
            }
        });
    }
    
    /**
     * Reinitializes the population
     */
    private void reInitializePopulation(){
        populationList.subList((int) (populationList.size()*0.15), populationList.size()).clear();
        for(int i = populationList.size(); i < population; i++){
            populationList.add(new Individual(knapsackGA, mutationRate));
        }
    }

    /**
     * Creates a new generation of individuals
     */
    private void newGeneration(){
        ArrayList<Individual> newGeneration = new ArrayList<>();

        for(int i = 0; i < populationList.size() * 0.1; i++){
            newGeneration.add(populationList.get(i));
        }
        for(int i = (int) (populationList.size()*0.1); i < populationList.size(); i ++){
            Individual parent1 = populationList.get(rn.nextInt((int) (populationList.size()/2)));
            Individual parent2 = populationList.get(rn.nextInt((int) (populationList.size()/2)));
            
            Individual child = parent1.twoPointCrossover(parent2);
            newGeneration.add(child);
        }
        populationList = newGeneration;
    }

    /**
     * Checks if the population has plateaued
     * @return true if the population has plateaued
     */
    private boolean plateaus(){
        currentMax = populationList.get(0).getScore();
        if (currentMax <= previousMax) {
            plateauCounter++;
        } else {
            plateauCounter = 0;
        }
        double variance = calculateFitnessVariance();
        previousMax = currentMax;

        if (variance < 10.0){
            setMutationRate(Math.min(0.5, mutationRate*2));
        }
        else if(variance > 10.0){
            setMutationRate(0.2);
        }
        if(plateauCounter >= maxPlateauThreshold){
            plateauCounter = 0;
            return true;
        }
        else{
            return false;
        }
    }
    private double calculateFitnessVariance() {
        double mean = populationList.stream()
                                    .mapToInt(Individual::getScore)
                                    .average()
                                    .orElse(0.0);
        double variance = populationList.stream()
                                        .mapToDouble(ind -> Math.pow(ind.getScore() - mean, 2))
                                        .average()
                                        .orElse(0.0);
        return variance;
    }
}