AdaptivePerformanceEstimator.java

package com.morphiqlabs.wavelet.performance;

import java.io.File;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * Adaptive performance estimator that learns from actual execution times.
 * 
 * <p>This class provides performance predictions that improve over time by
 * learning from actual measurements. It automatically recalibrates when
 * the model accuracy degrades.</p>
 */
public class AdaptivePerformanceEstimator {
    private static final System.Logger LOG = com.morphiqlabs.wavelet.util.Logging.getLogger(AdaptivePerformanceEstimator.class);
    
    private static final String MODEL_CACHE_DIR = ".vectorwave/performance";
    private static final String MODEL_FILE = "performance_models.dat";
    private static final int RECALIBRATION_THRESHOLD = 1000; // Measurements before checking
    private static final double ACCURACY_THRESHOLD = 0.85; // Minimum acceptable accuracy
    
    // Opt-in flags (default disabled)
    private final boolean calibrationEnabled;
    private final boolean persistenceEnabled;

    // Singleton instance
    private static final AdaptivePerformanceEstimator INSTANCE = new AdaptivePerformanceEstimator();
    
    private PerformanceCalibrator.CalibratedModels models;
    private final ConcurrentHashMap<String, PerformanceModel> operationModels;
    private final AtomicInteger measurementCount;
    private final AtomicLong lastCalibrationTime;
    private final ExecutorService calibrationExecutor;
    private final AtomicBoolean isCalibrating;
    
    private AdaptivePerformanceEstimator() {
        this.calibrationEnabled = getBoolean(
            "vectorwave.perf.calibration", "vectorwave.perf.calibration.enabled",
            "VECTORWAVE_PERF_CALIBRATION", "VECTORWAVE_PERF_CALIBRATION_ENABLED",
            false
        );
        this.persistenceEnabled = getBoolean(
            "vectorwave.perf.persist", null,
            "VECTORWAVE_PERF_PERSIST", null,
            false
        );
        this.operationModels = new ConcurrentHashMap<>();
        this.measurementCount = new AtomicInteger(0);
        this.lastCalibrationTime = new AtomicLong(System.currentTimeMillis());
        this.isCalibrating = new AtomicBoolean(false);
        
        // Create daemon thread executor for background calibration
        this.calibrationExecutor = Executors.newSingleThreadExecutor(new ThreadFactory() {
            @Override
            public Thread newThread(Runnable r) {
                Thread t = new Thread(r, "PerformanceCalibration");
                t.setDaemon(true); // Daemon thread won't prevent JVM shutdown
                return t;
            }
        });
        
        // Try to load existing models if persistence is enabled
        if (persistenceEnabled) {
            loadModels();
        }
    }
    
    /**
     * Gets the singleton instance.
     * 
     * @return The adaptive performance estimator
     */
    public static AdaptivePerformanceEstimator getInstance() {
        return INSTANCE;
    }
    
    /**
     * Estimates execution time for a MODWT operation.
     * 
     * @param signalLength Length of the input signal
     * @param waveletName Name of the wavelet
     * @param hasVectorization Whether vectorization is available
     * @return Prediction with confidence bounds
     */
    public PredictionResult estimateMODWT(int signalLength, String waveletName, 
                                         boolean hasVectorization) {
        PerformanceModel model = getOrCreateModel("MODWT");
        
        // Adjust for wavelet complexity
        double complexityFactor = getWaveletComplexityFactor(waveletName);
        
        PredictionResult base = model.predict(signalLength, hasVectorization);
        
        return new PredictionResult(
            base.estimatedTime() * complexityFactor,
            base.lowerBound() * complexityFactor,
            base.upperBound() * complexityFactor,
            base.confidence()
        );
    }
    
