Skip to content

Commit

Permalink
feat: add support for multiple model files (#907)
Browse files Browse the repository at this point in the history
* Add new proto message ModelFile

* Update all .pbtxt file with new model_file field

* Update tests

* Update Dart code to work with new ModelFile message

* Create symlinks for multiple models and put them to one cache dir

* Increase download idleTimeout
  • Loading branch information
anhappdev authored Aug 30, 2024
1 parent c453aa5 commit 2e6931e
Show file tree
Hide file tree
Showing 40 changed files with 1,370 additions and 668 deletions.
16 changes: 11 additions & 5 deletions flutter/cpp/proto/backend_setting.proto
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ message BenchmarkSetting {
}

// Config of a delegate.
// Next ID: 8
// Next ID: 10
message DelegateSetting {
// Priority of the delegate. Used for sorting delegate choices in frontend.
optional int32 priority = 1 [default = 0];
Expand All @@ -61,10 +61,8 @@ message DelegateSetting {
required string accelerator_name = 3;
// Human-readable name of the accelerator (hardware)
required string accelerator_desc = 4;
// URL or local path of the model file
required string model_path = 5;
// MD5 checksum to validate the model file
required string model_checksum = 6;
// The model file to be used when using this delegate
repeated ModelFile model_file = 9;
// The batch size to be used when running the model. Default to 1.
optional int32 batch_size = 7 [default = 1];
// Custom setting for this delegate.
Expand Down Expand Up @@ -110,3 +108,11 @@ message SettingList {
repeated CommonSetting setting = 1;
optional BenchmarkSetting benchmark_setting = 2;
}

// ModelFile will downloaded by the app and passed to the vendor backend
message ModelFile {
// URL or local path of the model file
required string model_path = 5;
// MD5 checksum to validate the model file
required string model_checksum = 6;
}
10 changes: 6 additions & 4 deletions flutter/integration_test/utils.dart
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ Future<void> validateSettings(WidgetTester tester) async {
for (var benchmark in benchmarkState.benchmarks) {
expect(benchmark.selectedDelegate.batchSize, greaterThanOrEqualTo(0),
reason: 'batchSize must >= 0');
expect(benchmark.selectedDelegate.modelPath.isNotEmpty, isTrue,
reason: 'modelPath cannot be empty');
expect(benchmark.selectedDelegate.modelChecksum.isNotEmpty, isTrue,
reason: 'modelChecksum cannot be empty');
for (var modelFile in benchmark.selectedDelegate.modelFile) {
expect(modelFile.modelPath.isNotEmpty, isTrue,
reason: 'modelPath cannot be empty');
expect(modelFile.modelChecksum.isNotEmpty, isTrue,
reason: 'modelChecksum cannot be empty');
}
expect(benchmark.selectedDelegate.acceleratorName.isNotEmpty, isTrue,
reason: 'acceleratorName cannot be empty');
expect(benchmark.selectedDelegate.acceleratorDesc.isNotEmpty, isTrue,
Expand Down
25 changes: 15 additions & 10 deletions flutter/lib/benchmark/benchmark.dart
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ class Benchmark {
return delegate;
}

RunSettings createRunSettings({
Future<RunSettings> createRunSettings({
required BenchmarkRunMode runMode,
required ResourceManager resourceManager,
required List<pb.CommonSetting> commonSettings,
required String backendLibName,
required String logDir,
required int testMinDuration,
required int testMinQueryCount,
}) {
}) async {
final dataset = runMode.chooseDataset(taskConfig);

int minQueryCount;
Expand All @@ -92,9 +92,12 @@ class Benchmark {
setting: commonSettings,
benchmarkSetting: benchmarkSettings,
);

final uris = selectedDelegate.modelFile.map((e) => e.modelPath).toList();
final modelDirName = selectedDelegate.delegateName.replaceAll(' ', '_');
final backendModelPath =
await resourceManager.getModelPath(uris, modelDirName);
return RunSettings(
backend_model_path: resourceManager.get(selectedDelegate.modelPath),
backend_model_path: backendModelPath,
backend_lib_name: backendLibName,
backend_settings: settings,
backend_native_lib_path: DeviceInfo.instance.nativeLibraryPath,
Expand Down Expand Up @@ -167,12 +170,14 @@ class BenchmarkStore {
}

for (final delegate in b.benchmarkSettings.delegateChoice) {
final model = Resource(
path: delegate.modelPath,
type: ResourceTypeEnum.model,
md5Checksum: delegate.modelChecksum,
);
result.add(model);
for (final modelFile in delegate.modelFile) {
final model = Resource(
path: modelFile.modelPath,
type: ResourceTypeEnum.model,
md5Checksum: modelFile.modelChecksum,
);
result.add(model);
}
}
}

Expand Down
5 changes: 4 additions & 1 deletion flutter/lib/resources/export_result_helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,15 @@ class ResultHelper {
final delegate = benchmark.selectedDelegate;
final extraSettings = _extraSettingsFromCommon(commonSettings) +
_extraSettingsFromCustom(benchmark.selectedDelegate.customSetting);
final modelPathsJoined =
delegate.modelFile.map((e) => e.modelPath).join(', ');
final modelPathString = '[$modelPathsJoined]';
return BackendSettingsInfo(
acceleratorCode: delegate.acceleratorName,
acceleratorDesc: delegate.acceleratorDesc,
delegate: delegate.delegateName,
framework: benchmark.benchmarkSettings.framework,
modelPath: delegate.modelPath,
modelPath: modelPathString,
batchSize: delegate.batchSize,
extraSettings: extraSettings,
);
Expand Down
2 changes: 1 addition & 1 deletion flutter/lib/resources/file_cache_helper.dart
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class FileCacheHelper {
Future<File> _download(String url) async {
print('downloading $url');
const successStatusCode = 200;

_httpClient.idleTimeout = const Duration(seconds: 60);
final response = await _httpClient
.getUrl(Uri.parse(url))
.then((request) => request.close());
Expand Down
40 changes: 40 additions & 0 deletions flutter/lib/resources/resource_manager.dart
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import 'package:mlperfbench/store.dart';
class ResourceManager {
static const _dataPrefix = 'local://';
static const _loadedResourcesDirName = 'loaded_resources';
static const _symlinksDirName = 'symlinks';

final VoidCallback _onUpdate;
final Store store;
Expand Down Expand Up @@ -57,6 +58,27 @@ class ResourceManager {
throw 'invalid resource path: $uri';
}

// Creates symlinks to the given file paths and put them in one cache directory.
Future<String> getModelPath(List<String> paths, String dirName) async {
String modelPath;
if (paths.isEmpty) {
throw 'List of URIs cannot be empty';
}
if (dirName.contains(' ')) {
throw 'Directory name cannot contain spaces';
}
final cacheDir = await getApplicationCacheDirectory();
final modelDir = Directory('${cacheDir.path}/$_symlinksDirName/$dirName');
final files = paths.map((uri) => File(get(uri))).toList();
final symlinks = await _createSymlinks(files, modelDir);
if (paths.length == 1) {
modelPath = symlinks.first;
} else {
modelPath = modelDir.path;
}
return modelPath;
}

String getDataFolder() {
return applicationDirectory;
}
Expand Down Expand Up @@ -180,4 +202,22 @@ class ResourceManager {
}
return checksumFailedResources;
}

Future<List<String>> _createSymlinks(
List<File> files, Directory cacheDir) async {
if (!await cacheDir.exists()) {
await cacheDir.create(recursive: true);
}
List<String> symlinkPaths = [];
for (final file in files) {
final symlinkPath = '${cacheDir.path}/${file.uri.pathSegments.last}';
symlinkPaths.add(symlinkPath);
final link = Link(symlinkPath);
if (await link.exists()) {
await link.delete();
}
await link.create(file.path);
}
return symlinkPaths;
}
}
34 changes: 20 additions & 14 deletions flutter/lib/state/task_runner.dart
Original file line number Diff line number Diff line change
Expand Up @@ -215,19 +215,21 @@ class TaskRunner {
};
notifyListeners();

final performanceRunInfo = await _NativeRunHelper(
final runHelper = _NativeRunHelper(
enableArtificialLoad: store.artificialCPULoadEnabled,
isTestMode: store.testMode,
resourceManager: resourceManager,
backendBridge: backendBridge,
benchmark: benchmark,
runMode: perfMode,
logParentDir: currentLogDir,
);
await runHelper.initRunSettings(
resourceManager: resourceManager,
commonSettings: backendInfo.settings.commonSetting,
backendLibName: backendInfo.libName,
logParentDir: currentLogDir,
testMinQueryCount: store.testMinQueryCount,
testMinDuration: store.testMinDuration,
).run();
);
final performanceRunInfo = await runHelper.run();
perfTimer.stop();
performanceRunInfo.loadgenInfo!;

Expand All @@ -253,19 +255,21 @@ class TaskRunner {
return queryProgress;
};
notifyListeners();
final accuracyRunInfo = await _NativeRunHelper(
final runHelper = _NativeRunHelper(
enableArtificialLoad: store.artificialCPULoadEnabled,
isTestMode: store.testMode,
resourceManager: resourceManager,
backendBridge: backendBridge,
benchmark: benchmark,
runMode: accuracyMode,
logParentDir: currentLogDir,
);
await runHelper.initRunSettings(
resourceManager: resourceManager,
commonSettings: backendInfo.settings.commonSetting,
backendLibName: backendInfo.libName,
logParentDir: currentLogDir,
testMinQueryCount: store.testMinQueryCount,
testMinDuration: store.testMinDuration,
).run();
);
final accuracyRunInfo = await runHelper.run();
resultHelper.accuracyRunInfo = accuracyRunInfo;
final accuracyResult = accuracyRunInfo.result;
benchmark.accuracyModeResult = BenchmarkResult(
Expand Down Expand Up @@ -302,15 +306,17 @@ class _NativeRunHelper {
required this.backendBridge,
required this.benchmark,
required this.runMode,
required bool isTestMode,
required String logParentDir,
}) : logDir = '$logParentDir/${benchmark.id}-${runMode.readable}';

Future<void> initRunSettings({
required ResourceManager resourceManager,
required List<pb.CommonSetting> commonSettings,
required String backendLibName,
required String logParentDir,
required int testMinQueryCount,
required int testMinDuration,
}) : logDir = '$logParentDir/${benchmark.id}-${runMode.readable}' {
runSettings = benchmark.createRunSettings(
}) async {
runSettings = await benchmark.createRunSettings(
runMode: runMode,
resourceManager: resourceManager,
commonSettings: commonSettings,
Expand Down
14 changes: 10 additions & 4 deletions flutter/unit_test/benchmark/benchmark_store_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ void main() {
tiny: pb.OneDatasetConfig(inputPath: 'tiny-inputPath'),
),
);
final model1 = pb.ModelFile(
modelPath: 'model1-path',
);
final choice1 = pb.DelegateSetting(
delegateName: 'delegate1',
modelPath: 'model1-path',
modelFile: [model1],
);
final backendSettings1 = pb.BenchmarkSetting(
benchmarkId: 'task1',
Expand Down Expand Up @@ -115,7 +118,8 @@ void main() {
expect(
resources,
contains(Resource(
path: backendSettings1.delegateChoice.first.modelPath,
path:
backendSettings1.delegateChoice.first.modelFile.first.modelPath,
type: ResourceTypeEnum.model,
md5Checksum: '',
)));
Expand All @@ -140,7 +144,8 @@ void main() {
expect(
resources,
contains(Resource(
path: backendSettings1.delegateChoice.first.modelPath,
path:
backendSettings1.delegateChoice.first.modelFile.first.modelPath,
type: ResourceTypeEnum.model,
md5Checksum: '',
)));
Expand Down Expand Up @@ -168,7 +173,8 @@ void main() {
expect(
resources,
contains(Resource(
path: backendSettings1.delegateChoice.first.modelPath,
path:
backendSettings1.delegateChoice.first.modelFile.first.modelPath,
type: ResourceTypeEnum.model,
md5Checksum: '',
)));
Expand Down
Loading

0 comments on commit 2e6931e

Please sign in to comment.