Skip to content

Commit

Permalink
feat(android): added on-device model download
Browse files Browse the repository at this point in the history
  • Loading branch information
mfkrause committed Nov 22, 2024
1 parent 18b2b99 commit 91b89a5
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 27 deletions.
21 changes: 20 additions & 1 deletion android/src/main/java/com/voicekit/VoiceKitModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class VoiceKitModule(reactContext: ReactApplicationContext) :
try {
voiceKitService.getSupportedLocales(reactApplicationContext) { locales ->
val writableArray = Arguments.createArray()
locales.forEach { writableArray.pushString(it) }
locales["installed"]?.forEach { writableArray.pushString(it) }
locales["supported"]?.forEach { writableArray.pushString(it) }
promise.resolve(writableArray)
}
} catch (e: Exception) {
Expand All @@ -77,6 +78,24 @@ class VoiceKitModule(reactContext: ReactApplicationContext) :
}
}

@ReactMethod
fun isOnDeviceModelInstalled(locale: String, promise: Promise) {
voiceKitService.getSupportedLocales(reactApplicationContext) { locales ->
promise.resolve(locales["installed"]?.contains(locale) ?: false)
}
}

@ReactMethod
fun downloadOnDeviceModel(locale: String, promise: Promise) {
voiceKitService.downloadOnDeviceModel(locale, { result ->
val response = Arguments.createMap().apply {
putString("status", result["status"] as String)
putBoolean("progressAvailable", result["progressAvailable"] as Boolean)
}
promise.resolve(response)
})
}

companion object {
const val NAME = "VoiceKit"
private const val TAG = "VoiceKitModule"
Expand Down
102 changes: 88 additions & 14 deletions android/src/main/java/com/voicekit/VoiceKitService.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import android.speech.RecognizerIntent
import android.speech.SpeechRecognizer
import android.speech.RecognitionSupport
import android.speech.RecognitionSupportCallback
import android.speech.ModelDownloadListener
import androidx.core.app.ActivityCompat
import androidx.core.content.ContextCompat
import com.facebook.react.bridge.*
Expand Down Expand Up @@ -43,6 +44,8 @@ class VoiceKitService(private val context: ReactApplicationContext) {
private var lastResultTimer: Handler? = null
private var lastTranscription: String? = null

private var isDownloadingModel: Boolean = false

fun sendEvent(eventName: String, params: Any?) {
context
.getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter::class.java)
Expand Down Expand Up @@ -278,34 +281,32 @@ class VoiceKitService(private val context: ReactApplicationContext) {
}
}

fun getSupportedLocales(context: Context, callback: (List<String>) -> Unit) {
fun getSupportedLocales(context: Context, callback: (Map<String, List<String>>) -> Unit) {
Log.d(TAG, "Getting supported locales")

if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
// On Android 13+, we can get the list from the on-device recognizer
// TODO: The on-device supported locales are not necessarily the ones we can use for the standard recognizer
// We need to improve the usage of the default recognizer & on-device recognizer for both Android 13+ and <13
// On Android 13+, we can get a list of locales supported by the on-device recognizer

// On-device speech Recognizer can only be ran on main thread
// On-device speech recognizer can only be ran on main thread
Handler(context.mainLooper).post {
val tempSpeechRecognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context)
val intent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH)

tempSpeechRecognizer?.checkRecognitionSupport(
tempSpeechRecognizer.checkRecognitionSupport(
intent,
Executors.newSingleThreadExecutor(),
@RequiresApi(Build.VERSION_CODES.TIRAMISU)
object : RecognitionSupportCallback {
override fun onSupportResult(recognitionSupport: RecognitionSupport) {
Log.d(TAG, "getSupportedLocales() onSupportResult called with recognitionSupport $recognitionSupport")

// TODO: We need a method to download supported but not installed locales, then we can send mergedLocales
val installedLocales = recognitionSupport.installedOnDeviceLanguages
val supportedLocales = recognitionSupport.supportedOnDeviceLanguages // not necessarily installed for on-device recognition

val mergedLocales = (installedLocales + supportedLocales).distinct().sorted()
val installedLocales = recognitionSupport.installedOnDeviceLanguages.map { it.toString() }
val supportedLocales = recognitionSupport.supportedOnDeviceLanguages.map { it.toString() }

callback(installedLocales?.map { it.toString() } ?: emptyList())
callback(mapOf(
"installed" to installedLocales,
"supported" to supportedLocales
))

tempSpeechRecognizer.destroy()
}
Expand All @@ -318,8 +319,81 @@ class VoiceKitService(private val context: ReactApplicationContext) {
)
}
} else {
// TODO: Implement fallback for Android <13
callback(emptyList())
callback(mapOf(
"installed" to emptyList(),
"supported" to emptyList()
))
}
}

fun downloadOnDeviceModel(locale: String, callback: (Map<String, Any>) -> Unit) {
if (isDownloadingModel) {
// throw VoiceError.InvalidState("A model download is already in progress")
throw VoiceError.InvalidState
}

if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) {
// throw VoiceError.InvalidState("Android version must be 13 or higher to download speech models")
throw VoiceError.InvalidState
}

val intent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply {
putExtra(RecognizerIntent.EXTRA_LANGUAGE, locale)
}

// Android 13 does not support progress tracking, simply download the model
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
Handler(context.mainLooper).post {
val recognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context)
recognizer.triggerModelDownload(intent)
recognizer.destroy()
callback(mapOf(
"status" to "started",
"progressAvailable" to false
))
}
return
}

