MathUtils.java

package com.morphiqlabs.wavelet.util;

/**
 * Utility class providing general-purpose mathematical algorithms.
 * 
 * <p>This class contains static methods for various mathematical operations
 * that are used throughout the wavelet library but are general enough to be
 * useful in other contexts.</p>
 */
public final class MathUtils {
    
    // Private constructor to prevent instantiation
    private MathUtils() {
        throw new AssertionError("Utility class should not be instantiated");
    }
    
    /**
     * Applies symmetric boundary extension for a given index.
     * 
     * <p>This method implements whole-sample symmetry reflection used in wavelet transforms.
     * For a signal [a, b, c, d], the extension is: ... b a | a b c d | d c b a | a b c d | d c ...</p>
     * 
     * <p>The algorithm efficiently handles any index using modular arithmetic, avoiding
     * inefficient while loops for long filter supports.</p>
     * 
     * @param idx the index to reflect
     * @param signalLength the length of the signal
     * @return the reflected index within [0, signalLength)
     */
    public static int symmetricBoundaryExtension(int idx, int signalLength) {
        if (signalLength <= 0) {
            throw new IllegalArgumentException("Signal length must be positive");
        }
        
        // If index is already within bounds, return it
        if (idx >= 0 && idx < signalLength) {
            return idx;
        }
        
        // Efficient calculation using modular arithmetic
        // Map to positive index in the extended symmetric domain [0, 2*signalLength)
        int period = 2 * signalLength;
        idx = ((idx % period) + period) % period;
        
        // Map back to signal range with reflection
        if (idx >= signalLength) {
            idx = period - idx - 1;
        }
        
        return idx;
    }
    
    /**
     * Finds the kth smallest element in an array using the QuickSelect algorithm.
     * This is more efficient than full sorting when only a specific order statistic is needed.
     * 
     * <p>Time complexity:</p>
     * <ul>
     *   <li>Average case: O(n)</li>
     *   <li>Worst case: O(n²)</li>
     * </ul>
     * 
     * <p>The algorithm modifies the input array. If the original array must be preserved,
     * pass a copy of the array.</p>
     * 
     * @param arr the array to select from (will be modified)
     * @param k the index of the desired element (0-based)
     * @return the kth smallest element
     * @throws IllegalArgumentException if k is out of bounds or array is null/empty
     */
    public static double quickSelect(double[] arr, int k) {
        if (arr == null) {
            throw new IllegalArgumentException("Array cannot be null");
        }
        if (arr.length == 0) {
            throw new IllegalArgumentException("Array cannot be empty");
        }
        if (k < 0 || k >= arr.length) {
            throw new IllegalArgumentException(
                String.format("k=%d is out of bounds [0, %d)", k, arr.length));
        }
        
        return quickSelectInternal(arr, 0, arr.length - 1, k);
    }
    
    /**
     * Finds the median of an array using QuickSelect.
     * More efficient than sorting for just finding the median.
     * 
     * @param arr the array (will be modified)
     * @return the median value
     * @throws IllegalArgumentException if array is null or empty
     */
    public static double median(double[] arr) {
        if (arr == null || arr.length == 0) {
            throw new IllegalArgumentException("Array cannot be null or empty");
        }
        
        int n = arr.length;
        if (n % 2 == 1) {
            // Odd length: return middle element
            return quickSelect(arr, n / 2);
        } else {
            // Even length: return average of two middle elements
            // We need to find both (n/2 - 1) and (n/2) elements
            double[] copy = arr.clone();
            double lower = quickSelect(copy, n / 2 - 1);
            double upper = quickSelect(arr, n / 2);
            return (lower + upper) / 2.0;
        }
    }
    
    /**
     * Calculates the Median Absolute Deviation (MAD) of an array.
     * MAD is a robust measure of variability: median(|x_i - median(x)|)
     * 
     * @param values the input values
     * @return the median absolute deviation
     * @throws IllegalArgumentException if array is null or empty
     */
    public static double medianAbsoluteDeviation(double[] values) {
        if (values == null || values.length == 0) {
            throw new IllegalArgumentException("Array cannot be null or empty");
        }
        
        // First, find median (on a copy to preserve original)
        double[] work = values.clone();
        double median = median(work);
        
        // Calculate absolute deviations
        for (int i = 0; i < work.length; i++) {
            work[i] = Math.abs(values[i] - median);
        }
        
        // Return median of absolute deviations
        return median(work);
    }
    
    /**
     * Internal QuickSelect implementation with explicit bounds.
     * 
     * @param arr the array to select from (will be modified)
     * @param left the left boundary of the search range (inclusive)
     * @param right the right boundary of the search range (inclusive)
     * @param k the index of the desired element (0-based, must be within [left, right])
     * @return the kth smallest element in the range [left, right]
     */
    private static double quickSelectInternal(double[] arr, int left, int right, int k) {
        // Base case: single element
        if (left == right) {
            return arr[left];
        }
        
        // Choose pivot and partition
        int pivotIndex = partition(arr, left, right);
        
        // Recursively search the appropriate partition
        if (k == pivotIndex) {
            return arr[k];
        } else if (k < pivotIndex) {
            return quickSelectInternal(arr, left, pivotIndex - 1, k);
        } else {
            return quickSelectInternal(arr, pivotIndex + 1, right, k);
        }
    }
    
    /**
     * Partitions the array around a pivot element.
     * Uses the "median-of-three" strategy to choose a good pivot.
     * 
     * @param arr the array to partition
     * @param left the left boundary (inclusive)
     * @param right the right boundary (inclusive)
     * @return the final position of the pivot
     */
    private static int partition(double[] arr, int left, int right) {
        // Use median-of-three to choose pivot for better performance
        int mid = left + (right - left) / 2;
        
        // Sort left, middle, and right elements
        if (arr[left] > arr[mid]) {
            swap(arr, left, mid);
        }
        if (arr[mid] > arr[right]) {
            swap(arr, mid, right);
            if (arr[left] > arr[mid]) {
                swap(arr, left, mid);
            }
        }
        
        // Use middle element as pivot
        double pivot = arr[mid];
        
        // Move pivot to end
        swap(arr, mid, right);
        
        // Partition around pivot
        int storeIndex = left;
        for (int i = left; i < right; i++) {
            if (arr[i] < pivot) {
                swap(arr, storeIndex, i);
                storeIndex++;
            }
        }
        
        // Move pivot to its final position
        swap(arr, storeIndex, right);
        return storeIndex;
    }
    
    /**
     * Swaps two elements in an array.
     * 
     * @param arr the array
     * @param i first index
     * @param j second index
     */
    private static void swap(double[] arr, int i, int j) {
        if (i != j) {
            double temp = arr[i];
            arr[i] = arr[j];
            arr[j] = temp;
        }
    }
    
    /**
     * Computes the standard deviation of an array of values.
     * 
     * @param values the input values
     * @return the standard deviation
     * @throws IllegalArgumentException if array is null or has less than 2 elements
     */
    public static double standardDeviation(double[] values) {
        if (values == null) {
            throw new IllegalArgumentException("Array cannot be null");
        }
        if (values.length < 2) {
            throw new IllegalArgumentException("Need at least 2 values for standard deviation");
        }
        
        // Calculate mean
        double sum = 0.0;
        for (double value : values) {
            sum += value;
        }
        double mean = sum / values.length;
        
        // Calculate variance
        double sumSquaredDiff = 0.0;
        for (double value : values) {
            double diff = value - mean;
            sumSquaredDiff += diff * diff;
        }
        
        // Return standard deviation
        return Math.sqrt(sumSquaredDiff / (values.length - 1));
    }
}