Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Downsizing sample for some language in speech to text #7835

Merged
merged 11 commits into from
May 12, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,31 @@ describe('Predictions convert provider test', () => {
predictionsProvider.convert(validSpeechToTextInput)
).rejects.toMatch('region not configured for transcription');
});
test('Error languageCode not configured ', () => {
SebSchwartz marked this conversation as resolved.
Show resolved Hide resolved
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
return 'Hello how are you';
}
);

const predictionsProvider = new AmazonAIConvertPredictionsProvider();
const speechGenOptions = {
transcription: {
region: 'us-west-2',
proxy: false,
},
};
predictionsProvider.configure(speechGenOptions);
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});

return expect(
predictionsProvider.convert(validSpeechToTextInput)
).rejects.toMatch(
'languageCode not configured or provided for transcription'
);
});
test('Happy case ', () => {
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
Expand Down Expand Up @@ -247,5 +272,35 @@ describe('Predictions convert provider test', () => {
},
} as SpeechToTextOutput);
});
test('Downsized Happy case ', async () => {
AmazonAIConvertPredictionsProvider.serializeDataFromTranscribe = jest.fn(
() => {
return 'Bonjour, comment vas tu?';
}
);
const downsampleBufferSpyon = jest.spyOn(
AmazonAIConvertPredictionsProvider.prototype as any,
'downsampleBuffer'
);

const predictionsProvider = new AmazonAIConvertPredictionsProvider();
const speechGenOptions = {
transcription: {
region: 'us-west-2',
proxy: false,
defaults: {
language: 'fr-FR',
},
},
};
predictionsProvider.configure(speechGenOptions);
jest.spyOn(Credentials, 'get').mockImplementationOnce(() => {
return Promise.resolve(credentials);
});

await predictionsProvider.convert(validSpeechToTextInput);
expect(downsampleBufferSpyon).toBeCalled();
SebSchwartz marked this conversation as resolved.
Show resolved Hide resolved
downsampleBufferSpyon.mockClear();
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import { fromUtf8, toUtf8 } from '@aws-sdk/util-utf8-node';
const logger = new Logger('AmazonAIConvertPredictionsProvider');
const eventBuilder = new EventStreamMarshaller(toUtf8, fromUtf8);

const LANGUAGES_CODE_IN_8KHZ = ['fr-FR', 'en-AU', 'en-GB', 'fr-CA'];

export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictionsProvider {
private translateClient: TranslateClient;
private pollyClient: PollyClient;
Expand Down Expand Up @@ -178,10 +180,17 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
languageCode: language,
});

let sourceBytes = source.bytes;
if (LANGUAGES_CODE_IN_8KHZ.includes(languageCode)) {
sourceBytes = this.downsampleBuffer({
buffer: source.bytes,
outputSampleRate: 8000,
});
}
SebSchwartz marked this conversation as resolved.
Show resolved Hide resolved
try {
const fullText = await this.sendDataToTranscribe({
connection,
raw: source.bytes,
raw: sourceBytes,
});
return {
transcription: {
Expand All @@ -206,9 +215,7 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
const transcribeMessage = eventBuilder.unmarshall(
Buffer.from(message.data)
);
const transcribeMessageJson = JSON.parse(
toUtf8(transcribeMessage.body)
);
const transcribeMessageJson = JSON.parse(toUtf8(transcribeMessage.body));
if (transcribeMessage.headers[':message-type'].value === 'exception') {
logger.debug(
'exception',
Expand Down Expand Up @@ -327,14 +334,13 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
}

private inputSampleRate = 44100;
private outputSampleRate = 16000;

private downsampleBuffer({ buffer }) {
if (this.outputSampleRate === this.inputSampleRate) {
private downsampleBuffer({ buffer, outputSampleRate = 16000 }) {
if (outputSampleRate === this.inputSampleRate) {
return buffer;
}

const sampleRateRatio = this.inputSampleRate / this.outputSampleRate;
const sampleRateRatio = this.inputSampleRate / outputSampleRate;
const newLength = Math.round(buffer.length / sampleRateRatio);
const result = new Float32Array(newLength);
let offsetResult = 0;
Expand Down Expand Up @@ -399,7 +405,9 @@ export class AmazonAIConvertPredictionsProvider extends AbstractConvertPredictio
`wss://transcribestreaming.${region}.amazonaws.com:8443`,
'/stream-transcription-websocket?',
`media-encoding=pcm&`,
`sample-rate=16000&`,
`sample-rate=${
LANGUAGES_CODE_IN_8KHZ.includes(languageCode) ? '8000' : '16000'
}&`,
`language-code=${languageCode}`,
].join('');

Expand Down