Skip to content

3D Dental surface segmentation with Tooth Group Network

Notifications You must be signed in to change notification settings

MuZhenPro/ToothGroupNetwork

 
 

Repository files navigation

ToothGroupNetwork

  • Team CGIP
  • Ho Yeon Lim and Min Chang Kim

Notice

  • Please press the star!
  • This repository contains the code for the tooth segmentation algorithm that won first place at Miccai2022 3D Teeth Scan Segmentation and Labeling Challenge.
  • Official Paper for Challenge: 3DTeethSeg'22: 3D Teeth Scan Segmentation and Labeling Challenge.
  • We used the dataset shared in the challenge git repository.
  • If you only want the inference code, or if you want to use the same checkpoints that we used in the challenge, you can jump to challenge_branch in this repository.
  • Most of the code has been modified. It may be significantly different from the code you downloaded before June 20th. If you downloaded this repository, we recommend download it again.
  • There may be a lot of errors in the initial code. If you encounter any errors, please post them on the issue board, and we will promptly address and fix them.
  • If you have any problem with execution, please contact me via email([email protected]). We'll give you a quick reply.

Dataset

  • We used the dataset shared in the challenge git repository. For more information about the data, check out the link.
  • Dataset consists of dental mesh obj files and corresponding ground truth json files.
  • You can also download the challenge training split data in google drive(our codes are based on this data).
    • After you download and unzip these zip files, merge 3D_scans_per_patient_obj_files.zip and 3D_scans_per_patient_obj_files_b2.zip. The parent path of these obj files is data_obj_parent_directory.
    • Apply the same to the ground truth json files(ground-truth_labels_instances.zip and ground-truth_labels_instances_b2.zip. The parent path of these json files is data_json_parent_directory).
    • The directory structure of your data should look like below..
      --data_obj_parent_directory
      ----00OMSZGW
      ------00OMSZGW_lower.obj
      ------00OMSZGW_upper.obj
      ----0EAKT1CU
      ------0EAKT1CU_lower.obj
      ------0EAKT1CU_upper.obj
      and so on..
      
      --data_json_parent_directory
      ----00OMSZGW
      ------00OMSZGW_lower.json
      ------00OMSZGW_upper.jsno
      ----0EAKT1CU
      ------0EAKT1CU_lower.json
      ------0EAKT1CU_upper.json
      and so on..
      
  • If you have your dental mesh data, you can use it.
    • In such cases, you need to adhere to the data name format(casename_upper.obj or casename_lower.obj).

    • All axes must be aligned as shown in the figure below. Note that the Y-axis points towards the back direction(plz note that both lower jaw and upper jaw have the same z-direction!).

Training

Preprocessing

  • For training, you have to execute the preprocess_data.py to save the farthest point sampled vertices of the mesh (.obj) files.
  • Here is an example of how to execute preprocess_data.py.
    python preprocess.py
     --source_obj_data_path data_obj_parent_directory \
     --source_json_data_path data_json_parent_directory \
     --save_data_path path/for/preprocessed_data
    

Training

  • We offer six models(tsegnet | tgnet(ours) | pointnet | pointnetpp | dgcnn | pointtransformer).
  • For experiment tracking, we use wandb. Please replace "entity" with your own wandb ID in the get_default_config function of train_configs/train_config_maker.py.
  • Due to the memory constraints, all functions have been implemented based on batch size 1 (minimum 11GB GPU RAM required). If you want to change the batch size to a different value, you will need to modify most of the functions by yourself.

1. tgnet(Ours)

  • The tgnet is our 3d tooth segmentation method. Please refer to the challenge paper for an explanation of the methodology.
  • You should first train the Farthest Point Sampling model and then train the Boundary Aware Point Sampling model.
  • First, train the Farthest Point Sampling model as follows.
    start_train.py \
     --model_name "tgnet_fps" \
     --config_path "train_configs/tgnet_fps.py" \
     --experiment_name "your_experiment_name" \
     --input_data_dir_path "path/for/preprocessed_data" \
     --train_data_split_txt_path "base_name_train_fold.txt" \
     --val_data_split_txt_path "base_name_val_fold.txt"
    
    • Input the preprocessed data directory path into the --input_data_dir_path.
    • You can provide the train/validation split text files through --train_data_split_txt_path and --val_data_split_txt_path. You can either use the provided text files from the above dataset drive link(base_name_*_fold.txt) or create your own text files for the split.
  • To train the Boundary Aware Point Sampling model, please modify the following four configurations in train_configs/tgnet_bdl.py: original_data_obj_path, original_data_json_path, bdl_cache_path, and load_ckpt_path. image
  • After modifying the configurations, train the Boundary Aware Point Sampling model as follows.
    start_train.py \
     --model_name "tgnet_bdl" \
     --config_path "train_configs/tgnet_bdl.py" \
     --experiment_name "your_experiment_name" \
     --input_data_dir_path "path/to/save/preprocessed_data" \
     --train_data_split_txt_path "base_name_train_fold.txt" \
     --val_data_split_txt_path "base_name_val_fold.txt"
    

