Skip to content

Commit

Permalink
Do cross-correlation based matching
Browse files Browse the repository at this point in the history
Also adds a 2 channel test and goes back to sinc resampler
  • Loading branch information
xd009642 committed Sep 4, 2024
1 parent f843d84 commit 69e5d80
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 30 deletions.
10 changes: 0 additions & 10 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ tracing-opentelemetry = "0.24.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }

[dev-dependencies]
approx = "0.5.1"
dasp = { version = "0.11.0", features = ["signal"] }
hound = "3.5.1"
tracing-test = { version = "0.2.4", features = ["no-env-filter"] }
Expand Down
113 changes: 94 additions & 19 deletions src/audio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use crate::api_types::AudioFormat;
use crate::AudioChannel;
use bytes::Bytes;
use rubato::{FftFixedIn, Resampler};
use rubato::{
calculate_cutoff, Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType,
WindowFunction,
};
use tokio::sync::mpsc;
use tracing::{instrument, trace};

Expand All @@ -22,20 +25,29 @@ pub async fn decode_audio(
anyhow::bail!("No output sinks for channel data");
}

const RESAMPLER_SIZE: usize = 2048;
const RESAMPLER_SIZE: usize = 4086;

let resample_ratio = 16000.0 / audio_format.sample_rate as f64;

trace!("Resampler ratio: {}", resample_ratio);

let mut resampler = if audio_format.sample_rate != 16000 {
let resampler = FftFixedIn::new(
audio_format.sample_rate as usize,
16000,
let window = WindowFunction::Blackman;
let params = SincInterpolationParameters {
sinc_len: 256,
f_cutoff: calculate_cutoff(256, window),
oversampling_factor: 128,
interpolation: SincInterpolationType::Cubic,
window,
};
let resampler = SincFixedIn::new(
16000.0 / audio_format.sample_rate as f64,
1.0,
params,
RESAMPLER_SIZE,
1024,
audio_format.channels,
)?;

trace!(
input_frames_max = resampler.input_frames_max(),
output_frames_max = resampler.output_frames_next(),
Expand Down Expand Up @@ -69,7 +81,6 @@ pub async fn decode_audio(
current_buffer.append(&mut samples);

if current_buffer.len() >= resample_trigger_len || resampler.is_none() {
let len = current_buffer.len();
let capacity = RESAMPLER_SIZE.min(current_buffer.len() / audio_format.channels);
let mut channels = vec![Vec::with_capacity(RESAMPLER_SIZE); audio_format.channels];
for (chan, data) in (0..channel_data_tx.len())
Expand All @@ -85,7 +96,7 @@ pub async fn decode_audio(
channels
};

for (i, (data, sink)) in channels.drain(..).zip(&channel_data_tx).enumerate() {
for (data, sink) in channels.drain(..).zip(&channel_data_tx) {
//trace!("Emitting {} samples for channel {}", data.len(), i);
sent_samples += data.len();
sink.send(data.into()).await?;
Expand Down Expand Up @@ -121,7 +132,7 @@ pub async fn decode_audio(
} else {
channels
};
for (i, (mut data, sink)) in channels.drain(..).zip(&channel_data_tx).enumerate() {
for (mut data, sink) in channels.drain(..).zip(&channel_data_tx) {
if let Some(new_len) = new_len {
trace!(
"Downsizing to avoid trailing silence to {} bytes from {}",
Expand Down Expand Up @@ -181,11 +192,10 @@ impl Sample for f32 {
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use bytes::{Buf, BytesMut};
use dasp::{signal, Signal};
use futures::stream::{FuturesOrdered, StreamExt};
use std::{fs, path::Path};
use std::fs;
use tracing_test::traced_test;

fn write_wav(path: &str, samples: &Vec<Vec<f32>>) {
Expand All @@ -207,6 +217,29 @@ mod tests {
writer.finalize().unwrap();
}

/// Equation taken from: https://paulbourke.net/miscellaneous/correlate/
///
/// We do this without a delay because we're not looking for when a signal overlaps it should
/// be overlapping at d=0.
///
/// A score of 1 means the signals are the same. A signal of -1 would mean they're exact
/// opposites. So if we set of a tolerance of 0.95 I guess we can roughly refer to that as the
/// signals being 97.5% equivalent.
fn xcorr(a: &[f32], b: &[f32]) -> f32 {
let mean_a = a.iter().sum::<f32>() / a.len() as f32;
let mean_b = b.iter().sum::<f32>() / b.len() as f32;

let num = a
.iter()
.zip(b.iter())
.map(|(a, b)| (a - mean_a) * (b - mean_b))
.sum::<f32>();
let dn_a = a.iter().map(|x| (x - mean_a).powi(2)).sum::<f32>().sqrt();
let dn_b = b.iter().map(|x| (x - mean_b).powi(2)).sum::<f32>().sqrt();

num / (dn_a * dn_b)
}

/// Given an input audio format, the bytes for this audio, a chunk size to stream the bytes
/// into the encoder and an expected output run the audio through the decoding pipeline and
/// compare it.
Expand Down Expand Up @@ -261,14 +294,12 @@ mod tests {
for (channel_index, (expected_channel, actual_channel)) in
expected.iter().zip(resampled.iter()).enumerate()
{
assert_eq!(expected_channel.len(), actual_channel.len());
for (sample_index, (expected, actual)) in expected_channel
.iter()
.zip(actual_channel.iter())
.enumerate()
{
assert_abs_diff_eq!(expected, actual, epsilon = 0.1); // This would be a 5% error
}
let similarity = xcorr(&expected_channel, &actual_channel);
println!(
"Channel {} cross correlation is {}",
channel_index, similarity
);
assert!(similarity > 0.95);
}

let _ = fs::remove_file(&expected_name);
Expand Down Expand Up @@ -402,4 +433,48 @@ mod tests {

test_audio(format, input, 300, vec![expected_output], "downsample_s16").await;
}

#[tokio::test]
#[traced_test]
async fn multichannel_audio() {
let format = AudioFormat {
sample_rate: 16000,
channels: 2,
bit_depth: 32,
is_float: true,
};

let channel_1 = signal::rate(16000.0)
.const_hz(800.0)
.noise_simplex()
.take(32000)
.map(|x| x as f32)
.collect::<Vec<f32>>();

let channel_2 = signal::rate(16000.0)
.const_hz(900.0)
.noise_simplex()
.take(32000)
.map(|x| x as f32)
.collect::<Vec<f32>>();

let input = channel_1
.iter()
.zip(channel_2.iter())
.flat_map(|(c_1, c_2)| {
let c1 = c_1.to_le_bytes();
let c2 = c_2.to_le_bytes();
[c1[0], c1[1], c1[2], c1[3], c2[0], c2[1], c2[2], c2[3]]
})
.collect::<BytesMut>();

test_audio(
format,
input,
600,
vec![channel_1, channel_2],
"multichannel",
)
.await;
}
}

0 comments on commit 69e5d80

Please sign in to comment.