// Android 14+ supports progress tracking, track download progress
isDownloadingModel = true
Handler(context.mainLooper).post {
val recognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context)
recognizer.triggerModelDownload(
intent,
Executors.newSingleThreadExecutor(),
@RequiresApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE)
object : ModelDownloadListener {
override fun onProgress(progress: Int) {
sendEvent("RNVoiceKit.model-download-progress", progress)
}

override fun onSuccess() {
isDownloadingModel = false
recognizer.destroy()
}

override fun onScheduled() {
isDownloadingModel = false
/*callback(mapOf(
"status" to "scheduled",
"progressAvailable" to false
))*/
recognizer.destroy()
}

override fun onError(error: Int) {
isDownloadingModel = false
recognizer.destroy()
// throw VoiceError.RecognitionFailed("Model download failed with error code: $error")
throw VoiceError.RecognitionFailed // TODO: this doesn't reach the callback
}
}
)
callback(mapOf(
"status" to "started",
"progressAvailable" to true
))
}
}

Expand Down
47 changes: 41 additions & 6 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { StyleSheet, View, Text, TouchableOpacity, TextInput } from 'react-native';
import { VoiceError, VoiceKit, VoiceMode, useVoice } from 'react-native-voicekit';
import { StyleSheet, View, Text, TouchableOpacity, TextInput, Platform } from 'react-native';
import { VoiceError, VoiceEvent, VoiceKit, VoiceMode, useVoice } from 'react-native-voicekit';
import Dropdown from './components/Dropdown';
import { useEffect, useState } from 'react';
import { useCallback, useEffect, useState } from 'react';

export default function App() {
const [locale, setLocale] = useState('en-US');
const [isLocaleInstalled, setIsLocaleInstalled] = useState(Platform.OS !== 'android');
const [supportedLocales, setSupportedLocales] = useState<string[]>([]);

const { available, listening, transcript, startListening, stopListening, resetTranscript } = useVoice({
Expand All @@ -22,27 +23,61 @@ export default function App() {
});
}, [locale]);

useEffect(() => {
VoiceKit.isOnDeviceModelInstalled(locale).then((isInstalled) => {
setIsLocaleInstalled(isInstalled);
});
}, [locale]);

