diff --git a/cli.py b/cli.py index f1944ade4..5d11a8c04 100644 --- a/cli.py +++ b/cli.py @@ -9,7 +9,7 @@ import torch -default_device = "cpu" +default_device = "cpu" def check_args(args, name: str) -> None: pass @@ -29,8 +29,14 @@ def add_arguments_for_export(parser): _add_arguments_common(parser) def add_arguments_for_browser(parser): - # Only export specific options should be here + # Only browser specific options should be here _add_arguments_common(parser) + parser.add_argument( + "--port", + type=int, + default=5000, + help="Port for the web server for browser mode." + ) def _add_arguments_common(parser): # TODO: Refactor this so that only common options are here diff --git a/torchchat.py b/torchchat.py index 44b394c48..a3e1055bc 100644 --- a/torchchat.py +++ b/torchchat.py @@ -25,7 +25,7 @@ parser = argparse.ArgumentParser(description="Top-level command") subparsers = parser.add_subparsers( dest="subcommand", - help="Use `generate`, `eval` or `export` followed by subcommand specific options.", + help="Use `generate`, `eval`, `export` or `browser` followed by subcommand specific options.", ) parser_generate = subparsers.add_parser("generate") @@ -63,10 +63,26 @@ elif args.subcommand == "browser": # TODO: add check_args() + # Look for port from cmd args. Default to 5000 if not found. + # The port args will be passed directly to the Flask app. + port = 5000 + i = 2 + while i < len(sys.argv): + # Check if the current argument is '--port' + if sys.argv[i] == '--port': + # Check if there's a value immediately following '--port' + if i + 1 < len(sys.argv): + # Extract the value and remove '--port' and the value from sys.argv + port = sys.argv[i + 1] + del sys.argv[i:i+2] # Delete '--port' and the value + break # Exit loop since port is found + else: + i += 1 + # Assume the user wants "chat" when entering "browser". TODO: add support for "generate" as well args_plus_chat = ['"{}"'.format(s) for s in sys.argv[2:]] + ['"--chat"'] + ['"--num-samples"'] + ['"1000000"'] formatted_args = ", ".join(args_plus_chat) - command = ["flask", "--app", "chat_in_browser:create_app(" + formatted_args + ")", "run"] + command = ["flask", "--app", "chat_in_browser:create_app(" + formatted_args + ")", "run", "--port", f"{port}"] subprocess.run(command) else: raise RuntimeError("Must specify valid subcommands: generate, export, eval")