Add audio resampling option

You can now select between linear interpolation and nearest-neighbor resampling algorithms.
This commit is contained in:
sergystepanov 2025-12-13 23:56:38 +03:00
parent 129690e901
commit 84ad0a4cac
6 changed files with 178 additions and 85 deletions

View file

@ -341,6 +341,9 @@ encoder:
frames:
- 10
- 5
# linear (1) or nearest neighbour (0) audio resampler
# linear should sound slightly better
resampler: 1
video:
# h264, vpx (vp8) or vp9
codec: h264

View file

@ -51,7 +51,8 @@ type Encoder struct {
}
type Audio struct {
Frames []float32
Frames []float32
Resampler int
}
type Video struct {

View file

@ -1,19 +1,28 @@
package media
import (
"errors"
"math"
"unsafe"
import "errors"
type ResampleAlgo uint8
const (
ResampleNearest ResampleAlgo = iota
ResampleLinear
)
// preallocated scratch buffer for resampling output
// size for max Opus frame: 60ms at 48kHz stereo = 48000 * 0.06 * 2 = 5760 samples
var stretchBuf = make(samples, 5760)
// buffer is a simple non-concurrent safe buffer for audio samples.
type buffer struct {
stretch bool
frameHz []int
useResample bool
algo ResampleAlgo
srcHz int
raw samples
raw samples
buckets []bucket
cur *bucket
bi int
}
type bucket struct {
@ -25,100 +34,180 @@ type bucket struct {
func newBuffer(frames []float32, hz int) (*buffer, error) {
if hz < 2000 {
return nil, errors.New("hz should be > than 2000")
return nil, errors.New("hz should be > 2000")
}
if len(frames) == 0 {
return nil, errors.New("frames list is empty")
}
buf := buffer{}
buf := buffer{srcHz: hz}
// preallocate continuous array
s := 0
totalSize := 0
for _, f := range frames {
s += frame(hz, f)
}
buf.raw = make(samples, s)
if len(buf.raw) == 0 {
return nil, errors.New("seems those params are bad and the buffer is 0")
totalSize += frameStereoSamples(hz, f)
}
next := 0
if totalSize == 0 {
return nil, errors.New("calculated buffer size is 0, check params")
}
buf.raw = make(samples, totalSize)
// map buckets to the raw continuous array
offset := 0
for _, f := range frames {
s := frame(hz, f)
size := frameStereoSamples(hz, f)
buf.buckets = append(buf.buckets, bucket{
mem: buf.raw[next : next+s],
mem: buf.raw[offset : offset+size],
ms: f,
})
next += s
offset += size
}
buf.cur = &buf.buckets[len(buf.buckets)-1]
// start with the largest bucket (last one, assuming frames are sorted ascending)
buf.bi = len(buf.buckets) - 1
return &buf, nil
}
func (b *buffer) choose(l int) {
for _, bb := range b.buckets {
if l >= len(bb.mem) {
b.cur = &bb
break
// cur returns the current bucket pointer
func (b *buffer) cur() *bucket { return &b.buckets[b.bi] }
// choose selects the best bucket for the remaining samples.
// It picks the largest bucket that can be completely filled.
// Buckets should be sorted by size ascending for this to work optimally.
func (b *buffer) choose(remaining int) {
// search from largest to smallest
for i := len(b.buckets) - 1; i >= 0; i-- {
if remaining >= len(b.buckets[i].mem) {
b.bi = i
return
}
}
// fall back to smallest bucket if remaining < all bucket sizes
b.bi = 0
}
func (b *buffer) resample(hz int) {
b.stretch = true
// resample enables resampling to target Hz with specified algorithm
func (b *buffer) resample(targetHz int, algo ResampleAlgo) {
b.useResample = true
b.algo = algo
for i := range b.buckets {
b.buckets[i].dst = frame(hz, b.buckets[i].ms)
b.buckets[i].dst = frameStereoSamples(targetHz, b.buckets[i].ms)
}
}
// write fills the buffer until it's full and then passes the gathered data into a callback.
//
// There are two cases to consider:
// 1. Underflow, when the length of the written data is less than the buffer's available space.
// 2. Overflow, when the length exceeds the current available buffer space.
//
// We overwrite any previous values in the buffer and move the internal write pointer
// by the length of the written data.
// In the first case, we won't call the callback, but it will be called every time
// when the internal buffer overflows until all samples are read.
// It will choose between multiple internal buffers to fit remaining samples.
func (b *buffer) write(s samples, onFull func(samples, float32)) (r int) {
for r < len(s) {
buf := b.cur
w := copy(buf.mem[buf.p:], s[r:])
r += w
buf.p += w
if buf.p == len(buf.mem) {
if b.stretch {
onFull(buf.mem.stretch(buf.dst), buf.ms)
// stretch applies the selected resampling algorithm
func (b *buffer) stretch(src samples, dstSize int) samples {
switch b.algo {
case ResampleNearest:
return stretchNearest(src, dstSize)
case ResampleLinear:
return stretchLinear(src, dstSize)
default:
return stretchLinear(src, dstSize)
}
}
// write fills the buffer and calls onFull when a complete frame is ready.
// returns the number of samples consumed.
func (b *buffer) write(s samples, onFull func(samples, float32)) int {
read := 0
for read < len(s) {
cur := b.cur()
// copy all samples into current bucket
n := copy(cur.mem[cur.p:], s[read:])
read += n
cur.p += n
// bucket is full - emit frame
if cur.p == len(cur.mem) {
if b.useResample {
onFull(b.stretch(cur.mem, cur.dst), cur.ms)
} else {
onFull(buf.mem, buf.ms)
onFull(cur.mem, cur.ms)
}
b.choose(len(s) - r)
b.cur.p = 0
// select next bucket and reset write position
b.choose(len(s) - read)
b.cur().p = 0
}
}
return
return read
}
// frame calculates an audio stereo frame size, i.e. 48k*frame/1000*2
// with round(x / 2) * 2 for the closest even number
func frame(hz int, frame float32) int {
return int(math.Round(float64(hz)*float64(frame)/1000/2) * 2 * 2)
// frameStereoSamples calculates stereo frame size in samples.
// e.g., 48000 Hz * 20ms = 960 samples/channel * 2 channels = 1920 total samples
func frameStereoSamples(hz int, ms float32) int {
samplesPerChannel := int(float32(hz)*ms/1000 + 0.5) // round to nearest
return samplesPerChannel * 2 // stereo
}
// stretch does a simple stretching of audio samples.
// something like: [1,2,3,4,5,6] -> [1,2,x,x,3,4,x,x,5,6,x,x] -> [1,2,1,2,3,4,3,4,5,6,5,6]
func (s samples) stretch(size int) []int16 {
out := buf[:size]
n := len(s)
ratio := float32(size) / float32(n)
sPtr := unsafe.Pointer(&s[0])
for i, l, r := 0, 0, 0; i < n; i += 2 {
l, r = r, int(float32((i+2)>>1)*ratio)<<1 // index in src * ratio -> approximated index in dst *2 due to int16
for j := l; j < r; j += 2 {
*(*int32)(unsafe.Pointer(&out[j])) = *(*int32)(sPtr) // out[j] = s[i]; out[j+1] = s[i+1]
}
sPtr = unsafe.Add(sPtr, uintptr(4))
// stretchLinear resamples stereo audio using linear interpolation.
func stretchLinear(src samples, dstSize int) samples {
srcLen := len(src)
if srcLen < 2 || dstSize < 2 {
return stretchBuf[:dstSize]
}
out := stretchBuf[:dstSize]
srcPairs := srcLen / 2
dstPairs := dstSize / 2
// Fixed-point ratio for precision (16.16 fixed point)
ratio := ((srcPairs - 1) << 16) / (dstPairs - 1)
for i := 0; i < dstPairs; i++ {
// Calculate source position in fixed-point
pos := i * ratio
srcIdx := pos >> 16
frac := pos & 0xFFFF
dstIdx := i * 2
if srcIdx >= srcPairs-1 {
// Last sample - no interpolation
out[dstIdx] = src[srcLen-2]
out[dstIdx+1] = src[srcLen-1]
} else {
// Linear interpolation for both channels
srcBase := srcIdx * 2
// Left channel
l0 := int32(src[srcBase])
l1 := int32(src[srcBase+2])
out[dstIdx] = int16(l0 + ((l1-l0)*int32(frac))>>16)
// Right channel
r0 := int32(src[srcBase+1])
r1 := int32(src[srcBase+3])
out[dstIdx+1] = int16(r0 + ((r1-r0)*int32(frac))>>16)
}
}
return out
}
// stretchNearest is a faster nearest-neighbor version if quality isn't critical
func stretchNearest(src samples, dstSize int) samples {
srcLen := len(src)
if srcLen < 2 || dstSize < 2 {
return stretchBuf[:dstSize]
}
out := stretchBuf[:dstSize]
srcPairs := srcLen / 2
dstPairs := dstSize / 2
for i := 0; i < dstPairs; i++ {
srcIdx := (i * srcPairs / dstPairs) * 2
dstIdx := i * 2
out[dstIdx] = src[srcIdx]
out[dstIdx+1] = src[srcIdx+1]
}
return out
}

View file

@ -23,7 +23,11 @@ func TestBufferWrite(t *testing.T) {
{sample: 2, len: 20},
{sample: 3, len: 30},
},
expect: samples{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
expect: samples{
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
},
},
{
bufLen: 2000,
@ -32,7 +36,7 @@ func TestBufferWrite(t *testing.T) {
{sample: 2, len: 18},
{sample: 3, len: 2},
},
expect: samples{2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
expect: samples{1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
},
}
@ -48,7 +52,7 @@ func TestBufferWrite(t *testing.T) {
)
}
if !reflect.DeepEqual(test.expect, lastResult) {
t.Errorf("not expted buffer, %v != %v, %v", lastResult, test.expect, len(buf.cur.mem))
t.Errorf("not expted buffer, %v != %v, %v", lastResult, test.expect, len(buf.buckets))
}
}
}

View file

@ -12,17 +12,13 @@ import (
"github.com/giongto35/cloud-game/v3/pkg/worker/caged/app"
)
const (
audioHz = 48000
sampleBufLen = 1024 * 4
)
const audioHz = 48000
type samples []int16
var (
encoderOnce = sync.Once{}
opusCoder *opus.Encoder
buf = make([]int16, sampleBufLen)
)
func DefaultOpus() (*opus.Encoder, error) {
@ -116,7 +112,7 @@ func (wmp *WebrtcMediaPipe) initAudio(srcHz int, frameSizes []float32) error {
wmp.log.Debug().Msgf("Opus frames (ms): %v", frameSizes)
dstHz, _ := au.SampleRate()
if srcHz != dstHz {
buf.resample(dstHz)
buf.resample(dstHz, ResampleAlgo(wmp.aConf.Resampler))
wmp.log.Debug().Msgf("Resample %vHz -> %vHz", srcHz, dstHz)
}
wmp.audioBuf = buf

View file

@ -126,7 +126,7 @@ func TestResampleStretch(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rez2 := tt.args.pcm.stretch(tt.args.size)
rez2 := stretchNearest(tt.args.pcm, tt.args.size)
if rez2[0] != tt.args.pcm[0] || rez2[1] != tt.args.pcm[1] ||
rez2[len(rez2)-1] != tt.args.pcm[len(tt.args.pcm)-1] ||
rez2[len(rez2)-2] != tt.args.pcm[len(tt.args.pcm)-2] {
@ -141,7 +141,7 @@ func BenchmarkResampler(b *testing.B) {
pcm := samples(gen(1764))
size := 1920
for i := 0; i < b.N; i++ {
pcm.stretch(size)
stretchLinear(pcm, size)
}
}
@ -170,7 +170,7 @@ func TestFrame(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := frame(tt.args.hz, tt.args.frame); got != tt.want {
if got := frameStereoSamples(tt.args.hz, tt.args.frame); got != tt.want {
t.Errorf("frame() = %v, want %v", got, tt.want)
}
})