const onModelDownloadProgress = useCallback((progress: number) => {
console.log('Model download progress:', progress);
if (progress >= 100) {
setIsLocaleInstalled(true);
VoiceKit.removeListener(VoiceEvent.ModelDownloadProgress, onModelDownloadProgress);
}
}, []);

return (
<View style={styles.container}>
<Text>Is available: {available ? 'Yes' : 'No'}</Text>
<Text style={{ marginBottom: 30 }}>Is listening: {listening ? 'Yes' : 'No'}</Text>
<Dropdown
label="Locale"
label={`Locale${Platform.OS === 'android' ? ` (is installed: ${isLocaleInstalled ? 'yes' : 'no'})` : ''}`}
data={supportedLocales.map((l) => ({ label: l, value: l }))}
maxHeight={300}
value={locale}
onChange={(item) => setLocale(item.value)}
containerStyle={styles.dropdown}
style={styles.dropdown}
/>
{Platform.OS === 'android' && (
<TouchableOpacity
onPress={() => {
VoiceKit.downloadOnDeviceModel(locale)
.then((result) => {
if (result.progressAvailable) {
VoiceKit.addListener(VoiceEvent.ModelDownloadProgress, onModelDownloadProgress);
} else {
console.log('Model download status:', result.status);
}
})
.catch((error) => {
console.error('Error downloading model', error, error instanceof VoiceError ? error.details : null);
});
}}
disabled={isLocaleInstalled}
style={[styles.button, isLocaleInstalled && styles.disabledButton]}>
<Text style={styles.buttonText}>Download "{locale}" Model</Text>
</TouchableOpacity>
)}
<TouchableOpacity
onPress={async () => {
await startListening().catch((error) => {
console.error('Error starting listening', error, error instanceof VoiceError ? error.details : null);
});
}}
disabled={!available || listening}
style={[styles.button, (!available || listening) && styles.disabledButton]}>
disabled={!available || !isLocaleInstalled || listening}
style={[styles.button, (!available || !isLocaleInstalled || listening) && styles.disabledButton]}>
<Text style={styles.buttonText}>Start Listening</Text>
</TouchableOpacity>
<TouchableOpacity
Expand Down
47 changes: 44 additions & 3 deletions src/RNVoiceKit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { NativeEventEmitter, NativeModules, Platform } from 'react-native';
import RNVoiceError from './utils/VoiceError';
import { VoiceErrorCode } from './types/native';
import { VoiceEvent, VoiceMode } from './types';
import type { VoiceEventMap, VoiceStartListeningOptions } from './types';
import { VoiceModelDownloadStatus, type VoiceEventMap, type VoiceStartListeningOptions } from './types';

const LINKING_ERROR =
`The package 'react-native-voicekit' doesn't seem to be linked. Make sure: \n\n` +
Expand Down Expand Up @@ -94,8 +94,9 @@ class RNVoiceKit {
}

/**
* Gets the list of supported locales for speech recognition. On Android, this gets the list of
* supported locales for the on-device speech recognizer.
* Gets the list of supported locales for speech recognition. On Android, this gets the list of supported locales for
* the on-device speech recognizer. Note that this does not check if the model is installed already. Use
* `isOnDeviceModelInstalled()` to check if the model for a given locale is installed before using it.
* Does not work on Android versions below 13 and will return an empty array for those versions.
*
* @returns The list of supported locales.
Expand All @@ -104,6 +105,46 @@ class RNVoiceKit {
return await nativeInstance.getSupportedLocales();
}

/**
* Checks if the on-device speech recognizer model for the given locale is installed. If it is not, use
* `downloadOnDeviceModel()` to download it. Only works on Android 13+.
* Does not have any effect on iOS and will simply check if the locale is supported.
*
* @param locale - The locale to check.
* @returns Whether the model is installed.
*/
async isOnDeviceModelInstalled(locale: string): Promise<boolean> {
if (Platform.OS === 'ios') {
return (await this.getSupportedLocales()).includes(locale);
}

return await nativeInstance.isOnDeviceModelInstalled(locale);
}

/**
* Downloads the on-device speech recognizer model for the given locale. Only works on Android 13+.
* When the download was successfully started, the promise will resolve with a `started` status.
* On Android 14+,you can listen to the `VoiceEvent.ModelDownloadProgress` event to track the download progress.
* Does not have any effect on iOS and will simply return a `started` status if the locale is supported, or throw
* an error if it is not.
*
* @returns The status of the model download and whether download progress is available via the
* `VoiceEvent.ModelDownloadProgress` event.
*/
async downloadOnDeviceModel(
locale: string
): Promise<{ status: VoiceModelDownloadStatus; progressAvailable: boolean }> {
if (Platform.OS === 'ios') {
if ((await this.getSupportedLocales()).includes(locale)) {
return { status: VoiceModelDownloadStatus.Started, progressAvailable: false };
} else {
throw new RNVoiceError('Locale is not supported', VoiceErrorCode.INVALID_STATE); // TODO: better code
}
}

return await nativeInstance.downloadOnDeviceModel(locale);
}

addListener<T extends VoiceEvent>(event: T, listener: (...args: VoiceEventMap[T]) => void) {
if (!this.listeners[event]) {
this.listeners[event] = [];
Expand Down
10 changes: 8 additions & 2 deletions src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ export enum VoiceEvent {
PartialResult = 'partial-result',
AvailabilityChange = 'availability-change',
ListeningStateChange = 'listening-state-change',
ModelDownloadProgress = 'model-download-progress',
Error = 'error',
}

Expand All @@ -11,6 +12,7 @@ export interface VoiceEventMap extends Record<VoiceEvent, any[]> {
[VoiceEvent.PartialResult]: [string];
[VoiceEvent.AvailabilityChange]: [boolean];
[VoiceEvent.ListeningStateChange]: [boolean];
[VoiceEvent.ModelDownloadProgress]: [number];
[VoiceEvent.Error]: any[];
}

Expand All @@ -20,6 +22,11 @@ export enum VoiceMode {
ContinuousAndStop = 'continuous-and-stop',
}

export enum VoiceModelDownloadStatus {
Started = 'started',
Scheduled = 'scheduled',
}

export interface VoiceStartListeningOptions {
/**
* The locale to use for speech recognition. Defaults to `en-US`.
Expand Down Expand Up @@ -57,8 +64,7 @@ export interface VoiceStartListeningOptions {
* Whether to force usage of the on-device speech recognizer. Does not have any effect on iOS. Only works on Android
* 13 and above. Defaults to `false`.
* Note: When using the on-device recognizer, some locales returned by `getSupportedLocales()` may not be installed
* on the device yet and need to be installed first.
* TODO: Add a method to install locales for the Android on-device recognizer
* on the device yet and need to be installed using `downloadOnDeviceModel()` first.
*/
useOnDeviceRecognizer?: boolean;
}
4 changes: 3 additions & 1 deletion src/types/native.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { VoiceStartListeningOptions } from '.';
import type { VoiceModelDownloadStatus, VoiceStartListeningOptions } from '.';

export enum VoiceErrorCode {
SPEECH_RECOGNIZER_NOT_AVAILABLE = 'ERR_SPEECH_RECOGNIZER_NOT_AVAILABLE',
Expand All @@ -15,5 +15,7 @@ export default interface NativeRNVoiceKit {
startListening: (options: Required<VoiceStartListeningOptions>) => Promise<void>;
stopListening: () => Promise<void>;
isSpeechRecognitionAvailable: () => Promise<boolean>;
isOnDeviceModelInstalled: (locale: string) => Promise<boolean>;
getSupportedLocales: () => Promise<string[]>;
downloadOnDeviceModel: (locale: string) => Promise<{ status: VoiceModelDownloadStatus; progressAvailable: boolean }>;
}

0 comments on commit 91b89a5

Please sign in to comment.