2. tsegnet

  • This is the implementation of model TSegNet. Please refer to the paper for detail.
  • First, The centroid prediction module has to be trained first in tsegnet. To train the centroid prediction module, please modify the run_tooth_seg_mentation_module parameter to False in the train_configs/tsegnet.py file. image
  • And train the centroid prediction module by entering the following command.
    start_train.py \
     --model_name "tsegnet" \
     --config_path "train_configs/tsegnet.py" \
     --experiment_name "your_experiment_name" \
     --input_data_dir_path "path/to/save/preprocessed_data" \
     --train_data_split_txt_path "base_name_train_fold.txt" \
     --val_data_split_txt_path "base_name_val_fold.txt"
    
  • Once the training of the centroid prediction module is completed, please update the pretrained_centroid_model_path in train_configs/tsegnet.py with the checkpoint path of the trained centroid prediction module. Also, set run_tooth_segmentation_module to True.
  • And please train the tsegnet model by entering the following command.
    start_train.py \
     --model_name "tsegnet" \
     --config_path "train_configs/tsegnet.py" \
     --experiment_name "your_experiment_name" \
     --input_data_dir_path "path/to/save/preprocessed_data" \
     --train_data_split_txt_path "base_name_train_fold.txt" \
     --val_data_split_txt_path "base_name_val_fold.txt"
    

3. pointnet | pointnetpp | dgcnn | pointtransformer

  • pointnet | pointnet++ | dgcnn | pointtransformer
  • This model directly applies the point cloud segmentation method to tooth segmentation.
  • train models by entering the following command.
    start_train.py \
     --model_name "pointnet" \
     --config_path "train_configs/pointnet.py" \
     --experiment_name "your_experiment_name" \
     --input_data_dir_path "path/to/save/preprocessed_data" \
     --train_data_split_txt_path "base_name_train_fold.txt" \
     --val_data_split_txt_path "base_name_val_fold.txt"
    

Inference

  • To test the performance of the model used in the challenge, please switch to the challenge_branch (refer to the notice at the top).
  • We offer six models(tsegnet | tgnet(ours) | pointnet | pointnetpp | dgcnn | pointtransformer).
  • All of the checkpoint files for each model are in (https://drive.google.com/drive/folders/15oP0CZM_O_-Bir18VbSM8wRUEzoyLXby?usp=sharing). Download ckpts(new).zip and unzip all of the checkpoints.
  • Inference with tsegnet / pointnet / pointnetpp / dgcnn / pointtransformer
    python start_inference.py \
     --input_dir_path obj/file/parent/path \
     --split_txt_path base_name_test_fold.txt \
     --save_path path/to/save/results \
     --model_name tgnet_fps \
     --checkpoint_path your/model/checkpoint/path
    
  • Inference with tgnet(ours)
    python start_inference.py \
     --input_dir_path obj/file/parent/path \
     --split_txt_path base_name_test_fold.txt \
     --save_path path/to/save/results \
     --model_name tgnet_fps \
     --checkpoint_path your/tgnet_fps/checkpoint/path
     --checkpoint_path_bdl your/tgnet_bdl/checkpoint/path
    
    • Please input the parent path of the original mesh obj files instead of the preprocessed sampling points in --input_data_dir_path for training. The inference process will handle the farthest point sampling internally.
    • For split_txt_path, provide the test split fold's casenames in the same format as used during training.
  • Predicted results are saved in save_path like below... It has the same format as the ground truth json file.
    --save_path
    ----00OMSZGW_lower.json
    ----00OMSZGW_upper.json
    ----0EAKT1CU_lower.json
    ----0EAKT1CU_upper.json
    and so on...
    
  • the inference config in "inference_pipelines.infenrece_pipeline_maker.py" has to be the same as the model of the train config. If you change the train config, then you have to change the inference config.

Test results

  • The checkpoints we provided were trained for 60 epochs using the train-validation split provided in the dataset drive link(base_name_train_fold.txt, base_name_val_fold.txt). The results obtained using the test split(base_name_test_fold.txt) are as follows.

    image

    • IoU -> Intersection over Union(TSA in challenge) // CLS -> classification accuracy(TIR in challenge)

Evaulation & Visualization

  • We provide the evaluation and visualization code.
  • You can execute the following code to test on a pair of obj/gt json file:
    eval_visualize_results.py \
     --mesh_path path/to/obj_file \ 
     --gt_json_path path/to/gt_json_file \
     --pred_json_path path/to/predicted_json_file(a result of inference code)
    
  • With just a few modifications to the provided code, you can write code to test all the results.

Installation

  • Installtion is tested on pytorch/pytorch:1.7.1-cuda11.0-cudnn8-devel(ubuntu, pytorch 1.7.1) docker image.
  • It can be installed on other OS(window, etc..)
  • There are some issues with RTX40XX graphic cards. plz report in issue board.
  • if you have any issues while installing pointops library, please install another pointops library source in (https://github.com/POSTECH-CVLab/point-transformer).
    pip install wandb
    pip install --ignore-installed PyYAML
    pip install open3d
    pip install multimethod
    pip install termcolor
    pip install trimesh
    pip install easydict
    cd external_libs/pointops && python setup.py install
    

Reference codes

About

3D Dental surface segmentation with Tooth Group Network

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 91.0%
  • Cuda 5.3%
  • C++ 3.7%