    /**
     * Estimates execution time for a convolution operation.
     * 
     * @param signalLength Length of the input signal
     * @param filterLength Length of the filter
     * @param hasVectorization Whether vectorization is available
     * @return Prediction with confidence bounds
     */
    public PredictionResult estimateConvolution(int signalLength, int filterLength,
                                              boolean hasVectorization) {
        PerformanceModel model = getOrCreateModel("Convolution");
        
        // Adjust for filter length
        double filterFactor = Math.sqrt(filterLength / 4.0); // Normalized to length 4
        
        PredictionResult base = model.predict(signalLength, hasVectorization);
        
        return new PredictionResult(
            base.estimatedTime() * filterFactor,
            base.lowerBound() * filterFactor,
            base.upperBound() * filterFactor,
            base.confidence()
        );
    }
    
    /**
     * Estimates execution time for batch operations.
     * 
     * @param batchSize Number of signals in the batch
     * @param signalLength Length of each signal
     * @param hasVectorization Whether vectorization is available
     * @return Prediction with confidence bounds
     */
    public PredictionResult estimateBatch(int batchSize, int signalLength,
                                        boolean hasVectorization) {
        PerformanceModel model = getOrCreateModel("Batch");
        
        // Batch efficiency improves with size but has diminishing returns
        double batchEfficiency = 1.0 + Math.log(batchSize) / Math.log(32);
        
        PredictionResult base = model.predict(signalLength, hasVectorization);
        
        return new PredictionResult(
            base.estimatedTime() * batchSize / batchEfficiency,
            base.lowerBound() * batchSize / batchEfficiency,
            base.upperBound() * batchSize / batchEfficiency,
            base.confidence()
        );
    }
    
    /**
     * Records an actual measurement to improve the model.
     * 
     * @param operation The operation type ("MODWT", "Convolution", "Batch")
     * @param inputSize The input size
     * @param actualTime The actual execution time in milliseconds
     * @param hasVectorization Whether vectorization was used
     */
    public void recordMeasurement(String operation, int inputSize, 
                                double actualTime, boolean hasVectorization) {
        if (!calibrationEnabled) {
            return; // no-op when calibration is disabled
        }
        PerformanceModel model = getOrCreateModel(operation);
        
        // Update model with measurement
        model.updateWithMeasurement(inputSize, actualTime, hasVectorization);
        
        // Check if recalibration is needed
        int count = measurementCount.incrementAndGet();
        if (count % RECALIBRATION_THRESHOLD == 0) {
            checkRecalibration();
        }
    }
    
    /**
     * Forces recalibration of all models.
     */
    public void recalibrate() {
        if (!calibrationEnabled) {
            return; // no-op when calibration is disabled
        }
        // Prevent concurrent calibrations
        if (!isCalibrating.compareAndSet(false, true)) {
            return; // Already calibrating
        }
        
        try {
            LOG.log(System.Logger.Level.INFO, "Recalibrating performance models...");
            
            PerformanceCalibrator calibrator = new PerformanceCalibrator();
            models = calibrator.calibrate();
            
            // Update operation-specific models
            operationModels.put("MODWT", models.modwtModel());
            operationModels.put("Convolution", models.convolutionModel());
            operationModels.put("Batch", models.batchModel());
            
            // Save calibrated models
            saveModels();
            
            lastCalibrationTime.set(System.currentTimeMillis());
            measurementCount.set(0);
        } finally {
            isCalibrating.set(false);
        }
    }
    
    /**
     * Gets current model accuracy statistics.
     * 
     * @return Accuracy report for all models
     */
    public String getAccuracyReport() {
        StringBuilder report = new StringBuilder();
        report.append("Performance Model Accuracy Report\n");
        report.append("=================================\n");
        
        operationModels.forEach((operation, model) -> {
            report.append("\n").append(operation).append(" Model:\n");
            report.append(model.getAccuracy().getSummary());
            report.append("\n");
        });
        
        long hoursSinceCalibration = 
            (System.currentTimeMillis() - lastCalibrationTime.get()) / (1000 * 60 * 60);
        report.append("\nLast calibration: ").append(hoursSinceCalibration).append(" hours ago\n");
        report.append("Total measurements: ").append(measurementCount.get()).append("\n");
        
        return report.toString();
    }
    
