LCOV - code coverage report
Current view: top level - src/audio/dsp - convolution_engine.cpp (source / functions) Coverage Total Hit
Test: merged.info Lines: 99.4 % 173 172
Test Date: 2026-06-01 11:15:25 Functions: 100.0 % 12 12

            Line data    Source code
       1              : #include "audio/dsp/convolution_engine.h"
       2              : #include "kiss_fft.h"
       3              : #include <cstring>
       4              : #include <cmath>
       5              : #include <algorithm>
       6              : 
       7              : namespace Amplitron {
       8              : 
       9              : // Helper: complex multiply-accumulate (accum += a * b)
      10           12 : static void complex_multiply_accumulate(kiss_fft_cpx* accum,
      11              :                                          const kiss_fft_cpx* a,
      12              :                                          const kiss_fft_cpx* b,
      13              :                                          int n) {
      14         6156 :     for (int i = 0; i < n; ++i) {
      15         6144 :         accum[i].r += a[i].r * b[i].r - a[i].i * b[i].i;
      16         6144 :         accum[i].i += a[i].r * b[i].i + a[i].i * b[i].r;
      17         2048 :     }
      18           12 : }
      19              : 
      20              : // =============================================================================
      21              : // ConvolutionKernel
      22              : // =============================================================================
      23              : 
      24          104 : ConvolutionKernel::ConvolutionKernel(const std::vector<float>& ir_samples,
      25           52 :                                      int block_size)
      26           52 :     : block_size_(block_size)
      27           52 :     , fft_size_(block_size * 2)
      28           78 :     , ir_length_(static_cast<int>(ir_samples.size()))
      29           78 :     , ir_time_(ir_samples) {
      30              : 
      31           78 :     if (ir_samples.empty() || block_size <= 0) {
      32            3 :         num_partitions_ = 0;
      33            3 :         return;
      34              :     }
      35              : 
      36              :     // Number of partitions = ceil(ir_length / block_size)
      37           75 :     num_partitions_ = (ir_length_ + block_size_ - 1) / block_size_;
      38              : 
      39              :     // Allocate forward FFT config
      40           75 :     fft_cfg_ = kiss_fft_alloc(fft_size_, 0, nullptr, nullptr);
      41              : 
      42              :     // Pre-compute frequency-domain representation of each partition
      43          100 :     std::vector<kiss_fft_cpx> time_buf(static_cast<size_t>(fft_size_));
      44           75 :     std::vector<kiss_fft_cpx> freq_buf(static_cast<size_t>(fft_size_));
      45              : 
      46           75 :     partitions_freq_.resize(static_cast<size_t>(num_partitions_));
      47              : 
      48          162 :     for (int p = 0; p < num_partitions_; ++p) {
      49              :         // Zero the time buffer
      50           87 :         std::fill(time_buf.begin(), time_buf.end(), kiss_fft_cpx{0.0f, 0.0f});
      51              : 
      52              :         // Copy this partition's IR samples into the real part
      53           87 :         int offset = p * block_size_;
      54           87 :         int count = std::min(block_size_, ir_length_ - offset);
      55         4830 :         for (int i = 0; i < count; ++i) {
      56         4743 :             time_buf[static_cast<size_t>(i)].r = ir_samples[static_cast<size_t>(offset + i)];
      57         4743 :             time_buf[static_cast<size_t>(i)].i = 0.0f;
      58         1581 :         }
      59              : 
      60              :         // Forward FFT
      61           87 :         kiss_fft(fft_cfg_, time_buf.data(), freq_buf.data());
      62              : 
      63              :         // Store the frequency-domain partition
      64           87 :         size_t byte_size = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size_);
      65           87 :         partitions_freq_[static_cast<size_t>(p)].resize(byte_size);
      66          116 :         std::memcpy(partitions_freq_[static_cast<size_t>(p)].data(),
      67           87 :                     freq_buf.data(), byte_size);
      68           29 :     }
      69           76 : }
      70              : 
      71          104 : ConvolutionKernel::~ConvolutionKernel() {
      72           78 :     if (fft_cfg_) kiss_fft_free(fft_cfg_);
      73          104 : }
      74              : 
      75           21 : const void* ConvolutionKernel::partition_freq(int index) const {
      76           17 :     if (index < 0 || index >= num_partitions_) return nullptr;
      77           15 :     return partitions_freq_[static_cast<size_t>(index)].data();
      78            7 : }
      79              : 
      80              : // =============================================================================
      81              : // ConvolutionEngine
      82              : // =============================================================================
      83              : 
      84          180 : ConvolutionEngine::ConvolutionEngine() = default;
      85              : 
      86          180 : ConvolutionEngine::~ConvolutionEngine() {
      87          135 :     cleanup_fft();
      88          180 : }
      89              : 
      90           33 : void ConvolutionEngine::init_fft(int fft_size) {
      91           33 :     cleanup_fft();
      92           33 :     fft_cfg_ = kiss_fft_alloc(fft_size, 0, nullptr, nullptr);   // forward
      93           33 :     ifft_cfg_ = kiss_fft_alloc(fft_size, 1, nullptr, nullptr);  // inverse
      94           33 :     current_fft_size_ = fft_size;
      95           33 : }
      96              : 
      97          210 : void ConvolutionEngine::cleanup_fft() {
      98          210 :     if (fft_cfg_) { kiss_fft_free(fft_cfg_); fft_cfg_ = nullptr; }
      99          210 :     if (ifft_cfg_) { kiss_fft_free(ifft_cfg_); ifft_cfg_ = nullptr; }
     100          210 :     current_fft_size_ = 0;
     101          210 : }
     102              : 
     103           36 : void ConvolutionEngine::set_kernel(const ConvolutionKernel* kernel) {
     104           36 :     kernel_ = kernel;
     105           36 :     reset();
     106           36 : }
     107              : 
     108           78 : void ConvolutionEngine::reset() {
     109           78 :     if (!kernel_) {
     110           42 :         cleanup_fft();
     111           42 :         fdl_.clear();
     112           42 :         overlap_.clear();
     113           42 :         direct_input_.clear();
     114           42 :         direct_overlap_.clear();
     115           42 :         input_cpx_.clear();
     116           42 :         accum_cpx_.clear();
     117           42 :         ifft_out_cpx_.clear();
     118           42 :         fdl_index_ = 0;
     119           42 :         return;
     120              :     }
     121              : 
     122           36 :     int fft_size = kernel_->fft_size();
     123           36 :     int num_parts = kernel_->num_partitions();
     124              : 
     125              :     // Initialize FFT if size changed
     126           36 :     if (current_fft_size_ != fft_size) {
     127           33 :         init_fft(fft_size);
     128           11 :     }
     129              : 
     130              :     // Initialize frequency-domain delay line
     131           36 :     if (fft_size <= 0 || fft_size > 65536) return; // Sanity check
     132              : 
     133           36 :     size_t cpx_bytes = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size);
     134           36 :     fdl_.resize(static_cast<size_t>(num_parts));
     135           81 :     for (auto& buf : fdl_) {
     136           45 :         buf.assign(cpx_bytes, 0);
     137              :     }
     138           36 :     fdl_index_ = 0;
     139              : 
     140              :     // Initialize overlap buffer
     141           36 :     overlap_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
     142              : 
     143              :     // Initialize direct convolution overlap
     144           36 :     int ir_len = kernel_->ir_length();
     145           36 :     if (ir_len > 0) {
     146           36 :         direct_overlap_.assign(static_cast<size_t>(ir_len - 1), 0.0f);
     147           12 :     } else {
     148            0 :         direct_overlap_.clear();
     149              :     }
     150              : 
     151              :     // Scratch input copy for direct convolution fallback
     152           36 :     direct_input_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
     153              : 
     154              :     // Initialize FFT workspace buffers (allocation-free during process())
     155           36 :     input_cpx_.assign(cpx_bytes, 0);
     156           36 :     accum_cpx_.assign(cpx_bytes, 0);
     157           36 :     ifft_out_cpx_.assign(cpx_bytes, 0);
     158           26 : }
     159              : 
     160         3033 : void ConvolutionEngine::process_direct(float* buffer, int num_samples) {
     161         3033 :     const auto& ir = kernel_->ir_time_domain();
     162         3033 :     int ir_len = static_cast<int>(ir.size());
     163         3033 :     if (ir_len == 0) return;
     164              : 
     165              :     // Output length = num_samples + ir_len - 1.
     166              :     // We output num_samples and carry over the tail (overlap-add) in direct_overlap_.
     167         3033 :     const int tail_len = ir_len - 1;
     168         3033 :     if (tail_len <= 0) return;
     169              : 
     170              :     // direct_overlap_ should be pre-sized in reset(); avoid allocations here.
     171         3027 :     if (static_cast<int>(direct_overlap_.size()) != tail_len) return;
     172              : 
     173              :     // Need original input; avoid per-call allocations by using direct_input_ scratch.
     174         3027 :     if (static_cast<int>(direct_input_.size()) < num_samples) return;
     175         3027 :     std::memcpy(direct_input_.data(), buffer, sizeof(float) * static_cast<size_t>(num_samples));
     176              : 
     177              :     // First num_samples samples: convolution + previous overlap
     178       776883 :     for (int n = 0; n < num_samples; ++n) {
     179       773856 :         float y = (n < tail_len) ? direct_overlap_[static_cast<size_t>(n)] : 0.0f;
     180              : 
     181              :         // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j]
     182       773856 :         int j0 = std::max(0, n - (num_samples - 1));
     183       773856 :         int j1 = std::min(ir_len - 1, n);
     184      3087390 :         for (int j = j0; j <= j1; ++j) {
     185      2313534 :             y += direct_input_[static_cast<size_t>(n - j)] * ir[static_cast<size_t>(j)];
     186       771178 :         }
     187       773856 :         buffer[n] = y;
     188       257952 :     }
     189              : 
     190              :     // Tail samples for next block: y[num_samples .. num_samples+tail_len-1]
     191         9081 :     for (int t = 0; t < tail_len; ++t) {
     192         6054 :         const int n = num_samples + t;
     193         6054 :         float y = 0.0f;
     194              : 
     195              :         // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j], where (n-j) in [0, num_samples-1]
     196         6054 :         int j0 = std::max(0, n - (num_samples - 1));
     197         6054 :         int j1 = std::min(ir_len - 1, n);
     198        15144 :         for (int j = j0; j <= j1; ++j) {
     199         9090 :             int xi = n - j;
     200         9090 :             if (xi >= 0 && xi < num_samples) {
     201         9090 :                 y += direct_input_[static_cast<size_t>(xi)] * ir[static_cast<size_t>(j)];
     202         3030 :             }
     203         3030 :         }
     204         6054 :         direct_overlap_[static_cast<size_t>(t)] = y;
     205         2018 :     }
     206         1011 : }
     207              : 
     208         3036 : void ConvolutionEngine::process(float* buffer, int num_samples) {
     209         3036 :     if (!kernel_ || kernel_->ir_length() == 0) return;
     210              : 
     211         3036 :     int block_size = kernel_->block_size();
     212         3036 :     int fft_size = kernel_->fft_size();
     213         3036 :     int num_parts = kernel_->num_partitions();
     214              : 
     215              :     // If block size doesn't match, fall back to direct convolution
     216         3036 :     if (num_samples != block_size) {
     217            3 :         process_direct(buffer, num_samples);
     218            3 :         return;
     219              :     }
     220              : 
     221              :     // Also use direct convolution for very short IRs (1 partition, IR <= block_size)
     222         3033 :     if (num_parts == 1 && kernel_->ir_length() <= block_size) {
     223         3030 :         process_direct(buffer, num_samples);
     224         3030 :         return;
     225              :     }
     226              : 
     227              :     // --- Partitioned overlap-add convolution ---
     228              : 
     229              :     // 1. Prepare input: zero-pad to fft_size
     230            3 :     auto* input_cpx = reinterpret_cast<kiss_fft_cpx*>(input_cpx_.data());
     231          771 :     for (int i = 0; i < block_size; ++i) {
     232          768 :         input_cpx[static_cast<size_t>(i)].r = buffer[i];
     233          768 :         input_cpx[static_cast<size_t>(i)].i = 0.0f;
     234          256 :     }
     235          771 :     for (int i = block_size; i < fft_size; ++i) {
     236          768 :         input_cpx[static_cast<size_t>(i)].r = 0.0f;
     237          768 :         input_cpx[static_cast<size_t>(i)].i = 0.0f;
     238          256 :     }
     239              : 
     240              :     // 2. Forward FFT of input -> store in FDL at current index
     241            2 :     auto* fdl_data = reinterpret_cast<kiss_fft_cpx*>(
     242            3 :         fdl_[static_cast<size_t>(fdl_index_)].data());
     243            3 :     kiss_fft(fft_cfg_, input_cpx, fdl_data);
     244              : 
     245              :     // 3. Complex multiply-accumulate across all partitions
     246            3 :     auto* accum = reinterpret_cast<kiss_fft_cpx*>(accum_cpx_.data());
     247            3 :     std::memset(accum, 0, sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size));
     248              : 
     249           15 :     for (int k = 0; k < num_parts; ++k) {
     250              :         // FDL index for partition k (circular)
     251           12 :         int fdl_idx = (fdl_index_ - k + num_parts) % num_parts;
     252            8 :         const auto* fdl_block = reinterpret_cast<const kiss_fft_cpx*>(
     253           12 :             fdl_[static_cast<size_t>(fdl_idx)].data());
     254            8 :         const auto* ir_block = reinterpret_cast<const kiss_fft_cpx*>(
     255           12 :             kernel_->partition_freq(k));
     256              : 
     257           12 :         complex_multiply_accumulate(accum, fdl_block, ir_block, fft_size);
     258            4 :     }
     259              : 
     260              :     // 4. Inverse FFT
     261            3 :     auto* ifft_out = reinterpret_cast<kiss_fft_cpx*>(ifft_out_cpx_.data());
     262            3 :     kiss_fft(ifft_cfg_, accum, ifft_out);
     263              : 
     264              :     // kiss_fft inverse does NOT normalize -- divide by fft_size
     265            3 :     float norm = 1.0f / static_cast<float>(fft_size);
     266              : 
     267              :     // 5. Output = first block_size samples + overlap from previous block
     268          771 :     for (int i = 0; i < block_size; ++i) {
     269         1024 :         buffer[i] = ifft_out[static_cast<size_t>(i)].r * norm +
     270          768 :                     overlap_[static_cast<size_t>(i)];
     271          256 :     }
     272              : 
     273              :     // 6. Store new overlap (last block_size samples of IFFT result)
     274          771 :     for (int i = 0; i < block_size; ++i) {
     275          768 :         overlap_[static_cast<size_t>(i)] =
     276          768 :             ifft_out[static_cast<size_t>(block_size + i)].r * norm;
     277          256 :     }
     278              : 
     279              :     // 7. Advance FDL index
     280            3 :     fdl_index_ = (fdl_index_ + 1) % num_parts;
     281         1012 : }
     282              : 
     283              : } // namespace Amplitron
        

Generated by: LCOV version 2.0-1