diff --git a/pkg/config/config.yaml b/pkg/config/config.yaml index 7de1b43d..275eea59 100644 --- a/pkg/config/config.yaml +++ b/pkg/config/config.yaml @@ -341,6 +341,9 @@ encoder: frames: - 10 - 5 + # linear (1) or nearest neighbour (0) audio resampler + # linear should sound slightly better + resampler: 1 video: # h264, vpx (vp8) or vp9 codec: h264 diff --git a/pkg/config/worker.go b/pkg/config/worker.go index 5a509b0c..014ce644 100644 --- a/pkg/config/worker.go +++ b/pkg/config/worker.go @@ -51,7 +51,8 @@ type Encoder struct { } type Audio struct { - Frames []float32 + Frames []float32 + Resampler int } type Video struct { diff --git a/pkg/worker/media/buffer.go b/pkg/worker/media/buffer.go index e80a7c82..836e1c53 100644 --- a/pkg/worker/media/buffer.go +++ b/pkg/worker/media/buffer.go @@ -1,19 +1,28 @@ package media -import ( - "errors" - "math" - "unsafe" +import "errors" + +type ResampleAlgo uint8 + +const ( + ResampleNearest ResampleAlgo = iota + ResampleLinear ) +// preallocated scratch buffer for resampling output +// size for max Opus frame: 60ms at 48kHz stereo = 48000 * 0.06 * 2 = 5760 samples +var stretchBuf = make(samples, 5760) + // buffer is a simple non-concurrent safe buffer for audio samples. type buffer struct { - stretch bool - frameHz []int + useResample bool + algo ResampleAlgo + srcHz int + + raw samples - raw samples buckets []bucket - cur *bucket + bi int } type bucket struct { @@ -25,100 +34,180 @@ type bucket struct { func newBuffer(frames []float32, hz int) (*buffer, error) { if hz < 2000 { - return nil, errors.New("hz should be > than 2000") + return nil, errors.New("hz should be > 2000") + } + if len(frames) == 0 { + return nil, errors.New("frames list is empty") } - buf := buffer{} + buf := buffer{srcHz: hz} - // preallocate continuous array - s := 0 + totalSize := 0 for _, f := range frames { - s += frame(hz, f) - } - buf.raw = make(samples, s) - - if len(buf.raw) == 0 { - return nil, errors.New("seems those params are bad and the buffer is 0") + totalSize += frameStereoSamples(hz, f) } - next := 0 + if totalSize == 0 { + return nil, errors.New("calculated buffer size is 0, check params") + } + + buf.raw = make(samples, totalSize) + + // map buckets to the raw continuous array + offset := 0 for _, f := range frames { - s := frame(hz, f) + size := frameStereoSamples(hz, f) buf.buckets = append(buf.buckets, bucket{ - mem: buf.raw[next : next+s], + mem: buf.raw[offset : offset+size], ms: f, }) - next += s + offset += size } - buf.cur = &buf.buckets[len(buf.buckets)-1] + + // start with the largest bucket (last one, assuming frames are sorted ascending) + buf.bi = len(buf.buckets) - 1 + return &buf, nil } -func (b *buffer) choose(l int) { - for _, bb := range b.buckets { - if l >= len(bb.mem) { - b.cur = &bb - break +// cur returns the current bucket pointer +func (b *buffer) cur() *bucket { return &b.buckets[b.bi] } + +// choose selects the best bucket for the remaining samples. +// It picks the largest bucket that can be completely filled. +// Buckets should be sorted by size ascending for this to work optimally. +func (b *buffer) choose(remaining int) { + // search from largest to smallest + for i := len(b.buckets) - 1; i >= 0; i-- { + if remaining >= len(b.buckets[i].mem) { + b.bi = i + return } } + // fall back to smallest bucket if remaining < all bucket sizes + b.bi = 0 } -func (b *buffer) resample(hz int) { - b.stretch = true +// resample enables resampling to target Hz with specified algorithm +func (b *buffer) resample(targetHz int, algo ResampleAlgo) { + b.useResample = true + b.algo = algo for i := range b.buckets { - b.buckets[i].dst = frame(hz, b.buckets[i].ms) + b.buckets[i].dst = frameStereoSamples(targetHz, b.buckets[i].ms) } } -// write fills the buffer until it's full and then passes the gathered data into a callback. -// -// There are two cases to consider: -// 1. Underflow, when the length of the written data is less than the buffer's available space. -// 2. Overflow, when the length exceeds the current available buffer space. -// -// We overwrite any previous values in the buffer and move the internal write pointer -// by the length of the written data. -// In the first case, we won't call the callback, but it will be called every time -// when the internal buffer overflows until all samples are read. -// It will choose between multiple internal buffers to fit remaining samples. -func (b *buffer) write(s samples, onFull func(samples, float32)) (r int) { - for r < len(s) { - buf := b.cur - w := copy(buf.mem[buf.p:], s[r:]) - r += w - buf.p += w - if buf.p == len(buf.mem) { - if b.stretch { - onFull(buf.mem.stretch(buf.dst), buf.ms) +// stretch applies the selected resampling algorithm +func (b *buffer) stretch(src samples, dstSize int) samples { + switch b.algo { + case ResampleNearest: + return stretchNearest(src, dstSize) + case ResampleLinear: + return stretchLinear(src, dstSize) + default: + return stretchLinear(src, dstSize) + } +} + +// write fills the buffer and calls onFull when a complete frame is ready. +// returns the number of samples consumed. +func (b *buffer) write(s samples, onFull func(samples, float32)) int { + read := 0 + for read < len(s) { + cur := b.cur() + + // copy all samples into current bucket + n := copy(cur.mem[cur.p:], s[read:]) + read += n + cur.p += n + + // bucket is full - emit frame + if cur.p == len(cur.mem) { + if b.useResample { + onFull(b.stretch(cur.mem, cur.dst), cur.ms) } else { - onFull(buf.mem, buf.ms) + onFull(cur.mem, cur.ms) } - b.choose(len(s) - r) - b.cur.p = 0 + + // select next bucket and reset write position + b.choose(len(s) - read) + b.cur().p = 0 } } - return + return read } -// frame calculates an audio stereo frame size, i.e. 48k*frame/1000*2 -// with round(x / 2) * 2 for the closest even number -func frame(hz int, frame float32) int { - return int(math.Round(float64(hz)*float64(frame)/1000/2) * 2 * 2) +// frameStereoSamples calculates stereo frame size in samples. +// e.g., 48000 Hz * 20ms = 960 samples/channel * 2 channels = 1920 total samples +func frameStereoSamples(hz int, ms float32) int { + samplesPerChannel := int(float32(hz)*ms/1000 + 0.5) // round to nearest + return samplesPerChannel * 2 // stereo } -// stretch does a simple stretching of audio samples. -// something like: [1,2,3,4,5,6] -> [1,2,x,x,3,4,x,x,5,6,x,x] -> [1,2,1,2,3,4,3,4,5,6,5,6] -func (s samples) stretch(size int) []int16 { - out := buf[:size] - n := len(s) - ratio := float32(size) / float32(n) - sPtr := unsafe.Pointer(&s[0]) - for i, l, r := 0, 0, 0; i < n; i += 2 { - l, r = r, int(float32((i+2)>>1)*ratio)<<1 // index in src * ratio -> approximated index in dst *2 due to int16 - for j := l; j < r; j += 2 { - *(*int32)(unsafe.Pointer(&out[j])) = *(*int32)(sPtr) // out[j] = s[i]; out[j+1] = s[i+1] - } - sPtr = unsafe.Add(sPtr, uintptr(4)) +// stretchLinear resamples stereo audio using linear interpolation. +func stretchLinear(src samples, dstSize int) samples { + srcLen := len(src) + if srcLen < 2 || dstSize < 2 { + return stretchBuf[:dstSize] } + + out := stretchBuf[:dstSize] + + srcPairs := srcLen / 2 + dstPairs := dstSize / 2 + + // Fixed-point ratio for precision (16.16 fixed point) + ratio := ((srcPairs - 1) << 16) / (dstPairs - 1) + + for i := 0; i < dstPairs; i++ { + // Calculate source position in fixed-point + pos := i * ratio + srcIdx := pos >> 16 + frac := pos & 0xFFFF + + dstIdx := i * 2 + + if srcIdx >= srcPairs-1 { + // Last sample - no interpolation + out[dstIdx] = src[srcLen-2] + out[dstIdx+1] = src[srcLen-1] + } else { + // Linear interpolation for both channels + srcBase := srcIdx * 2 + + // Left channel + l0 := int32(src[srcBase]) + l1 := int32(src[srcBase+2]) + out[dstIdx] = int16(l0 + ((l1-l0)*int32(frac))>>16) + + // Right channel + r0 := int32(src[srcBase+1]) + r1 := int32(src[srcBase+3]) + out[dstIdx+1] = int16(r0 + ((r1-r0)*int32(frac))>>16) + } + } + + return out +} + +// stretchNearest is a faster nearest-neighbor version if quality isn't critical +func stretchNearest(src samples, dstSize int) samples { + srcLen := len(src) + if srcLen < 2 || dstSize < 2 { + return stretchBuf[:dstSize] + } + + out := stretchBuf[:dstSize] + + srcPairs := srcLen / 2 + dstPairs := dstSize / 2 + + for i := 0; i < dstPairs; i++ { + srcIdx := (i * srcPairs / dstPairs) * 2 + dstIdx := i * 2 + out[dstIdx] = src[srcIdx] + out[dstIdx+1] = src[srcIdx+1] + } + return out } diff --git a/pkg/worker/media/buffer_test.go b/pkg/worker/media/buffer_test.go index 29f2fc6a..28a596ba 100644 --- a/pkg/worker/media/buffer_test.go +++ b/pkg/worker/media/buffer_test.go @@ -23,7 +23,11 @@ func TestBufferWrite(t *testing.T) { {sample: 2, len: 20}, {sample: 3, len: 30}, }, - expect: samples{3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3}, + expect: samples{ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + }, }, { bufLen: 2000, @@ -32,7 +36,7 @@ func TestBufferWrite(t *testing.T) { {sample: 2, len: 18}, {sample: 3, len: 2}, }, - expect: samples{2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, + expect: samples{1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, }, } @@ -48,7 +52,7 @@ func TestBufferWrite(t *testing.T) { ) } if !reflect.DeepEqual(test.expect, lastResult) { - t.Errorf("not expted buffer, %v != %v, %v", lastResult, test.expect, len(buf.cur.mem)) + t.Errorf("not expted buffer, %v != %v, %v", lastResult, test.expect, len(buf.buckets)) } } } diff --git a/pkg/worker/media/media.go b/pkg/worker/media/media.go index b08ec692..0d1407d6 100644 --- a/pkg/worker/media/media.go +++ b/pkg/worker/media/media.go @@ -12,17 +12,13 @@ import ( "github.com/giongto35/cloud-game/v3/pkg/worker/caged/app" ) -const ( - audioHz = 48000 - sampleBufLen = 1024 * 4 -) +const audioHz = 48000 type samples []int16 var ( encoderOnce = sync.Once{} opusCoder *opus.Encoder - buf = make([]int16, sampleBufLen) ) func DefaultOpus() (*opus.Encoder, error) { @@ -116,7 +112,7 @@ func (wmp *WebrtcMediaPipe) initAudio(srcHz int, frameSizes []float32) error { wmp.log.Debug().Msgf("Opus frames (ms): %v", frameSizes) dstHz, _ := au.SampleRate() if srcHz != dstHz { - buf.resample(dstHz) + buf.resample(dstHz, ResampleAlgo(wmp.aConf.Resampler)) wmp.log.Debug().Msgf("Resample %vHz -> %vHz", srcHz, dstHz) } wmp.audioBuf = buf diff --git a/pkg/worker/media/media_test.go b/pkg/worker/media/media_test.go index 4b9a431b..f754e17e 100644 --- a/pkg/worker/media/media_test.go +++ b/pkg/worker/media/media_test.go @@ -126,7 +126,7 @@ func TestResampleStretch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - rez2 := tt.args.pcm.stretch(tt.args.size) + rez2 := stretchNearest(tt.args.pcm, tt.args.size) if rez2[0] != tt.args.pcm[0] || rez2[1] != tt.args.pcm[1] || rez2[len(rez2)-1] != tt.args.pcm[len(tt.args.pcm)-1] || rez2[len(rez2)-2] != tt.args.pcm[len(tt.args.pcm)-2] { @@ -141,7 +141,7 @@ func BenchmarkResampler(b *testing.B) { pcm := samples(gen(1764)) size := 1920 for i := 0; i < b.N; i++ { - pcm.stretch(size) + stretchLinear(pcm, size) } } @@ -170,7 +170,7 @@ func TestFrame(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := frame(tt.args.hz, tt.args.frame); got != tt.want { + if got := frameStereoSamples(tt.args.hz, tt.args.frame); got != tt.want { t.Errorf("frame() = %v, want %v", got, tt.want) } })