/*
 * Decompiled with CFR 0.152.
 */
package org.apache.druid.indexing.seekablestream.supervisor.autoscaler;

import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.druid.indexing.overlord.supervisor.SupervisorSpec;
import org.apache.druid.indexing.overlord.supervisor.autoscaler.LagStats;
import org.apache.druid.indexing.overlord.supervisor.autoscaler.SupervisorTaskAutoScaler;
import org.apache.druid.indexing.seekablestream.supervisor.SeekableStreamSupervisor;
import org.apache.druid.indexing.seekablestream.supervisor.autoscaler.CostBasedAutoScalerConfig;
import org.apache.druid.indexing.seekablestream.supervisor.autoscaler.CostMetrics;
import org.apache.druid.indexing.seekablestream.supervisor.autoscaler.CostResult;
import org.apache.druid.indexing.seekablestream.supervisor.autoscaler.WeightedCostFunction;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.emitter.EmittingLogger;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
import org.apache.druid.java.util.emitter.service.ServiceEventBuilder;
import org.apache.druid.java.util.emitter.service.ServiceMetricEvent;

public class CostBasedAutoScaler
implements SupervisorTaskAutoScaler {
    private static final EmittingLogger log = new EmittingLogger(CostBasedAutoScaler.class);
    private static final int MAX_INCREASE_IN_PARTITIONS_PER_TASK = 2;
    private static final int MAX_DECREASE_IN_PARTITIONS_PER_TASK = 4;
    public static final String LAG_COST_METRIC = "task/autoScaler/costBased/lagCost";
    public static final String IDLE_COST_METRIC = "task/autoScaler/costBased/idleCost";
    public static final String OPTIMAL_TASK_COUNT_METRIC = "task/autoScaler/costBased/optimalTaskCount";
    private final String supervisorId;
    private final SeekableStreamSupervisor supervisor;
    private final ServiceEmitter emitter;
    private final SupervisorSpec spec;
    private final CostBasedAutoScalerConfig config;
    private final ServiceMetricEvent.Builder metricBuilder;
    private final ScheduledExecutorService autoscalerExecutor;
    private final WeightedCostFunction costFunction;
    private volatile CostMetrics lastKnownMetrics;

    public CostBasedAutoScaler(SeekableStreamSupervisor supervisor, CostBasedAutoScalerConfig config, SupervisorSpec spec, ServiceEmitter emitter) {
        this.config = config;
        this.spec = spec;
        this.supervisor = supervisor;
        this.supervisorId = spec.getId();
        this.emitter = emitter;
        this.costFunction = new WeightedCostFunction();
        this.autoscalerExecutor = Execs.scheduledSingleThreaded((String)("CostBasedAutoScaler-" + StringUtils.encodeForFormat((String)spec.getId())));
        this.metricBuilder = ServiceMetricEvent.builder().setDimension("supervisorId", (Object)this.supervisorId).setDimension("stream", (Object)this.supervisor.getIoConfig().getStream());
    }

    public void start() {
        this.autoscalerExecutor.scheduleAtFixedRate(this.supervisor.buildDynamicAllocationTask(this::computeTaskCountForScaleAction, () -> {}, this.emitter), this.config.getScaleActionPeriodMillis(), this.config.getScaleActionPeriodMillis(), TimeUnit.MILLISECONDS);
        log.info("CostBasedAutoScaler started for supervisorId[%s]: evaluating scaling every [%d]ms", new Object[]{this.supervisorId, this.config.getScaleActionPeriodMillis()});
    }

    public void stop() {
        this.autoscalerExecutor.shutdownNow();
        log.info("CostBasedAutoScaler stopped for supervisorId [%s]", new Object[]{this.supervisorId});
    }

    public void reset() {
    }

    public int computeTaskCountForRollover() {
        return this.computeOptimalTaskCount(this.lastKnownMetrics);
    }

    public int computeTaskCountForScaleAction() {
        int currentTaskCount;
        this.lastKnownMetrics = this.collectMetrics();
        int optimalTaskCount = this.computeOptimalTaskCount(this.lastKnownMetrics);
        return optimalTaskCount >= (currentTaskCount = this.lastKnownMetrics.getCurrentTaskCount()) ? optimalTaskCount : -1;
    }

    public CostBasedAutoScalerConfig getConfig() {
        return this.config;
    }

    int computeOptimalTaskCount(CostMetrics metrics) {
        if (metrics == null) {
            log.debug("No metrics available yet for supervisorId [%s]", new Object[]{this.supervisorId});
            return -1;
        }
        int partitionCount = metrics.getPartitionCount();
        int currentTaskCount = metrics.getCurrentTaskCount();
        if (partitionCount <= 0 || currentTaskCount <= 0) {
            return -1;
        }
        int[] validTaskCounts = CostBasedAutoScaler.computeValidTaskCounts(partitionCount, currentTaskCount);
        if (validTaskCounts.length == 0) {
            log.warn("No valid task counts after applying constraints for supervisorId [%s]", new Object[]{this.supervisorId});
            return -1;
        }
        int optimalTaskCount = -1;
        CostResult optimalCost = new CostResult();
        for (int taskCount : validTaskCounts) {
            CostResult costResult = this.costFunction.computeCost(metrics, taskCount, this.config);
            double cost = costResult.totalCost();
            log.debug("Proposed task count: %d, Cost: %.4f (lag: %.4f, idle: %.4f)", new Object[]{taskCount, cost, costResult.lagCost(), costResult.idleCost()});
            if (!(cost < optimalCost.totalCost())) continue;
            optimalTaskCount = taskCount;
            optimalCost = costResult;
        }
        this.emitter.emit((ServiceEventBuilder)this.metricBuilder.setMetric(OPTIMAL_TASK_COUNT_METRIC, (Number)optimalTaskCount));
        this.emitter.emit((ServiceEventBuilder)this.metricBuilder.setMetric(LAG_COST_METRIC, (Number)optimalCost.lagCost()));
        this.emitter.emit((ServiceEventBuilder)this.metricBuilder.setMetric(IDLE_COST_METRIC, (Number)optimalCost.idleCost()));
        log.debug("Cost-based scaling evaluation for supervisorId [%s]: current=%d, optimal=%d, cost=%.4f, avgPartitionLag=%.2f, pollIdleRatio=%.3f", new Object[]{this.supervisorId, metrics.getCurrentTaskCount(), optimalTaskCount, optimalCost.totalCost(), metrics.getAvgPartitionLag(), metrics.getPollIdleRatio()});
        if (optimalTaskCount == currentTaskCount) {
            return -1;
        }
        return optimalTaskCount;
    }

    static int[] computeValidTaskCounts(int partitionCount, int currentTaskCount) {
        int maxPartitionsPerTask;
        if (partitionCount <= 0) {
            return new int[0];
        }
        HashSet<Integer> result = new HashSet<Integer>();
        int currentPartitionsPerTask = partitionCount / currentTaskCount;
        int minPartitionsPerTask = Math.max(1, currentPartitionsPerTask - 2);
        for (int partitionsPerTask = maxPartitionsPerTask = Math.min(partitionCount, currentPartitionsPerTask + 4); partitionsPerTask >= minPartitionsPerTask; --partitionsPerTask) {
            int taskCount = (partitionCount + partitionsPerTask - 1) / partitionsPerTask;
            result.add(taskCount);
        }
        return result.stream().mapToInt(Integer::intValue).toArray();
    }

    static double extractPollIdleRatio(Map<String, Map<String, Object>> taskStats) {
        if (taskStats == null || taskStats.isEmpty()) {
            return 0.0;
        }
        double sum = 0.0;
        int count = 0;
        for (Map<String, Object> groupMetrics : taskStats.values()) {
            for (Object taskMetric : groupMetrics.values()) {
                Object pollIdleRatioAvg;
                Object autoScalerMetricsMap;
                if (!(taskMetric instanceof Map) || !((autoScalerMetricsMap = ((Map)taskMetric).get("autoscalerMetrics")) instanceof Map) || !((pollIdleRatioAvg = ((Map)autoScalerMetricsMap).get("pollIdleRatio")) instanceof Number)) continue;
                sum += ((Number)pollIdleRatioAvg).doubleValue();
                ++count;
            }
        }
        return count > 0 ? sum / (double)count : 0.0;
    }

    static double extractMovingAverage(Map<String, Map<String, Object>> taskStats) {
        if (taskStats == null || taskStats.isEmpty()) {
            return -1.0;
        }
        double sum = 0.0;
        int count = 0;
        for (Map<String, Object> groupMetrics : taskStats.values()) {
            for (Object taskMetric : groupMetrics.values()) {
                Object processedRate;
                Object buildSegmentsObj;
                Object movingAveragesObj;
                if (!(taskMetric instanceof Map) || !((movingAveragesObj = ((Map)taskMetric).get("movingAverages")) instanceof Map) || !((buildSegmentsObj = ((Map)movingAveragesObj).get("buildSegments")) instanceof Map)) continue;
                Object movingAvgObj = ((Map)buildSegmentsObj).get("15m");
                if (movingAvgObj == null && (movingAvgObj = ((Map)buildSegmentsObj).get("5m")) == null) {
                    movingAvgObj = ((Map)buildSegmentsObj).get("1m");
                }
                if (!(movingAvgObj instanceof Map) || !((processedRate = ((Map)movingAvgObj).get("processed")) instanceof Number)) continue;
                sum += ((Number)processedRate).doubleValue();
                ++count;
            }
        }
        return count > 0 ? sum / (double)count : -1.0;
    }

    private CostMetrics collectMetrics() {
        double avgProcessingRate;
        if (this.spec.isSuspended()) {
            log.debug("Supervisor [%s] is suspended, skipping a metrics collection", new Object[]{this.supervisorId});
            return null;
        }
        LagStats lagStats = this.supervisor.computeLagStats();
        if (lagStats == null) {
            log.debug("Lag stats unavailable for supervisorId [%s], skipping collection", new Object[]{this.supervisorId});
            return null;
        }
        int currentTaskCount = this.supervisor.getIoConfig().getTaskCount();
        int partitionCount = this.supervisor.getPartitionCount();
        Map<String, Map<String, Object>> taskStats = this.supervisor.getStats();
        double movingAvgRate = CostBasedAutoScaler.extractMovingAverage(taskStats);
        double pollIdleRatio = CostBasedAutoScaler.extractPollIdleRatio(taskStats);
        double avgPartitionLag = lagStats.getAvgLag();
        if (movingAvgRate > 0.0) {
            avgProcessingRate = movingAvgRate;
        } else {
            double utilizationRatio = Math.max(0.01, 1.0 - pollIdleRatio);
            avgProcessingRate = this.config.getDefaultProcessingRate() * utilizationRatio;
        }
        return new CostMetrics(avgPartitionLag, currentTaskCount, partitionCount, pollIdleRatio, this.supervisor.getIoConfig().getTaskDuration().getStandardSeconds(), avgProcessingRate);
    }
}

