/*
 * Decompiled with CFR 0.152.
 */
package dev.aisandbox.server.simulation.bandit;

import dev.aisandbox.server.engine.Agent;
import dev.aisandbox.server.engine.Simulation;
import dev.aisandbox.server.engine.Theme;
import dev.aisandbox.server.engine.exception.IllegalActionException;
import dev.aisandbox.server.engine.exception.SimulationException;
import dev.aisandbox.server.engine.output.OutputConstants;
import dev.aisandbox.server.engine.output.OutputRenderer;
import dev.aisandbox.server.engine.widget.RollingStatisticsWidget;
import dev.aisandbox.server.engine.widget.RollingValueChartWidget;
import dev.aisandbox.server.engine.widget.TextWidget;
import dev.aisandbox.server.engine.widget.TitleWidget;
import dev.aisandbox.server.simulation.bandit.BanditWidget;
import dev.aisandbox.server.simulation.bandit.model.Bandit;
import dev.aisandbox.server.simulation.bandit.model.BanditNormalEnumeration;
import dev.aisandbox.server.simulation.bandit.model.BanditStdEnumeration;
import dev.aisandbox.server.simulation.bandit.model.BanditUpdateEnumeration;
import dev.aisandbox.server.simulation.bandit.proto.BanditAction;
import dev.aisandbox.server.simulation.bandit.proto.BanditResult;
import dev.aisandbox.server.simulation.bandit.proto.BanditState;
import dev.aisandbox.server.simulation.bandit.proto.Signal;
import java.awt.Graphics2D;
import java.awt.Image;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.stream.IntStream;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class BanditRuntime
implements Simulation {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(BanditRuntime.class);
    private static final int BANDIT_WIDTH = 1370;
    private static final int BANDIT_HEIGHT = 518;
    private static final int RESULTS_WIDTH = 573;
    private static final int RESULTS_HEIGHT = 312;
    private final Agent agent;
    private final Random random;
    private final int banditCount;
    private final int pullCount;
    private final BanditNormalEnumeration normal;
    private final BanditStdEnumeration std;
    private final BanditUpdateEnumeration updateRule;
    private final Theme theme;
    private final String sessionID = UUID.randomUUID().toString();
    private final List<Bandit> bandits = new ArrayList<Bandit>();
    private final TextWidget logWidget;
    private final BanditWidget banditWidget;
    private final RollingValueChartWidget episodeScoreWidget;
    private final RollingValueChartWidget episodeSuccessWidget;
    private final RollingStatisticsWidget statisticsWidget;
    private final TitleWidget titleWidget;
    private int sessionStep = 0;
    private double episodeScore = 0.0;
    private double episodeBestMoveCount = 0.0;
    private String episodeID = UUID.randomUUID().toString();

    public BanditRuntime(Agent agent, Random random, int banditCount, int pullCount, BanditNormalEnumeration normal, BanditStdEnumeration std, BanditUpdateEnumeration updateRule, Theme theme) {
        this.agent = agent;
        this.random = random;
        this.banditCount = banditCount;
        this.pullCount = pullCount;
        this.normal = normal;
        this.std = std;
        this.updateRule = updateRule;
        this.theme = theme;
        this.titleWidget = TitleWidget.builder().theme(theme).title("Multi-armed Bandit").build();
        this.logWidget = TextWidget.builder().width(573).height(312).font(OutputConstants.LOG_FONT).theme(theme).build();
        this.banditWidget = BanditWidget.builder().width(1370).height(518).theme(theme).build();
        this.episodeScoreWidget = RollingValueChartWidget.builder().width(573).height(312).window(200).theme(theme).title("Score per episode").build();
        this.episodeSuccessWidget = RollingValueChartWidget.builder().width(573).height(312).window(200).theme(theme).title("% best moves per episode").build();
        this.statisticsWidget = RollingStatisticsWidget.builder().width(400).height(518).theme(theme).windowSize(200).build();
        this.initialise();
    }

    public void initialise() {
        this.bandits.clear();
        for (int i = 0; i < this.banditCount; ++i) {
            this.bandits.add(new Bandit(this.normal.getNormalValue(this.random), this.std.getValue()));
        }
        this.sessionStep = 0;
        this.episodeID = UUID.randomUUID().toString();
        this.episodeScore = 0.0;
        this.episodeBestMoveCount = 0.0;
    }

    @Override
    public void step(OutputRenderer output) throws SimulationException {
        ++this.sessionStep;
        log.debug("Starting step {}", (Object)this.sessionStep);
        int bestPull = IntStream.range(0, this.bandits.size()).boxed().max(Comparator.comparingDouble(i -> this.bandits.get((int)i).getStd())).orElse(-1);
        this.agent.send(this.getState());
        BanditAction action = this.agent.receive(BanditAction.class);
        int arm = action.getArm();
        log.debug("Received request to pull arm {}", (Object)arm);
        if (arm < 0 || arm >= this.bandits.size()) {
            throw new IllegalActionException("Invalid arm.");
        }
        if (arm == bestPull) {
            this.episodeBestMoveCount += 1.0;
        }
        double score = this.bandits.get(arm).pull(this.random);
        this.episodeScore += score;
        this.logWidget.addText(this.agent.getAgentName() + " selects bandit " + arm + " gets reward " + String.format("%.4f", score));
        boolean reset = this.sessionStep == this.pullCount;
        this.banditWidget.setBandits(this.bandits, arm);
        if (reset) {
            this.episodeScoreWidget.addValue(this.episodeScore);
            this.episodeSuccessWidget.addValue(this.episodeBestMoveCount / (double)this.pullCount);
            this.statisticsWidget.addScore(this.episodeScore);
        }
        output.display();
        this.agent.send(BanditResult.newBuilder().setArm(arm).setScore(score).setSignal(reset ? Signal.RESET : Signal.CONTINUE).build());
        if (reset) {
            this.logWidget.addText("pull count reached, starting a new game.");
            this.initialise();
        } else {
            switch (this.updateRule) {
                case RANDOM: {
                    this.updateRandom();
                    break;
                }
                case EQUALISE: {
                    this.updateEqualise(arm);
                    break;
                }
                case FADE: {
                    this.updateFade(arm);
                    break;
                }
            }
        }
    }

    private BanditState getState() {
        BanditState.Builder builder = BanditState.newBuilder();
        builder.setSessionID(this.sessionID);
        builder.setEpisodeID(this.episodeID);
        builder.setBanditCount(this.bandits.size());
        builder.setPull(this.sessionStep);
        builder.setPullCount(this.pullCount);
        return builder.build();
    }

    public void updateRandom() {
        for (Bandit b : this.bandits) {
            b.setMean(b.getMean() + this.random.nextGaussian() * 0.001);
        }
    }

    public void updateEqualise(int chosen) {
        double reward = 0.001 / (double)this.bandits.size();
        for (int i = 0; i < this.bandits.size(); ++i) {
            Bandit b = this.bandits.get(i);
            if (i == chosen) {
                b.setMean(b.getMean() - 0.001);
                continue;
            }
            b.setMean(b.getMean() + reward);
        }
    }

    public void updateFade(int chosen) {
        Bandit target = this.bandits.get(chosen);
        target.setMean(target.getMean() - 0.001);
    }

    @Override
    public void visualise(Graphics2D graphics2D) {
        graphics2D.setColor(this.theme.getBase());
        graphics2D.fillRect(0, 0, 1920, 1080);
        graphics2D.drawImage((Image)this.titleWidget.getImage(), 0, 50, null);
        graphics2D.drawImage((Image)this.banditWidget.getImage(), 50, 150, null);
        graphics2D.drawImage((Image)this.statisticsWidget.getImage(), 1470, 150, null);
        graphics2D.drawImage((Image)this.episodeScoreWidget.getImage(), 50, 718, null);
        graphics2D.drawImage((Image)this.episodeSuccessWidget.getImage(), 673, 718, null);
        graphics2D.drawImage((Image)this.logWidget.getImage(), 1296, 718, null);
        graphics2D.drawImage((Image)this.theme.getLogoImage(), 1779, 21, null);
    }
}

