Skip to content

Commit

Permalink
Pass port in from cmd. If not passed, default to 5000 (it's the Flask…
Browse files Browse the repository at this point in the history
… default) (#302)

* chat in browser

* pass port in as an arg

* more comment
  • Loading branch information
Olivia-liu authored Apr 19, 2024
1 parent bfc7a6d commit 87a561f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
10 changes: 8 additions & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

default_device = "cpu"
default_device = "cpu"

def check_args(args, name: str) -> None:
pass
Expand All @@ -29,8 +29,14 @@ def add_arguments_for_export(parser):
_add_arguments_common(parser)

def add_arguments_for_browser(parser):
# Only export specific options should be here
# Only browser specific options should be here
_add_arguments_common(parser)
parser.add_argument(
"--port",
type=int,
default=5000,
help="Port for the web server for browser mode."
)

def _add_arguments_common(parser):
# TODO: Refactor this so that only common options are here
Expand Down
20 changes: 18 additions & 2 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
parser = argparse.ArgumentParser(description="Top-level command")
subparsers = parser.add_subparsers(
dest="subcommand",
help="Use `generate`, `eval` or `export` followed by subcommand specific options.",
help="Use `generate`, `eval`, `export` or `browser` followed by subcommand specific options.",
)

parser_generate = subparsers.add_parser("generate")
Expand Down Expand Up @@ -63,10 +63,26 @@
elif args.subcommand == "browser":
# TODO: add check_args()

# Look for port from cmd args. Default to 5000 if not found.
# The port args will be passed directly to the Flask app.
port = 5000
i = 2
while i < len(sys.argv):
# Check if the current argument is '--port'
if sys.argv[i] == '--port':
# Check if there's a value immediately following '--port'
if i + 1 < len(sys.argv):
# Extract the value and remove '--port' and the value from sys.argv
port = sys.argv[i + 1]
del sys.argv[i:i+2] # Delete '--port' and the value
break # Exit loop since port is found
else:
i += 1

# Assume the user wants "chat" when entering "browser". TODO: add support for "generate" as well
args_plus_chat = ['"{}"'.format(s) for s in sys.argv[2:]] + ['"--chat"'] + ['"--num-samples"'] + ['"1000000"']
formatted_args = ", ".join(args_plus_chat)
command = ["flask", "--app", "chat_in_browser:create_app(" + formatted_args + ")", "run"]
command = ["flask", "--app", "chat_in_browser:create_app(" + formatted_args + ")", "run", "--port", f"{port}"]
subprocess.run(command)
else:
raise RuntimeError("Must specify valid subcommands: generate, export, eval")

0 comments on commit 87a561f

Please sign in to comment.