Line data Source code
1 : #include "audio/engine/audio_graph.h"
2 :
3 : #include <algorithm>
4 : #include <unordered_map>
5 : #include <unordered_set>
6 :
7 : namespace Amplitron {
8 :
9 4641 : int AudioGraph::add_node(const std::string &name, NodeRoutingType type,
10 : std::shared_ptr<Effect> pedal, int num_inputs) {
11 4641 : DSPNode node;
12 4641 : node.id = next_id_++; // Uses your unified member counter
13 4641 : node.name = name;
14 4641 : node.routing_type = type;
15 4641 : node.pedal = pedal;
16 :
17 : // Dynamically configure pin structures using the same unified ID pool
18 4641 : if (type == NodeRoutingType::Mixer || type == NodeRoutingType::MergeSum) {
19 1271 : int inputs_to_create = std::clamp((num_inputs > 0) ? num_inputs : 2, 2, 8);
20 2883 : for (int i = 0; i < inputs_to_create; ++i) {
21 1926 : node.input_pin_ids.push_back(next_id_++);
22 642 : }
23 957 : node.input_gains.assign(inputs_to_create, 1.0f);
24 957 : node.output_pin_ids.push_back(next_id_++); // 1 Output Pin
25 4322 : } else if (type == NodeRoutingType::Splitter) {
26 957 : node.input_pin_ids.push_back(next_id_++); // 1 Input Pin
27 957 : node.output_pin_ids.push_back(next_id_++); // Output Pin Branch A
28 957 : node.output_pin_ids.push_back(next_id_++); // Output Pin Branch B
29 319 : } else {
30 2727 : node.input_pin_ids.push_back(next_id_++);
31 2727 : node.output_pin_ids.push_back(next_id_++);
32 : }
33 :
34 4641 : nodes_.push_back(node);
35 :
36 : // Auto-recompile topology order whenever a structural block changes
37 4641 : rebuild_topology();
38 :
39 4641 : return node.id;
40 4641 : }
41 :
42 3171 : int AudioGraph::add_link(int source_pin_id, int dest_pin_id) {
43 : // Prevent duplicate connections between the exact same pair of pins
44 7092 : for (const auto &existing_link : links_) {
45 3927 : if (existing_link.source_pin_id == source_pin_id &&
46 15 : existing_link.dest_pin_id == dest_pin_id) {
47 6 : printf("add_link failed: duplicate link\n");
48 6 : return existing_link.id;
49 : }
50 : }
51 :
52 : // Enforce that each input pin can only have ONE incoming link
53 7080 : for (const auto &existing_link : links_) {
54 3921 : if (existing_link.dest_pin_id == dest_pin_id) {
55 6 : printf("add_link failed: Input pin %d already in use!\n", dest_pin_id);
56 6 : return -1; // Pin already in use!
57 : }
58 : }
59 :
60 : // Enforce that ALL pins can only have ONE outgoing link
61 3159 : int source_node_id = get_node_from_pin(source_pin_id);
62 3159 : if (source_node_id != -1) {
63 4212 : const DSPNode *src_node = find_node(source_node_id);
64 3159 : if (src_node) {
65 : // Count existing outgoing links from this specific pin
66 2106 : int out_count = 0;
67 7074 : for (const auto &existing_link : links_) {
68 3915 : if (existing_link.source_pin_id == source_pin_id) {
69 9 : printf("add_link failed: Output pin %d already has an outgoing connection!\n",
70 3 : source_pin_id);
71 9 : out_count++;
72 3 : }
73 : }
74 3159 : if (out_count >= 1) {
75 6 : return -1; // Each output pin can only have 1 outgoing connection!
76 : }
77 1050 : }
78 1050 : }
79 :
80 3150 : GraphLink link;
81 3150 : link.id = next_id_++; // Uses your unified member counter
82 3150 : link.source_pin_id = source_pin_id;
83 3150 : link.dest_pin_id = dest_pin_id;
84 :
85 3150 : links_.push_back(link);
86 :
87 : // Validate if the new patch wire forms an impossible audio loop feedback cycle
88 3150 : if (!rebuild_topology()) {
89 3 : printf("add_link failed: forms an impossible audio loop feedback cycle!\n");
90 : // If a feedback loop is detected, pop the dangerous link back off to keep the engine safe
91 3 : links_.pop_back();
92 3 : rebuild_topology();
93 3 : return -1;
94 : }
95 :
96 2098 : return link.id;
97 1057 : }
98 :
99 1353 : void AudioGraph::set_node_as_input(int node_id, bool is_input) {
100 1431 : for (auto &node : nodes_) {
101 1431 : if (node.id == node_id) {
102 1353 : node.is_graph_input = is_input;
103 1353 : rebuild_topology();
104 902 : break;
105 : }
106 : }
107 1353 : }
108 :
109 1368 : void AudioGraph::set_node_as_output(int node_id, bool is_output) {
110 4164 : for (auto &node : nodes_) {
111 4164 : if (node.id == node_id) {
112 1368 : node.is_graph_output = is_output;
113 1368 : rebuild_topology();
114 912 : break;
115 : }
116 : }
117 1368 : }
118 :
119 1332 : void AudioGraph::set_node_position(int node_id, float x, float y) {
120 4710 : for (auto &node : nodes_) {
121 4710 : if (node.id == node_id) {
122 1332 : node.x = x;
123 1332 : node.y = y;
124 1332 : break;
125 : }
126 : }
127 1332 : }
128 :
129 47325 : int AudioGraph::get_node_from_pin(int pin_id) const {
130 : // Search through all nodes to find which one owns the given Pin ID
131 140094 : for (const auto &node : nodes_) {
132 248358 : for (int p : node.input_pin_ids) {
133 142737 : if (p == pin_id) return node.id;
134 : }
135 228354 : for (int p : node.output_pin_ids) {
136 135585 : if (p == pin_id) return node.id;
137 : }
138 : }
139 8 : return -1; // Pin ID not found in any registered node
140 15775 : }
141 :
142 11811 : bool AudioGraph::rebuild_topology() {
143 : // Kahn's algorithm or DFS to topologically sort the nodes based on links.
144 : // Since your test suite cases are already passing, we can use a basic
145 : // Kahn's sort dependency tracker to map links to execution order.
146 :
147 11811 : sorted_node_ids_.clear();
148 :
149 : // 1. Forward Reachability BFS
150 11811 : std::unordered_set<int> forward_reachable;
151 11811 : std::vector<int> queue;
152 46359 : for (const auto &node : nodes_) {
153 34548 : if (node.is_graph_input) {
154 7980 : queue.push_back(node.id);
155 19496 : forward_reachable.insert(node.id);
156 2660 : }
157 : }
158 7874 : size_t head = 0;
159 30441 : while (head < queue.size()) {
160 18630 : int curr = queue[head++];
161 24840 : auto it = std::find_if(nodes_.begin(), nodes_.end(),
162 49815 : [&](const DSPNode &n) { return n.id == curr; });
163 18630 : if (it != nodes_.end()) {
164 42579 : for (int out_pin : it->output_pin_ids) {
165 84294 : for (const auto &link : links_) {
166 60345 : if (link.source_pin_id == out_pin) {
167 11301 : int dest = get_node_from_pin(link.dest_pin_id);
168 15068 : if (dest != -1 && forward_reachable.find(dest) == forward_reachable.end()) {
169 10650 : forward_reachable.insert(dest);
170 10650 : queue.push_back(dest);
171 3550 : }
172 3767 : }
173 : }
174 : }
175 6210 : }
176 : }
177 :
178 : // 2. Backward Reachability BFS
179 11811 : std::unordered_set<int> backward_reachable;
180 11811 : queue.clear();
181 46359 : for (const auto &node : nodes_) {
182 34548 : if (node.is_graph_output) {
183 5076 : queue.push_back(node.id);
184 16592 : backward_reachable.insert(node.id);
185 1692 : }
186 : }
187 7874 : head = 0;
188 22674 : while (head < queue.size()) {
189 10863 : int curr = queue[head++];
190 14484 : auto it = std::find_if(nodes_.begin(), nodes_.end(),
191 31302 : [&](const DSPNode &n) { return n.id == curr; });
192 10863 : if (it != nodes_.end()) {
193 26094 : for (int in_pin : it->input_pin_ids) {
194 49854 : for (const auto &link : links_) {
195 34623 : if (link.dest_pin_id == in_pin) {
196 6432 : int src = get_node_from_pin(link.source_pin_id);
197 8576 : if (src != -1 && backward_reachable.find(src) == backward_reachable.end()) {
198 5787 : backward_reachable.insert(src);
199 5787 : queue.push_back(src);
200 1929 : }
201 2144 : }
202 : }
203 : }
204 3621 : }
205 : }
206 :
207 : // 3. Update is_reachable for all nodes
208 46359 : for (auto &node : nodes_) {
209 34548 : node.is_reachable =
210 34548 : (forward_reachable.count(node.id) > 0 && backward_reachable.count(node.id) > 0);
211 : }
212 :
213 11811 : std::unordered_map<int, int> in_degree;
214 :
215 : // Initialize in-degree count for all active nodes
216 46359 : for (const auto &node : nodes_) {
217 34548 : in_degree[node.id] = 0;
218 : }
219 :
220 : // Calculate how many incoming cables are hooked up to each node
221 23358 : for (const auto &link : links_) {
222 11547 : int dest_node = get_node_from_pin(link.dest_pin_id);
223 11547 : if (dest_node != -1) {
224 11547 : in_degree[dest_node]++;
225 3849 : }
226 : }
227 :
228 : // Gather all source nodes that have 0 dependencies
229 11811 : std::vector<int> process_queue;
230 46359 : for (const auto &node : nodes_) {
231 34548 : if (in_degree[node.id] == 0) {
232 23685 : process_queue.push_back(node.id);
233 7895 : }
234 : }
235 :
236 : // Topologically extract nodes from the dependency queue
237 7874 : head = 0;
238 46353 : while (head < process_queue.size()) {
239 34542 : int current_node_id = process_queue[head++];
240 34542 : sorted_node_ids_.push_back(current_node_id);
241 :
242 : // Decrement dependencies for downstream targets linked to this node
243 165060 : for (const auto &node : nodes_) {
244 130518 : if (node.id != current_node_id) continue;
245 :
246 77724 : for (int out_pin : node.output_pin_ids) {
247 108045 : for (const auto &link : links_) {
248 64863 : if (link.source_pin_id == out_pin) {
249 11541 : int target_node = get_node_from_pin(link.dest_pin_id);
250 11541 : if (target_node != -1) {
251 11541 : in_degree[target_node]--;
252 11541 : if (in_degree[target_node] == 0) {
253 10857 : process_queue.push_back(target_node);
254 3619 : }
255 3847 : }
256 3847 : }
257 : }
258 : }
259 : }
260 : }
261 :
262 : // If the sorted list length doesn't match total nodes, an impossible feedback
263 : // loop exists!
264 11811 : if (sorted_node_ids_.size() != nodes_.size()) {
265 3 : return false; // Rejects connection modifications to protect engine
266 : // stability
267 : }
268 :
269 7872 : return true; // Topology built successfully!
270 11811 : }
271 :
272 285 : bool AudioGraph::remove_node(int node_id) {
273 380 : auto it = std::find_if(nodes_.begin(), nodes_.end(),
274 857 : [node_id](const DSPNode &n) { return n.id == node_id; });
275 :
276 285 : if (it != nodes_.end()) {
277 : // 1. Destroy all cables attached to this node's Input Pins
278 558 : for (int pin : it->input_pin_ids) {
279 465 : links_.erase(std::remove_if(links_.begin(), links_.end(),
280 156 : [pin](const GraphLink &l) { return l.dest_pin_id == pin; }),
281 372 : links_.end());
282 : }
283 : // 2. Destroy all cables attached to this node's Output Pins
284 567 : for (int pin : it->output_pin_ids) {
285 480 : links_.erase(
286 384 : std::remove_if(links_.begin(), links_.end(),
287 117 : [pin](const GraphLink &l) { return l.source_pin_id == pin; }),
288 384 : links_.end());
289 : }
290 :
291 : // 3. Erase the node and recompile the audio thread topology
292 279 : nodes_.erase(it);
293 279 : rebuild_topology();
294 279 : return true;
295 : }
296 4 : return false;
297 95 : }
298 42 : bool AudioGraph::remove_link(int link_id) {
299 56 : auto it = std::remove_if(links_.begin(), links_.end(),
300 125 : [link_id](const GraphLink &l) { return l.id == link_id; });
301 42 : if (it != links_.end()) {
302 33 : links_.erase(it, links_.end());
303 33 : rebuild_topology();
304 33 : return true;
305 : }
306 6 : return false;
307 14 : }
308 :
309 3534 : const DSPNode *AudioGraph::find_node(int node_id) const {
310 6882 : for (const auto &node : nodes_) {
311 6873 : if (node.id == node_id) return &node;
312 : }
313 6 : return nullptr;
314 1178 : }
315 :
316 6 : void AudioGraph::restore_node(const DSPNode &node) {
317 6 : nodes_.push_back(node);
318 6 : if (node.id >= next_id_) next_id_ = node.id + 1;
319 12 : for (int pin : node.input_pin_ids) {
320 6 : if (pin >= next_id_) next_id_ = pin + 1;
321 : }
322 18 : for (int pin : node.output_pin_ids) {
323 12 : if (pin >= next_id_) next_id_ = pin + 1;
324 : }
325 6 : rebuild_topology();
326 6 : }
327 :
328 9 : void AudioGraph::restore_link(const GraphLink &link) {
329 9 : int prev_next_id = next_id_;
330 9 : links_.push_back(link);
331 9 : if (link.id >= next_id_) next_id_ = link.id + 1;
332 9 : if (!rebuild_topology()) {
333 0 : links_.pop_back();
334 0 : next_id_ = prev_next_id;
335 0 : rebuild_topology();
336 0 : }
337 9 : }
338 24 : bool AudioGraph::add_input_pin(int node_id) {
339 27 : for (auto &node : nodes_) {
340 27 : if (node.id == node_id && node.routing_type == NodeRoutingType::Mixer) {
341 24 : if (node.input_pin_ids.size() < 8) {
342 21 : node.input_pin_ids.push_back(next_id_++);
343 21 : node.input_gains.push_back(1.0f);
344 : // Do not necessarily need to rebuild topology if we just added an unconnected pin
345 22 : return true;
346 : }
347 2 : return false;
348 : }
349 : }
350 0 : return false;
351 8 : }
352 :
353 24 : bool AudioGraph::remove_input_pin(int node_id, int pin_id) {
354 24 : for (auto &node : nodes_) {
355 24 : if (node.id == node_id && (node.routing_type == NodeRoutingType::Mixer ||
356 0 : node.routing_type == NodeRoutingType::MergeSum)) {
357 24 : if (node.input_pin_ids.size() > 2) {
358 18 : if (node.input_gains.size() < node.input_pin_ids.size()) {
359 0 : node.input_gains.resize(node.input_pin_ids.size(), 1.0f);
360 0 : }
361 18 : int index_to_remove = -1;
362 18 : if (pin_id == -1) {
363 18 : index_to_remove = node.input_pin_ids.size() - 1;
364 6 : } else {
365 0 : for (size_t i = 0; i < node.input_pin_ids.size(); ++i) {
366 0 : if (node.input_pin_ids[i] == pin_id) {
367 0 : index_to_remove = i;
368 0 : break;
369 : }
370 0 : }
371 : }
372 18 : if (index_to_remove != -1) {
373 18 : int pin_to_remove = node.input_pin_ids[index_to_remove];
374 : // Prevent removal if the pin is linked
375 36 : for (const auto &link : links_) {
376 18 : if (link.dest_pin_id == pin_to_remove) {
377 2 : return false;
378 : }
379 : }
380 18 : node.input_pin_ids.erase(node.input_pin_ids.begin() + index_to_remove);
381 18 : node.input_gains.erase(node.input_gains.begin() + index_to_remove);
382 18 : return true;
383 : }
384 0 : }
385 6 : return false;
386 : }
387 : }
388 0 : return false;
389 8 : }
390 :
391 0 : void AudioGraph::restore_input_pin(int node_id, int pin_id, int index, float gain) {
392 0 : for (auto &node : nodes_) {
393 0 : if (node.id == node_id) {
394 0 : if (index >= 0) {
395 0 : size_t idx = static_cast<size_t>(index);
396 0 : if (idx <= node.input_pin_ids.size()) {
397 0 : node.input_pin_ids.insert(node.input_pin_ids.begin() + idx, pin_id);
398 0 : while (node.input_gains.size() < idx) {
399 0 : node.input_gains.push_back(1.0f);
400 : }
401 0 : node.input_gains.insert(node.input_gains.begin() + idx, gain);
402 0 : } else {
403 0 : node.input_pin_ids.push_back(pin_id);
404 0 : node.input_gains.push_back(gain);
405 : }
406 0 : } else {
407 0 : node.input_pin_ids.push_back(pin_id);
408 0 : node.input_gains.push_back(gain);
409 : }
410 0 : if (pin_id >= next_id_) next_id_ = pin_id + 1;
411 0 : break;
412 : }
413 : }
414 0 : }
415 :
416 9 : void AudioGraph::set_mixer_input_gain(int node_id, size_t pin_index, float gain) {
417 18 : for (auto &node : nodes_) {
418 18 : if (node.id == node_id && node.routing_type == NodeRoutingType::Mixer) {
419 9 : if (pin_index < node.input_gains.size()) {
420 9 : node.input_gains[pin_index] = std::clamp(gain, 0.0f, 2.0f);
421 3 : }
422 6 : break;
423 : }
424 : }
425 9 : }
426 :
427 0 : bool AudioGraph::add_output_pin(int node_id) {
428 0 : for (auto &node : nodes_) {
429 0 : if (node.id == node_id && node.routing_type == NodeRoutingType::Splitter) {
430 0 : if (node.output_pin_ids.size() < 8) {
431 0 : node.output_pin_ids.push_back(next_id_++);
432 0 : return true;
433 : }
434 0 : return false;
435 : }
436 : }
437 0 : return false;
438 0 : }
439 :
440 0 : bool AudioGraph::remove_output_pin(int node_id, int pin_id) {
441 0 : for (auto &node : nodes_) {
442 0 : if (node.id == node_id && node.routing_type == NodeRoutingType::Splitter) {
443 0 : if (node.output_pin_ids.size() > 2) {
444 0 : int index_to_remove = -1;
445 0 : if (pin_id == -1) {
446 0 : index_to_remove = node.output_pin_ids.size() - 1;
447 0 : } else {
448 0 : for (size_t i = 0; i < node.output_pin_ids.size(); ++i) {
449 0 : if (node.output_pin_ids[i] == pin_id) {
450 0 : index_to_remove = i;
451 0 : break;
452 : }
453 0 : }
454 : }
455 0 : if (index_to_remove != -1) {
456 0 : int pin_to_remove = node.output_pin_ids[index_to_remove];
457 0 : for (const auto &link : links_) {
458 0 : if (link.source_pin_id == pin_to_remove) {
459 0 : return false;
460 : }
461 : }
462 0 : node.output_pin_ids.erase(node.output_pin_ids.begin() + index_to_remove);
463 0 : return true;
464 : }
465 0 : }
466 0 : return false;
467 : }
468 : }
469 0 : return false;
470 0 : }
471 :
472 0 : void AudioGraph::restore_output_pin(int node_id, int pin_id, int index) {
473 0 : for (auto &node : nodes_) {
474 0 : if (node.id == node_id && node.routing_type == NodeRoutingType::Splitter) {
475 0 : if (index >= 0 && index <= node.output_pin_ids.size()) {
476 0 : node.output_pin_ids.insert(node.output_pin_ids.begin() + index, pin_id);
477 0 : } else {
478 0 : node.output_pin_ids.push_back(pin_id);
479 : }
480 0 : if (pin_id >= next_id_) next_id_ = pin_id + 1;
481 0 : break;
482 : }
483 : }
484 0 : }
485 :
486 : } // namespace Amplitron
|