Line data Source code
1 : #include "audio/dsp/convolution_engine.h"
2 :
3 : #include <algorithm>
4 : #include <cmath>
5 : #include <cstring>
6 :
7 : #include "kiss_fft.h"
8 :
9 : namespace Amplitron {
10 :
11 : // Helper: complex multiply-accumulate (accum += a * b)
12 12 : static void complex_multiply_accumulate(kiss_fft_cpx* accum, const kiss_fft_cpx* a,
13 : const kiss_fft_cpx* b, int n) {
14 6156 : for (int i = 0; i < n; ++i) {
15 6144 : accum[i].r += a[i].r * b[i].r - a[i].i * b[i].i;
16 6144 : accum[i].i += a[i].r * b[i].i + a[i].i * b[i].r;
17 2048 : }
18 12 : }
19 :
20 : // =============================================================================
21 : // ConvolutionKernel
22 : // =============================================================================
23 :
24 104 : ConvolutionKernel::ConvolutionKernel(const std::vector<float>& ir_samples, int block_size)
25 52 : : block_size_(block_size),
26 52 : fft_size_(block_size * 2),
27 78 : ir_length_(static_cast<int>(ir_samples.size())),
28 78 : ir_time_(ir_samples) {
29 78 : if (ir_samples.empty() || block_size <= 0) {
30 3 : num_partitions_ = 0;
31 3 : return;
32 : }
33 :
34 : // Number of partitions = ceil(ir_length / block_size)
35 75 : num_partitions_ = (ir_length_ + block_size_ - 1) / block_size_;
36 :
37 : // Allocate forward FFT config
38 75 : fft_cfg_ = kiss_fft_alloc(fft_size_, 0, nullptr, nullptr);
39 :
40 : // Pre-compute frequency-domain representation of each partition
41 100 : std::vector<kiss_fft_cpx> time_buf(static_cast<size_t>(fft_size_));
42 75 : std::vector<kiss_fft_cpx> freq_buf(static_cast<size_t>(fft_size_));
43 :
44 75 : partitions_freq_.resize(static_cast<size_t>(num_partitions_));
45 :
46 162 : for (int p = 0; p < num_partitions_; ++p) {
47 : // Zero the time buffer
48 87 : std::fill(time_buf.begin(), time_buf.end(), kiss_fft_cpx{0.0f, 0.0f});
49 :
50 : // Copy this partition's IR samples into the real part
51 87 : int offset = p * block_size_;
52 87 : int count = std::min(block_size_, ir_length_ - offset);
53 4830 : for (int i = 0; i < count; ++i) {
54 4743 : time_buf[static_cast<size_t>(i)].r = ir_samples[static_cast<size_t>(offset + i)];
55 4743 : time_buf[static_cast<size_t>(i)].i = 0.0f;
56 1581 : }
57 :
58 : // Forward FFT
59 87 : kiss_fft(fft_cfg_, time_buf.data(), freq_buf.data());
60 :
61 : // Store the frequency-domain partition
62 87 : size_t byte_size = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size_);
63 87 : partitions_freq_[static_cast<size_t>(p)].resize(byte_size);
64 87 : std::memcpy(partitions_freq_[static_cast<size_t>(p)].data(), freq_buf.data(), byte_size);
65 29 : }
66 76 : }
67 :
68 104 : ConvolutionKernel::~ConvolutionKernel() {
69 78 : if (fft_cfg_) kiss_fft_free(fft_cfg_);
70 104 : }
71 :
72 21 : const void* ConvolutionKernel::partition_freq(int index) const {
73 17 : if (index < 0 || index >= num_partitions_) return nullptr;
74 15 : return partitions_freq_[static_cast<size_t>(index)].data();
75 7 : }
76 :
77 : // =============================================================================
78 : // ConvolutionEngine
79 : // =============================================================================
80 :
81 180 : ConvolutionEngine::ConvolutionEngine() = default;
82 :
83 180 : ConvolutionEngine::~ConvolutionEngine() { cleanup_fft(); }
84 :
85 33 : void ConvolutionEngine::init_fft(int fft_size) {
86 33 : cleanup_fft();
87 33 : fft_cfg_ = kiss_fft_alloc(fft_size, 0, nullptr, nullptr); // forward
88 33 : ifft_cfg_ = kiss_fft_alloc(fft_size, 1, nullptr, nullptr); // inverse
89 33 : current_fft_size_ = fft_size;
90 33 : }
91 :
92 210 : void ConvolutionEngine::cleanup_fft() {
93 210 : if (fft_cfg_) {
94 33 : kiss_fft_free(fft_cfg_);
95 33 : fft_cfg_ = nullptr;
96 11 : }
97 210 : if (ifft_cfg_) {
98 33 : kiss_fft_free(ifft_cfg_);
99 33 : ifft_cfg_ = nullptr;
100 11 : }
101 210 : current_fft_size_ = 0;
102 210 : }
103 :
104 36 : void ConvolutionEngine::set_kernel(const ConvolutionKernel* kernel) {
105 36 : kernel_ = kernel;
106 36 : reset();
107 36 : }
108 :
109 78 : void ConvolutionEngine::reset() {
110 78 : if (!kernel_) {
111 42 : cleanup_fft();
112 42 : fdl_.clear();
113 42 : overlap_.clear();
114 42 : direct_input_.clear();
115 42 : direct_overlap_.clear();
116 42 : input_cpx_.clear();
117 42 : accum_cpx_.clear();
118 42 : ifft_out_cpx_.clear();
119 42 : fdl_index_ = 0;
120 42 : return;
121 : }
122 :
123 36 : int fft_size = kernel_->fft_size();
124 36 : int num_parts = kernel_->num_partitions();
125 :
126 : // Initialize FFT if size changed
127 36 : if (current_fft_size_ != fft_size) {
128 33 : init_fft(fft_size);
129 11 : }
130 :
131 : // Initialize frequency-domain delay line
132 36 : if (fft_size <= 0 || fft_size > 65536) return; // Sanity check
133 :
134 36 : size_t cpx_bytes = sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size);
135 36 : fdl_.resize(static_cast<size_t>(num_parts));
136 81 : for (auto& buf : fdl_) {
137 45 : buf.assign(cpx_bytes, 0);
138 : }
139 36 : fdl_index_ = 0;
140 :
141 : // Initialize overlap buffer
142 36 : overlap_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
143 :
144 : // Initialize direct convolution overlap
145 36 : int ir_len = kernel_->ir_length();
146 36 : if (ir_len > 0) {
147 36 : direct_overlap_.assign(static_cast<size_t>(ir_len - 1), 0.0f);
148 12 : } else {
149 0 : direct_overlap_.clear();
150 : }
151 :
152 : // Scratch input copy for direct convolution fallback
153 36 : direct_input_.assign(static_cast<size_t>(kernel_->block_size()), 0.0f);
154 :
155 : // Initialize FFT workspace buffers (allocation-free during process())
156 36 : input_cpx_.assign(cpx_bytes, 0);
157 36 : accum_cpx_.assign(cpx_bytes, 0);
158 36 : ifft_out_cpx_.assign(cpx_bytes, 0);
159 26 : }
160 :
161 3033 : void ConvolutionEngine::process_direct(float* buffer, int num_samples) {
162 3033 : const auto& ir = kernel_->ir_time_domain();
163 3033 : int ir_len = static_cast<int>(ir.size());
164 3033 : if (ir_len == 0) return;
165 :
166 : // Output length = num_samples + ir_len - 1.
167 : // We output num_samples and carry over the tail (overlap-add) in direct_overlap_.
168 3033 : const int tail_len = ir_len - 1;
169 3033 : if (tail_len <= 0) return;
170 :
171 : // direct_overlap_ should be pre-sized in reset(); avoid allocations here.
172 3027 : if (static_cast<int>(direct_overlap_.size()) != tail_len) return;
173 :
174 : // Need original input; avoid per-call allocations by using direct_input_ scratch.
175 3027 : if (static_cast<int>(direct_input_.size()) < num_samples) return;
176 3027 : std::memcpy(direct_input_.data(), buffer, sizeof(float) * static_cast<size_t>(num_samples));
177 :
178 : // First num_samples samples: convolution + previous overlap
179 776883 : for (int n = 0; n < num_samples; ++n) {
180 773856 : float y = (n < tail_len) ? direct_overlap_[static_cast<size_t>(n)] : 0.0f;
181 :
182 : // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j]
183 773856 : int j0 = std::max(0, n - (num_samples - 1));
184 773856 : int j1 = std::min(ir_len - 1, n);
185 3087390 : for (int j = j0; j <= j1; ++j) {
186 2313534 : y += direct_input_[static_cast<size_t>(n - j)] * ir[static_cast<size_t>(j)];
187 771178 : }
188 773856 : buffer[n] = y;
189 257952 : }
190 :
191 : // Tail samples for next block: y[num_samples .. num_samples+tail_len-1]
192 9081 : for (int t = 0; t < tail_len; ++t) {
193 6054 : const int n = num_samples + t;
194 6054 : float y = 0.0f;
195 :
196 : // y += sum_{j=0}^{ir_len-1} x[n-j] * h[j], where (n-j) in [0, num_samples-1]
197 6054 : int j0 = std::max(0, n - (num_samples - 1));
198 6054 : int j1 = std::min(ir_len - 1, n);
199 15144 : for (int j = j0; j <= j1; ++j) {
200 9090 : int xi = n - j;
201 9090 : if (xi >= 0 && xi < num_samples) {
202 9090 : y += direct_input_[static_cast<size_t>(xi)] * ir[static_cast<size_t>(j)];
203 3030 : }
204 3030 : }
205 6054 : direct_overlap_[static_cast<size_t>(t)] = y;
206 2018 : }
207 1011 : }
208 :
209 3036 : void ConvolutionEngine::process(float* buffer, int num_samples) {
210 3036 : if (!kernel_ || kernel_->ir_length() == 0) return;
211 :
212 3036 : int block_size = kernel_->block_size();
213 3036 : int fft_size = kernel_->fft_size();
214 3036 : int num_parts = kernel_->num_partitions();
215 :
216 : // If block size doesn't match, fall back to direct convolution
217 3036 : if (num_samples != block_size) {
218 3 : process_direct(buffer, num_samples);
219 3 : return;
220 : }
221 :
222 : // Also use direct convolution for very short IRs (1 partition, IR <= block_size)
223 3033 : if (num_parts == 1 && kernel_->ir_length() <= block_size) {
224 3030 : process_direct(buffer, num_samples);
225 3030 : return;
226 : }
227 :
228 : // --- Partitioned overlap-add convolution ---
229 :
230 : // 1. Prepare input: zero-pad to fft_size
231 3 : auto* input_cpx = reinterpret_cast<kiss_fft_cpx*>(input_cpx_.data());
232 771 : for (int i = 0; i < block_size; ++i) {
233 768 : input_cpx[static_cast<size_t>(i)].r = buffer[i];
234 768 : input_cpx[static_cast<size_t>(i)].i = 0.0f;
235 256 : }
236 771 : for (int i = block_size; i < fft_size; ++i) {
237 768 : input_cpx[static_cast<size_t>(i)].r = 0.0f;
238 768 : input_cpx[static_cast<size_t>(i)].i = 0.0f;
239 256 : }
240 :
241 : // 2. Forward FFT of input -> store in FDL at current index
242 3 : auto* fdl_data = reinterpret_cast<kiss_fft_cpx*>(fdl_[static_cast<size_t>(fdl_index_)].data());
243 3 : kiss_fft(fft_cfg_, input_cpx, fdl_data);
244 :
245 : // 3. Complex multiply-accumulate across all partitions
246 3 : auto* accum = reinterpret_cast<kiss_fft_cpx*>(accum_cpx_.data());
247 3 : std::memset(accum, 0, sizeof(kiss_fft_cpx) * static_cast<size_t>(fft_size));
248 :
249 15 : for (int k = 0; k < num_parts; ++k) {
250 : // FDL index for partition k (circular)
251 12 : int fdl_idx = (fdl_index_ - k + num_parts) % num_parts;
252 8 : const auto* fdl_block =
253 12 : reinterpret_cast<const kiss_fft_cpx*>(fdl_[static_cast<size_t>(fdl_idx)].data());
254 12 : const auto* ir_block = reinterpret_cast<const kiss_fft_cpx*>(kernel_->partition_freq(k));
255 :
256 12 : complex_multiply_accumulate(accum, fdl_block, ir_block, fft_size);
257 4 : }
258 :
259 : // 4. Inverse FFT
260 3 : auto* ifft_out = reinterpret_cast<kiss_fft_cpx*>(ifft_out_cpx_.data());
261 3 : kiss_fft(ifft_cfg_, accum, ifft_out);
262 :
263 : // kiss_fft inverse does NOT normalize -- divide by fft_size
264 3 : float norm = 1.0f / static_cast<float>(fft_size);
265 :
266 : // 5. Output = first block_size samples + overlap from previous block
267 771 : for (int i = 0; i < block_size; ++i) {
268 768 : buffer[i] = ifft_out[static_cast<size_t>(i)].r * norm + overlap_[static_cast<size_t>(i)];
269 256 : }
270 :
271 : // 6. Store new overlap (last block_size samples of IFFT result)
272 771 : for (int i = 0; i < block_size; ++i) {
273 768 : overlap_[static_cast<size_t>(i)] = ifft_out[static_cast<size_t>(block_size + i)].r * norm;
274 256 : }
275 :
276 : // 7. Advance FDL index
277 3 : fdl_index_ = (fdl_index_ + 1) % num_parts;
278 1012 : }
279 :
280 : } // namespace Amplitron
|