    /**
     * Shuts down the calibration executor service.
     * Should be called when the application terminates.
     */
    public void shutdown() {
        calibrationExecutor.shutdown();
        try {
            // Wait for any ongoing calibration to complete
            if (!calibrationExecutor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS)) {
                calibrationExecutor.shutdownNow();
            }
        } catch (InterruptedException e) {
            calibrationExecutor.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }
    
    // Private helper methods
    
    private PerformanceModel getOrCreateModel(String operation) {
        return operationModels.computeIfAbsent(operation, k -> {
            // Create default model if not exists
            PlatformFactors factors = PlatformFactors.detectPlatform();
            return new PerformanceModel(factors);
        });
    }
    
    private double getWaveletComplexityFactor(String waveletName) {
        // Relative complexity factors based on filter length and operations
        return switch (waveletName.toLowerCase()) {
            case "haar" -> 1.0;
            case "db2", "daub2", "daubechies2" -> 1.5;
            case "db4", "daub4", "daubechies4" -> 2.0;
            case "db6", "daub6", "daubechies6" -> 2.5;
            case "db8", "daub8", "daubechies8" -> 3.0;
            default -> 2.0; // Default for unknown wavelets
        };
    }
    
    private void checkRecalibration() {
        if (!calibrationEnabled) {
            return;
        }
        boolean needsRecalibration = false;
        
        // Check if any model needs recalibration
        for (PerformanceModel model : operationModels.values()) {
            if (model.needsRecalibration()) {
                needsRecalibration = true;
                break;
            }
            
            // Also check overall accuracy
            if (model.getAccuracy().getConfidence() < ACCURACY_THRESHOLD) {
                needsRecalibration = true;
                break;
            }
        }
        
        if (needsRecalibration) {
            // Recalibrate in background to avoid blocking
            calibrationExecutor.submit(this::recalibrate);
        }
    }
    
    private void loadModels() {
        File modelFile = new File(System.getProperty("user.home"), 
            MODEL_CACHE_DIR + "/" + MODEL_FILE);
        
        if (modelFile.exists()) {
            try {
                models = PerformanceCalibrator.CalibratedModels.load(modelFile.getPath());
                
                // Populate operation models
                operationModels.put("MODWT", models.modwtModel());
                operationModels.put("Convolution", models.convolutionModel());
                operationModels.put("Batch", models.batchModel());
                
                LOG.log(System.Logger.Level.INFO, () -> "Loaded performance models from " + modelFile);
            } catch (Exception e) {
                LOG.log(System.Logger.Level.WARNING, "Failed to load performance models: {0}", e.getMessage());
                // Will use default models
            }
        }
    }
    
    private void saveModels() {
        if (models == null || !persistenceEnabled) return;
        
        File modelDir = new File(System.getProperty("user.home"), MODEL_CACHE_DIR);
        if (!modelDir.exists()) {
            modelDir.mkdirs();
        }
        
        File modelFile = new File(modelDir, MODEL_FILE);
        
        try {
            models.save(modelFile.getPath());
            LOG.log(System.Logger.Level.INFO, () -> "Saved performance models to " + modelFile);
        } catch (Exception e) {
            LOG.log(System.Logger.Level.WARNING, "Failed to save performance models: {0}", e.getMessage());
        }
    }

    // Helper: read boolean from system properties and/or env vars
    private static boolean getBoolean(String prop, String altProp, String env, String altEnv, boolean defaultVal) {
        String v = null;
        if (prop != null) v = System.getProperty(prop);
        if ((v == null || v.isBlank()) && altProp != null) v = System.getProperty(altProp);
        if ((v == null || v.isBlank()) && env != null) v = System.getenv(env);
        if ((v == null || v.isBlank()) && altEnv != null) v = System.getenv(altEnv);
        if (v == null) return defaultVal;
        v = v.trim().toLowerCase();
        return v.equals("true") || v.equals("1") || v.equals("yes") || v.equals("on");
    }

    // Package-private diagnostic helpers (primarily for tests)
    boolean isCalibrationEnabled() { return calibrationEnabled; }
    boolean isPersistenceEnabled() { return persistenceEnabled; }
    int getTotalMeasurements() { return measurementCount.get(); }
}