PerformanceModel.java
package com.morphiqlabs.wavelet.performance;
import java.io.Serializable;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Empirical performance model for wavelet operations with platform-specific calibration.
*
* <p>This class represents a performance model that can predict execution time based on
* input size and system characteristics. The model is calibrated using actual measurements
* from the target platform.</p>
*
* <p>The model uses a piecewise polynomial approach where different size ranges have
* different coefficients, allowing accurate predictions across a wide range of input sizes.</p>
*/
public class PerformanceModel implements Serializable {
private static final long serialVersionUID = 1L;
/**
* Model coefficients for different size ranges.
* Each range has the form: time = a + b*n + c*n^2
*/
private final ConcurrentHashMap<SizeRange, ModelCoefficients> coefficients = new ConcurrentHashMap<>();
/**
* Platform-specific scaling factors.
*/
private final PlatformFactors platformFactors;
/**
* Confidence intervals for predictions.
*/
private final ConcurrentHashMap<SizeRange, ConfidenceInterval> confidenceIntervals = new ConcurrentHashMap<>();
/**
* Number of measurements used for calibration.
*/
private int measurementCount = 0;
/**
* Timestamp of last calibration.
*/
private long lastCalibrationTime = 0;
/**
* Model accuracy metrics.
*/
private ModelAccuracy accuracy = new ModelAccuracy();
/**
* Creates a new performance model with platform-specific factors.
*
* @param platformFactors Platform-specific scaling factors
*/
public PerformanceModel(PlatformFactors platformFactors) {
this.platformFactors = platformFactors;
initializeDefaultCoefficients();
}
/**
* Predicts execution time for a given input size.
*
* @param inputSize The size of the input
* @param hasVectorization Whether vectorization is available
* @return Predicted execution time in milliseconds
*/
public PredictionResult predict(int inputSize, boolean hasVectorization) {
SizeRange range = SizeRange.forSize(inputSize);
ModelCoefficients coeff = coefficients.get(range);
if (coeff == null) {
// Extrapolate from nearest range
coeff = extrapolateCoefficients(inputSize);
}
// Base prediction using polynomial model
double baseTime = coeff.evaluate(inputSize);
// Ensure base time is positive - safety guard
if (baseTime <= 0) {
baseTime = 0.001; // 1 microsecond minimum
}
// Apply platform-specific factors
double adjustedTime = baseTime * platformFactors.cpuSpeedFactor;
// Apply vectorization speedup if available
if (hasVectorization) {
adjustedTime /= platformFactors.vectorSpeedup;
}
// Apply cache effects
adjustedTime *= getCacheEffectMultiplier(inputSize);
// Final safety guard to ensure positive time
if (adjustedTime <= 0) {
adjustedTime = 0.001; // 1 microsecond minimum
}
// Get confidence interval
ConfidenceInterval ci = confidenceIntervals.getOrDefault(range,
new ConfidenceInterval(0.9, 1.1));
return new PredictionResult(
adjustedTime,
ci.getLowerBound(adjustedTime),
ci.getUpperBound(adjustedTime),
accuracy.getConfidence()
);
}
/**
* Updates the model with a new measurement.
*
* @param inputSize The input size
* @param actualTime The actual execution time in milliseconds
* @param hasVectorization Whether vectorization was used
*/
public void updateWithMeasurement(int inputSize, double actualTime, boolean hasVectorization) {
// Normalize time to base conditions
double normalizedTime = actualTime;
if (hasVectorization) {
normalizedTime *= platformFactors.vectorSpeedup;
}
normalizedTime /= platformFactors.cpuSpeedFactor;
normalizedTime /= getCacheEffectMultiplier(inputSize);
// Update coefficients for the appropriate range
SizeRange range = SizeRange.forSize(inputSize);
ModelCoefficients coeff = coefficients.computeIfAbsent(range,
k -> new ModelCoefficients());
// Use online learning to update coefficients
coeff.updateWithMeasurement(inputSize, normalizedTime);
// Update confidence intervals
updateConfidenceInterval(range, inputSize, actualTime, normalizedTime);
// Update model accuracy
double predicted = predict(inputSize, hasVectorization).estimatedTime();
accuracy.updateWithPrediction(predicted, actualTime);
measurementCount++;
lastCalibrationTime = System.currentTimeMillis();
}
/**
* Performs batch calibration with multiple measurements.
*
* @param measurements Array of measurements
*/
public void calibrate(Measurement[] measurements) {
// Clear existing model
coefficients.clear();
confidenceIntervals.clear();
accuracy = new ModelAccuracy();
// Group measurements by size range
Map<SizeRange, java.util.List<Measurement>> grouped =
java.util.Arrays.stream(measurements)
.collect(java.util.stream.Collectors.groupingBy(
m -> SizeRange.forSize(m.inputSize)));
// Fit coefficients for each range
grouped.forEach((range, rangeMeasurements) -> {
ModelCoefficients coeff = fitCoefficients(rangeMeasurements);
coefficients.put(range, coeff);
// Calculate confidence intervals
ConfidenceInterval ci = calculateConfidenceInterval(rangeMeasurements, coeff);
confidenceIntervals.put(range, ci);
});
measurementCount = measurements.length;
lastCalibrationTime = System.currentTimeMillis();
// Validate model accuracy
for (Measurement m : measurements) {
double predicted = predict(m.inputSize, m.hasVectorization).estimatedTime();
accuracy.updateWithPrediction(predicted, m.actualTime);
}
}
/**
* Gets model accuracy metrics.
*
* @return Current model accuracy
*/
public ModelAccuracy getAccuracy() {
return accuracy;
}
/**
* Checks if the model needs recalibration.
*
* @return true if recalibration is recommended
*/
public boolean needsRecalibration() {
// Recalibrate if accuracy drops below threshold
if (accuracy.getMeanAbsolutePercentageError() > 0.15) {
return true;
}
// Recalibrate if model is old (> 30 days)
long daysSinceCalibration = (System.currentTimeMillis() - lastCalibrationTime)
/ (1000 * 60 * 60 * 24);
if (daysSinceCalibration > 30) {
return true;
}
// Recalibrate if too few measurements
return measurementCount < 100;
}
/**
* Exports model coefficients for persistence.
*
* @return Serializable model data
*/
public ModelData exportModel() {
return new ModelData(
new ConcurrentHashMap<>(coefficients),
new ConcurrentHashMap<>(confidenceIntervals),
platformFactors,
accuracy,
measurementCount,
lastCalibrationTime
);
}
/**
* Imports model coefficients from persisted data.
*
* @param data Previously exported model data
*/
public void importModel(ModelData data) {
coefficients.clear();
coefficients.putAll(data.coefficients);
confidenceIntervals.clear();
confidenceIntervals.putAll(data.confidenceIntervals);
accuracy = data.accuracy;
measurementCount = data.measurementCount;
lastCalibrationTime = data.lastCalibrationTime;
}
// Private helper methods
private void initializeDefaultCoefficients() {
// Initialize with reasonable defaults based on complexity analysis
coefficients.put(SizeRange.TINY, new ModelCoefficients(0.01, 0.00001, 0));
coefficients.put(SizeRange.SMALL, new ModelCoefficients(0.1, 0.00005, 0));
coefficients.put(SizeRange.MEDIUM, new ModelCoefficients(0.5, 0.0001, 0));
coefficients.put(SizeRange.LARGE, new ModelCoefficients(2.0, 0.0002, 0));
coefficients.put(SizeRange.HUGE, new ModelCoefficients(8.0, 0.0003, 0.0000001));
}
private ModelCoefficients extrapolateCoefficients(int inputSize) {
// Find nearest range and extrapolate
SizeRange nearest = SizeRange.MEDIUM;
int minDistance = Integer.MAX_VALUE;
for (SizeRange range : coefficients.keySet()) {
int distance = Math.abs(range.getMidpoint() - inputSize);
if (distance < minDistance) {
minDistance = distance;
nearest = range;
}
}
return coefficients.get(nearest);
}
private double getCacheEffectMultiplier(int inputSize) {
// Model cache effects based on data size
double dataSize = inputSize * 8.0; // doubles are 8 bytes
if (dataSize <= platformFactors.l1CacheSize) {
return 1.0; // Fits in L1 cache
} else if (dataSize <= platformFactors.l2CacheSize) {
return 1.2; // Fits in L2 cache
} else if (dataSize <= platformFactors.l3CacheSize) {
return 1.5; // Fits in L3 cache
} else {
return 2.0; // Main memory access
}
}
private ModelCoefficients fitCoefficients(java.util.List<Measurement> measurements) {
// Use least squares fitting for polynomial model
// This is a simplified implementation - production code would use
// a proper numerical library
int n = measurements.size();
if (n < 3) {
// Not enough data for quadratic fit
return new ModelCoefficients(0.1, 0.0001, 0);
}
// Calculate sums for least squares
double sumX = 0, sumX2 = 0, sumX3 = 0, sumX4 = 0;
double sumY = 0, sumXY = 0, sumX2Y = 0;
for (Measurement m : measurements) {
double x = m.inputSize;
double y = m.normalizedTime;
sumX += x;
sumX2 += x * x;
sumX3 += x * x * x;
sumX4 += x * x * x * x;
sumY += y;
sumXY += x * y;
sumX2Y += x * x * y;
}
// Solve normal equations (simplified for demonstration)
// In practice, use a proper linear algebra library
double a = sumY / n; // Simplified constant term
double b = (n * sumXY - sumX * sumY) / (n * sumX2 - sumX * sumX);
double c = 0; // Simplified - skip quadratic term for stability
return new ModelCoefficients(a, b, c);
}
private ConfidenceInterval calculateConfidenceInterval(
java.util.List<Measurement> measurements, ModelCoefficients coeff) {
// Calculate residuals
double sumSquaredError = 0;
for (Measurement m : measurements) {
double predicted = coeff.evaluate(m.inputSize);
double error = m.normalizedTime - predicted;
sumSquaredError += error * error;
}
// Calculate standard deviation
double stdDev = Math.sqrt(sumSquaredError / measurements.size());
// 95% confidence interval (approximately 2 standard deviations)
double lowerMultiplier = 1.0 - 2.0 * stdDev / coeff.evaluate(1000);
double upperMultiplier = 1.0 + 2.0 * stdDev / coeff.evaluate(1000);
// Bound the multipliers to reasonable ranges
lowerMultiplier = Math.max(0.5, Math.min(0.95, lowerMultiplier));
upperMultiplier = Math.max(1.05, Math.min(2.0, upperMultiplier));
return new ConfidenceInterval(lowerMultiplier, upperMultiplier);
}
private void updateConfidenceInterval(SizeRange range, int inputSize,
double actualTime, double normalizedTime) {
ConfidenceInterval current = confidenceIntervals.get(range);
if (current == null) {
current = new ConfidenceInterval(0.9, 1.1);
confidenceIntervals.put(range, current);
}
// Update interval based on prediction error
double predicted = coefficients.get(range).evaluate(inputSize);
double error = Math.abs(normalizedTime - predicted) / predicted;
// Exponential moving average update
current.updateWithError(error);
}
/**
* Size ranges for piecewise modeling.
*/
public enum SizeRange {
TINY(0, 256),
SMALL(257, 1024),
MEDIUM(1025, 4096),
LARGE(4097, 16384),
HUGE(16385, Integer.MAX_VALUE);
private final int minSize;
private final int maxSize;
SizeRange(int minSize, int maxSize) {
this.minSize = minSize;
this.maxSize = maxSize;
}
public static SizeRange forSize(int size) {
for (SizeRange range : values()) {
if (size >= range.minSize && size <= range.maxSize) {
return range;
}
}
return HUGE;
}
public int getMidpoint() {
if (this == HUGE) {
return 32768; // Representative size for huge range
}
return (minSize + maxSize) / 2;
}
}
/**
* Measurement data for calibration.
*/
public static class Measurement {
public final int inputSize;
public final double actualTime;
public final boolean hasVectorization;
public final double normalizedTime;
public Measurement(int inputSize, double actualTime, boolean hasVectorization) {
this.inputSize = inputSize;
this.actualTime = actualTime;
this.hasVectorization = hasVectorization;
this.normalizedTime = actualTime; // Will be normalized during processing
}
}
/**
* Serializable model data for persistence.
*/
public static class ModelData implements Serializable {
private static final long serialVersionUID = 1L;
public final ConcurrentHashMap<SizeRange, ModelCoefficients> coefficients;
public final ConcurrentHashMap<SizeRange, ConfidenceInterval> confidenceIntervals;
public final PlatformFactors platformFactors;
public final ModelAccuracy accuracy;
public final int measurementCount;
public final long lastCalibrationTime;
public ModelData(ConcurrentHashMap<SizeRange, ModelCoefficients> coefficients,
ConcurrentHashMap<SizeRange, ConfidenceInterval> confidenceIntervals,
PlatformFactors platformFactors,
ModelAccuracy accuracy,
int measurementCount,
long lastCalibrationTime) {
this.coefficients = coefficients;
this.confidenceIntervals = confidenceIntervals;
this.platformFactors = platformFactors;
this.accuracy = accuracy;
this.measurementCount = measurementCount;
this.lastCalibrationTime = lastCalibrationTime;
}
}
}