Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download button in demo app; upgraded gradio to latest #59

Merged
merged 4 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion demo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
raidionicsrads@git+https://github.com/dbouget/raidionics_rads_lib
gradio==3.50.2
gradio==4.29.0
6 changes: 5 additions & 1 deletion demo/src/css_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
#upload {
height: 110px;
}
#download {
height: 47px;
width: 150px;
}
#run-button {
height: 110px;
height: 47px;
width: 150px;
}
#toggle-button {
Expand Down
67 changes: 45 additions & 22 deletions demo/src/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(
visible=True,
elem_id="model-3d",
camera_position=[90, 180, 768],
).style(height=512)
height=512,
)

def set_class_name(self, value):
LOGGER.info(f"Changed task to: {value}")
Expand All @@ -75,30 +76,44 @@ def upload_file(self, file):

def process(self, mesh_file_name):
path = mesh_file_name.name
curr = path.split("/")[-1]
self.extension = ".".join(curr.split(".")[1:])
self.filename = (
curr.split(".")[0] + "-" + self.class_names[self.class_name]
)
run_model(
path,
model_path=os.path.join(self.cwd, "resources/models/"),
task=self.class_names[self.class_name],
name=self.result_names[self.class_name],
output_filename=self.filename + "." + self.extension,
)
LOGGER.info("Converting prediction NIfTI to OBJ...")
nifti_to_obj("prediction.nii.gz")
nifti_to_obj(path=self.filename + "." + self.extension)

LOGGER.info("Loading CT to numpy...")
self.images = load_ct_to_numpy(path)

LOGGER.info("Loading prediction volume to numpy..")
self.pred_images = load_pred_volume_to_numpy("./prediction.nii.gz")
self.pred_images = load_pred_volume_to_numpy(
self.filename + "." + self.extension
)

return "./prediction.obj"

def download_prediction(self):
if (not self.filename) or (not self.extension):
LOGGER.error(
"The prediction is not available or ready to download. Wait until the result is available in the 3D viewer."
)
return self.filename + "." + self.extension

def get_img_pred_pair(self, k):
k = int(k)
out = gr.AnnotatedImage(
self.combine_ct_and_seg(self.images[k], self.pred_images[k]),
visible=True,
elem_id="model-2d",
).style(
color_map={self.class_name: "#ffae00"},
height=512,
width=512,
Expand All @@ -117,20 +132,18 @@ def run(self):
placeholder="\n" * 16,
label="Logs",
info="Verbose from inference will be displayed below.",
lines=38,
max_lines=38,
lines=36,
max_lines=36,
autoscroll=True,
elem_id="logs",
show_copy_button=True,
scroll_to_output=False,
container=True,
line_breaks=True,
)
demo.load(read_logs, None, logs, every=1)

with gr.Column():
with gr.Row():
with gr.Column(scale=0.2, min_width=150):
with gr.Column(scale=1, min_width=150):
sidebar_state = gr.State(True)

btn_toggle_sidebar = gr.Button(
Expand All @@ -149,7 +162,9 @@ def run(self):
btn_clear_logs.click(flush_logs, [], [])

file_output = gr.File(
file_count="single", elem_id="upload"
file_count="single",
elem_id="upload",
scale=3,
)
file_output.upload(
self.upload_file, file_output, file_output
Expand All @@ -160,29 +175,38 @@ def run(self):
label="Task",
info="Which structure to segment.",
multiselect=False,
size="sm",
scale=1,
)
model_selector.input(
fn=lambda x: self.set_class_name(x),
inputs=model_selector,
outputs=None,
)

with gr.Column(scale=0.2, min_width=150):
with gr.Column(scale=1, min_width=150):
run_btn = gr.Button(
"Run analysis",
variant="primary",
elem_id="run-button",
).style(
full_width=False,
size="lg",
)
run_btn.click(
fn=lambda x: self.process(x),
inputs=file_output,
outputs=self.volume_renderer,
)

download_btn = gr.DownloadButton(
"Download prediction",
visible=True,
variant="secondary",
elem_id="download",
)
download_btn.click(
fn=self.download_prediction,
inputs=None,
outputs=download_btn,
)

with gr.Row():
gr.Examples(
examples=[
Expand All @@ -202,17 +226,16 @@ def run(self):
)

with gr.Row():
with gr.Box():
with gr.Group():
with gr.Column():
# create dummy image to be replaced by loaded images
t = gr.AnnotatedImage(
visible=True, elem_id="model-2d"
).style(
visible=True,
elem_id="model-2d",
color_map={self.class_name: "#ffae00"},
height=512,
width=512,
# height=512,
# width=512,
)

self.slider.input(
self.get_img_pred_pair,
self.slider,
Expand All @@ -221,7 +244,7 @@ def run(self):

self.slider.render()

with gr.Box():
with gr.Group(): # gr.Box():
self.volume_renderer.render()

# sharing app publicly -> share=True:
Expand Down
6 changes: 5 additions & 1 deletion demo/src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def run_model(
verbose: str = "info",
task: str = "CT_Airways",
name: str = "Airways",
output_filename: str = None,
):
if verbose == "debug":
logging.getLogger().setLevel(logging.DEBUG)
Expand All @@ -27,6 +28,9 @@ def run_model(
if os.path.exists("./result/"):
shutil.rmtree("./result/")

if output_filename is None:
raise ValueError("Please, set output_filename.")

patient_directory = ""
output_path = ""
try:
Expand Down Expand Up @@ -84,7 +88,7 @@ def run_model(
+ "-t1gd_annotation-"
+ name
+ ".nii.gz",
"./prediction.nii.gz",
output_filename,
)
# Clean-up
if os.path.exists(patient_directory):
Expand Down
Loading