/*
 * Decompiled with CFR 0.152.
 */
package dev.aisandbox.server.simulation.bandit;

import dev.aisandbox.server.engine.Agent;
import dev.aisandbox.server.engine.Simulation;
import dev.aisandbox.server.engine.SimulationBuilder;
import dev.aisandbox.server.engine.SimulationParameter;
import dev.aisandbox.server.engine.Theme;
import dev.aisandbox.server.simulation.bandit.BanditRuntime;
import dev.aisandbox.server.simulation.bandit.model.BanditCountEnumeration;
import dev.aisandbox.server.simulation.bandit.model.BanditNormalEnumeration;
import dev.aisandbox.server.simulation.bandit.model.BanditPullEnumeration;
import dev.aisandbox.server.simulation.bandit.model.BanditStdEnumeration;
import dev.aisandbox.server.simulation.bandit.model.BanditUpdateEnumeration;
import java.util.List;
import java.util.Random;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class BanditScenario
implements SimulationBuilder {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(BanditScenario.class);
    private BanditPullEnumeration banditPulls = BanditPullEnumeration.ONE_HUNDRED;
    private BanditNormalEnumeration banditNormal = BanditNormalEnumeration.NORMAL_0_1;
    private BanditStdEnumeration banditStd = BanditStdEnumeration.ONE;
    private BanditUpdateEnumeration banditUpdate = BanditUpdateEnumeration.FIXED;
    private BanditCountEnumeration banditCount = BanditCountEnumeration.FIVE;

    @Override
    public String getSimulationName() {
        return "Bandit";
    }

    @Override
    public String getDescription() {
        return "The classic 'Multi-Armed Bandit scenario where an agent needs to learn which 'bandit' returns the best results.";
    }

    @Override
    public List<SimulationParameter> getParameters() {
        return List.of(new SimulationParameter("banditCount", "The number of bandits", BanditCountEnumeration.class), new SimulationParameter("banditUpdate", "How bandits change between pulls", BanditUpdateEnumeration.class), new SimulationParameter("banditStd", "How the standard deviation for each bandit is chosen", BanditStdEnumeration.class), new SimulationParameter("banditNormal", "How the normal (average) for each bandit is chosen", BanditNormalEnumeration.class), new SimulationParameter("banditPulls", "The number of bandit 'pulls' in each episode", BanditPullEnumeration.class));
    }

    @Override
    public int getMinAgentCount() {
        return 1;
    }

    @Override
    public int getMaxAgentCount() {
        return 1;
    }

    @Override
    public String[] getAgentNames(int agentCount) {
        return new String[]{"Agent 1"};
    }

    @Override
    public Simulation build(List<Agent> agents, Theme theme, Random random) {
        return new BanditRuntime(agents.getFirst(), random, this.banditCount.getNumber(), this.banditPulls.getNumber(), this.banditNormal, this.banditStd, this.banditUpdate, theme);
    }

    @Generated
    public void setBanditPulls(BanditPullEnumeration banditPulls) {
        this.banditPulls = banditPulls;
    }

    @Generated
    public void setBanditNormal(BanditNormalEnumeration banditNormal) {
        this.banditNormal = banditNormal;
    }

    @Generated
    public void setBanditStd(BanditStdEnumeration banditStd) {
        this.banditStd = banditStd;
    }

    @Generated
    public void setBanditUpdate(BanditUpdateEnumeration banditUpdate) {
        this.banditUpdate = banditUpdate;
    }

    @Generated
    public void setBanditCount(BanditCountEnumeration banditCount) {
        this.banditCount = banditCount;
    }

    @Generated
    public BanditPullEnumeration getBanditPulls() {
        return this.banditPulls;
    }

    @Generated
    public BanditNormalEnumeration getBanditNormal() {
        return this.banditNormal;
    }

    @Generated
    public BanditStdEnumeration getBanditStd() {
        return this.banditStd;
    }

    @Generated
    public BanditUpdateEnumeration getBanditUpdate() {
        return this.banditUpdate;
    }

    @Generated
    public BanditCountEnumeration getBanditCount() {
        return this.banditCount;
    }
}

