Skip to content

Commit

Permalink
Merge pull request #140 from pbdm/main
Browse files Browse the repository at this point in the history
feat(av-cliper): concurrency download for hls-loader
  • Loading branch information
hughfenghen authored Jun 20, 2024
2 parents c207c72 + f58342b commit 8f6ec55
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 9 deletions.
32 changes: 31 additions & 1 deletion packages/av-cliper/src/data-loader/__tests__/hls-loader.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { expect, test } from 'vitest';
import { expect, test, vi } from 'vitest';
import { MP4Clip } from '../../clips/mp4-clip';
import { createHLSLoader } from '../hls-loader';

Expand Down Expand Up @@ -41,3 +41,33 @@ test('hls loader default time', async () => {
expect(video?.timestamp).toBe(10e6);
video?.close();
});

test('hls loader async load m4s files', async () => {
const loader = await createHLSLoader(m3u8Url, 5);
const [{ actualStartTime, actualEndTime, stream }] = loader.load() ?? [];
expect(stream).toBeInstanceOf(ReadableStream);
expect([actualStartTime, Math.round(actualEndTime / 1e6)]).toEqual([0, 60]);

const clip = new MP4Clip(stream);
await clip.ready;
expect(Math.round(clip.meta.duration / 1e6)).toBe(
Math.round((actualEndTime - actualStartTime) / 1e6),
);

const { video } = await clip.tick(10e6);
expect(video?.timestamp).toBe(10e6);
video?.close();
});

test('hls loader async load m4s files with error stop correctly', async () => {
const fetchSpy = vi.spyOn(globalThis, 'fetch');
try {
const loader = await createHLSLoader(m3u8Url, 4);
fetchSpy.mockRejectedValueOnce(new Error('fetch error'));
loader.load() ?? [];
} catch (e: any) {
expect(e.message).toBe('fetch error');
}
expect(fetchSpy).toHaveBeenCalledTimes(6);
fetchSpy.mockRestore();
});
67 changes: 59 additions & 8 deletions packages/av-cliper/src/data-loader/hls-loader.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Parser } from 'm3u8-parser';
import { Log } from '../log';

/**
* 创建一个 HLS 资源加载器
*/
export async function createHLSLoader(m3u8URL: string) {
export async function createHLSLoader(m3u8URL: string, concurrency = 10) {
const parser = new Parser();
parser.push(await (await fetch(m3u8URL)).text());
parser.end();
Expand All @@ -17,6 +18,60 @@ export async function createHLSLoader(m3u8URL: string) {
);
const base = new URL(m3u8URL, location.href);

const segmentBufferFetchqueue = {} as Record<string, Promise<ArrayBuffer>>;

async function downloadSegments(
segments: Parser['manifest']['segments'],
ctrl: ReadableStreamDefaultController<Uint8Array>,
) {
function createTaskQueue(concurrency: number) {
let running = 0;
const queue = [] as Array<() => Promise<ArrayBuffer>>;

async function runTask(task: () => Promise<ArrayBuffer>) {
queue.push(task);
next();
}

async function next() {
if (running < concurrency && queue.length) {
const task = queue.shift();
running++;
try {
await task?.();
next();
} catch (err) {
queue = [];
ctrl.error(err);
Log.error(err);
}
running--;
}
}

return runTask;
}

async function fetchSegmentBufferPromise(url: string) {
return (await fetch(url)).arrayBuffer();
}

const runTask = createTaskQueue(concurrency);

for (const [, item] of segments.entries()) {
const url = new URL(item.uri, base).href;
runTask(
() => (segmentBufferFetchqueue[url] = fetchSegmentBufferPromise(url)),
);
}
}

async function getSegmentBuffer(url: string) {
const segmentBuffer = await segmentBufferFetchqueue[url];
delete segmentBufferFetchqueue[url];
return segmentBuffer;
}

return {
/**
* 下载期望时间区间的分配数据,封装成流
Expand Down Expand Up @@ -77,6 +132,7 @@ export async function createHLSLoader(m3u8URL: string) {
actualEndTime,
stream: new ReadableStream<Uint8Array>({
start: async (ctrl) => {
downloadSegments(segments, ctrl);
ctrl.enqueue(
new Uint8Array(
await (
Expand All @@ -86,13 +142,8 @@ export async function createHLSLoader(m3u8URL: string) {
);
},
pull: async (ctrl) => {
ctrl.enqueue(
new Uint8Array(
await (
await fetch(new URL(segments[segIdx].uri, base).href)
).arrayBuffer(),
),
);
const url = new URL(segments[segIdx].uri, base).href;
ctrl.enqueue(new Uint8Array(await getSegmentBuffer(url)));
segIdx += 1;
if (segIdx >= segments.length) ctrl.close();
},
Expand Down

0 comments on commit 8f6ec55

Please sign in to comment.