diff --git a/android/src/main/java/com/voicekit/VoiceKitModule.kt b/android/src/main/java/com/voicekit/VoiceKitModule.kt index cb70d49..0bbab0c 100644 --- a/android/src/main/java/com/voicekit/VoiceKitModule.kt +++ b/android/src/main/java/com/voicekit/VoiceKitModule.kt @@ -68,7 +68,8 @@ class VoiceKitModule(reactContext: ReactApplicationContext) : try { voiceKitService.getSupportedLocales(reactApplicationContext) { locales -> val writableArray = Arguments.createArray() - locales.forEach { writableArray.pushString(it) } + locales["installed"]?.forEach { writableArray.pushString(it) } + locales["supported"]?.forEach { writableArray.pushString(it) } promise.resolve(writableArray) } } catch (e: Exception) { @@ -77,6 +78,24 @@ class VoiceKitModule(reactContext: ReactApplicationContext) : } } + @ReactMethod + fun isOnDeviceModelInstalled(locale: String, promise: Promise) { + voiceKitService.getSupportedLocales(reactApplicationContext) { locales -> + promise.resolve(locales["installed"]?.contains(locale) ?: false) + } + } + + @ReactMethod + fun downloadOnDeviceModel(locale: String, promise: Promise) { + voiceKitService.downloadOnDeviceModel(locale, { result -> + val response = Arguments.createMap().apply { + putString("status", result["status"] as String) + putBoolean("progressAvailable", result["progressAvailable"] as Boolean) + } + promise.resolve(response) + }) + } + companion object { const val NAME = "VoiceKit" private const val TAG = "VoiceKitModule" diff --git a/android/src/main/java/com/voicekit/VoiceKitService.kt b/android/src/main/java/com/voicekit/VoiceKitService.kt index 4b8eaa5..0321f11 100644 --- a/android/src/main/java/com/voicekit/VoiceKitService.kt +++ b/android/src/main/java/com/voicekit/VoiceKitService.kt @@ -10,6 +10,7 @@ import android.speech.RecognizerIntent import android.speech.SpeechRecognizer import android.speech.RecognitionSupport import android.speech.RecognitionSupportCallback +import android.speech.ModelDownloadListener import androidx.core.app.ActivityCompat import androidx.core.content.ContextCompat import com.facebook.react.bridge.* @@ -43,6 +44,8 @@ class VoiceKitService(private val context: ReactApplicationContext) { private var lastResultTimer: Handler? = null private var lastTranscription: String? = null + private var isDownloadingModel: Boolean = false + fun sendEvent(eventName: String, params: Any?) { context .getJSModule(DeviceEventManagerModule.RCTDeviceEventEmitter::class.java) @@ -278,20 +281,18 @@ class VoiceKitService(private val context: ReactApplicationContext) { } } - fun getSupportedLocales(context: Context, callback: (List) -> Unit) { + fun getSupportedLocales(context: Context, callback: (Map>) -> Unit) { Log.d(TAG, "Getting supported locales") if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { - // On Android 13+, we can get the list from the on-device recognizer - // TODO: The on-device supported locales are not necessarily the ones we can use for the standard recognizer - // We need to improve the usage of the default recognizer & on-device recognizer for both Android 13+ and <13 + // On Android 13+, we can get a list of locales supported by the on-device recognizer - // On-device speech Recognizer can only be ran on main thread + // On-device speech recognizer can only be ran on main thread Handler(context.mainLooper).post { val tempSpeechRecognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context) val intent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH) - tempSpeechRecognizer?.checkRecognitionSupport( + tempSpeechRecognizer.checkRecognitionSupport( intent, Executors.newSingleThreadExecutor(), @RequiresApi(Build.VERSION_CODES.TIRAMISU) @@ -299,13 +300,13 @@ class VoiceKitService(private val context: ReactApplicationContext) { override fun onSupportResult(recognitionSupport: RecognitionSupport) { Log.d(TAG, "getSupportedLocales() onSupportResult called with recognitionSupport $recognitionSupport") - // TODO: We need a method to download supported but not installed locales, then we can send mergedLocales - val installedLocales = recognitionSupport.installedOnDeviceLanguages - val supportedLocales = recognitionSupport.supportedOnDeviceLanguages // not necessarily installed for on-device recognition - - val mergedLocales = (installedLocales + supportedLocales).distinct().sorted() + val installedLocales = recognitionSupport.installedOnDeviceLanguages.map { it.toString() } + val supportedLocales = recognitionSupport.supportedOnDeviceLanguages.map { it.toString() } - callback(installedLocales?.map { it.toString() } ?: emptyList()) + callback(mapOf( + "installed" to installedLocales, + "supported" to supportedLocales + )) tempSpeechRecognizer.destroy() } @@ -318,8 +319,81 @@ class VoiceKitService(private val context: ReactApplicationContext) { ) } } else { - // TODO: Implement fallback for Android <13 - callback(emptyList()) + callback(mapOf( + "installed" to emptyList(), + "supported" to emptyList() + )) + } + } + + fun downloadOnDeviceModel(locale: String, callback: (Map) -> Unit) { + if (isDownloadingModel) { + // throw VoiceError.InvalidState("A model download is already in progress") + throw VoiceError.InvalidState + } + + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.TIRAMISU) { + // throw VoiceError.InvalidState("Android version must be 13 or higher to download speech models") + throw VoiceError.InvalidState + } + + val intent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply { + putExtra(RecognizerIntent.EXTRA_LANGUAGE, locale) + } + + // Android 13 does not support progress tracking, simply download the model + if (Build.VERSION.SDK_INT < Build.VERSION_CODES.UPSIDE_DOWN_CAKE) { + Handler(context.mainLooper).post { + val recognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context) + recognizer.triggerModelDownload(intent) + recognizer.destroy() + callback(mapOf( + "status" to "started", + "progressAvailable" to false + )) + } + return + } + + // Android 14+ supports progress tracking, track download progress + isDownloadingModel = true + Handler(context.mainLooper).post { + val recognizer = SpeechRecognizer.createOnDeviceSpeechRecognizer(context) + recognizer.triggerModelDownload( + intent, + Executors.newSingleThreadExecutor(), + @RequiresApi(Build.VERSION_CODES.UPSIDE_DOWN_CAKE) + object : ModelDownloadListener { + override fun onProgress(progress: Int) { + sendEvent("RNVoiceKit.model-download-progress", progress) + } + + override fun onSuccess() { + isDownloadingModel = false + recognizer.destroy() + } + + override fun onScheduled() { + isDownloadingModel = false + /*callback(mapOf( + "status" to "scheduled", + "progressAvailable" to false + ))*/ + recognizer.destroy() + } + + override fun onError(error: Int) { + isDownloadingModel = false + recognizer.destroy() + // throw VoiceError.RecognitionFailed("Model download failed with error code: $error") + throw VoiceError.RecognitionFailed // TODO: this doesn't reach the callback + } + } + ) + callback(mapOf( + "status" to "started", + "progressAvailable" to true + )) } } diff --git a/example/src/App.tsx b/example/src/App.tsx index 133df7c..a9cc1a5 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -1,10 +1,11 @@ -import { StyleSheet, View, Text, TouchableOpacity, TextInput } from 'react-native'; -import { VoiceError, VoiceKit, VoiceMode, useVoice } from 'react-native-voicekit'; +import { StyleSheet, View, Text, TouchableOpacity, TextInput, Platform } from 'react-native'; +import { VoiceError, VoiceEvent, VoiceKit, VoiceMode, useVoice } from 'react-native-voicekit'; import Dropdown from './components/Dropdown'; -import { useEffect, useState } from 'react'; +import { useCallback, useEffect, useState } from 'react'; export default function App() { const [locale, setLocale] = useState('en-US'); + const [isLocaleInstalled, setIsLocaleInstalled] = useState(Platform.OS !== 'android'); const [supportedLocales, setSupportedLocales] = useState([]); const { available, listening, transcript, startListening, stopListening, resetTranscript } = useVoice({ @@ -22,12 +23,26 @@ export default function App() { }); }, [locale]); + useEffect(() => { + VoiceKit.isOnDeviceModelInstalled(locale).then((isInstalled) => { + setIsLocaleInstalled(isInstalled); + }); + }, [locale]); + + const onModelDownloadProgress = useCallback((progress: number) => { + console.log('Model download progress:', progress); + if (progress >= 100) { + setIsLocaleInstalled(true); + VoiceKit.removeListener(VoiceEvent.ModelDownloadProgress, onModelDownloadProgress); + } + }, []); + return ( Is available: {available ? 'Yes' : 'No'} Is listening: {listening ? 'Yes' : 'No'} ({ label: l, value: l }))} maxHeight={300} value={locale} @@ -35,14 +50,34 @@ export default function App() { containerStyle={styles.dropdown} style={styles.dropdown} /> + {Platform.OS === 'android' && ( + { + VoiceKit.downloadOnDeviceModel(locale) + .then((result) => { + if (result.progressAvailable) { + VoiceKit.addListener(VoiceEvent.ModelDownloadProgress, onModelDownloadProgress); + } else { + console.log('Model download status:', result.status); + } + }) + .catch((error) => { + console.error('Error downloading model', error, error instanceof VoiceError ? error.details : null); + }); + }} + disabled={isLocaleInstalled} + style={[styles.button, isLocaleInstalled && styles.disabledButton]}> + Download "{locale}" Model + + )} { await startListening().catch((error) => { console.error('Error starting listening', error, error instanceof VoiceError ? error.details : null); }); }} - disabled={!available || listening} - style={[styles.button, (!available || listening) && styles.disabledButton]}> + disabled={!available || !isLocaleInstalled || listening} + style={[styles.button, (!available || !isLocaleInstalled || listening) && styles.disabledButton]}> Start Listening { + if (Platform.OS === 'ios') { + return (await this.getSupportedLocales()).includes(locale); + } + + return await nativeInstance.isOnDeviceModelInstalled(locale); + } + + /** + * Downloads the on-device speech recognizer model for the given locale. Only works on Android 13+. + * When the download was successfully started, the promise will resolve with a `started` status. + * On Android 14+,you can listen to the `VoiceEvent.ModelDownloadProgress` event to track the download progress. + * Does not have any effect on iOS and will simply return a `started` status if the locale is supported, or throw + * an error if it is not. + * + * @returns The status of the model download and whether download progress is available via the + * `VoiceEvent.ModelDownloadProgress` event. + */ + async downloadOnDeviceModel( + locale: string + ): Promise<{ status: VoiceModelDownloadStatus; progressAvailable: boolean }> { + if (Platform.OS === 'ios') { + if ((await this.getSupportedLocales()).includes(locale)) { + return { status: VoiceModelDownloadStatus.Started, progressAvailable: false }; + } else { + throw new RNVoiceError('Locale is not supported', VoiceErrorCode.INVALID_STATE); // TODO: better code + } + } + + return await nativeInstance.downloadOnDeviceModel(locale); + } + addListener(event: T, listener: (...args: VoiceEventMap[T]) => void) { if (!this.listeners[event]) { this.listeners[event] = []; diff --git a/src/types/index.ts b/src/types/index.ts index 3ec2684..6d89711 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -3,6 +3,7 @@ export enum VoiceEvent { PartialResult = 'partial-result', AvailabilityChange = 'availability-change', ListeningStateChange = 'listening-state-change', + ModelDownloadProgress = 'model-download-progress', Error = 'error', } @@ -11,6 +12,7 @@ export interface VoiceEventMap extends Record { [VoiceEvent.PartialResult]: [string]; [VoiceEvent.AvailabilityChange]: [boolean]; [VoiceEvent.ListeningStateChange]: [boolean]; + [VoiceEvent.ModelDownloadProgress]: [number]; [VoiceEvent.Error]: any[]; } @@ -20,6 +22,11 @@ export enum VoiceMode { ContinuousAndStop = 'continuous-and-stop', } +export enum VoiceModelDownloadStatus { + Started = 'started', + Scheduled = 'scheduled', +} + export interface VoiceStartListeningOptions { /** * The locale to use for speech recognition. Defaults to `en-US`. @@ -57,8 +64,7 @@ export interface VoiceStartListeningOptions { * Whether to force usage of the on-device speech recognizer. Does not have any effect on iOS. Only works on Android * 13 and above. Defaults to `false`. * Note: When using the on-device recognizer, some locales returned by `getSupportedLocales()` may not be installed - * on the device yet and need to be installed first. - * TODO: Add a method to install locales for the Android on-device recognizer + * on the device yet and need to be installed using `downloadOnDeviceModel()` first. */ useOnDeviceRecognizer?: boolean; } diff --git a/src/types/native.ts b/src/types/native.ts index a665d00..44940d3 100644 --- a/src/types/native.ts +++ b/src/types/native.ts @@ -1,4 +1,4 @@ -import type { VoiceStartListeningOptions } from '.'; +import type { VoiceModelDownloadStatus, VoiceStartListeningOptions } from '.'; export enum VoiceErrorCode { SPEECH_RECOGNIZER_NOT_AVAILABLE = 'ERR_SPEECH_RECOGNIZER_NOT_AVAILABLE', @@ -15,5 +15,7 @@ export default interface NativeRNVoiceKit { startListening: (options: Required) => Promise; stopListening: () => Promise; isSpeechRecognitionAvailable: () => Promise; + isOnDeviceModelInstalled: (locale: string) => Promise; getSupportedLocales: () => Promise; + downloadOnDeviceModel: (locale: string) => Promise<{ status: VoiceModelDownloadStatus; progressAvailable: boolean }>; }