Welcome to the TOSS repository! This repository contains the core codebase for TOSS, including traditional methods and Graph Neural Network (GNN) implementations.
To get started with TOSS, follow the steps below:
First, create a Python environment. We recommend using Python version 3.9.0 for compatibility. You can use venv
or conda
to create a new environment.
# Using conda
conda create -n toss_env python=3.9.0
conda activate toss_env
After activating your environment, install the necessary dependencies listed in the requirements.txt
file.
pip install -r requirements.txt
Once the dependencies are installed, you can proceed to run the main scripts:
run.py
: This script handles the initial setup and execution of the traditional TOSS methods.train.py
: This script is used to train the GNN models.
# Run traditional TOSS methods
python TOSS/toss/run.py
# Train GNN models
python TOSS/toss_gnn/train.py
TOSS/toss/
: Contains the traditional TOSS methods and related scripts.TOSS/toss_gnn/
: Contains the code for training and evaluating Graph Neural Networks.
- Use FeTiO3.cif as an Example (Fe3O4 can also be used as an example with mixed valence)
# Example CIF file
mid = "FeTiO3.cif" # "Fe3O4.cif" or "Modified_Prussian_Blue.cif" can be used similarly
# If you only want to check the sample, please make a directory, and move the cif files in it.
Show results using the pre-trained model and display them in a 3D plot.
import pandas as pd
import sys
sys.path.append("D:/share/TOSS/toss_GNN")
from data_utils import *
from dataset_utils_pyg import *
from model_utils_pyg import *
from Predict import Get_OS_by_models
from GNN_vis import VIS as GNN_VIS
sys.path.append("D:/share/TOSS/toss")
from result import RESULT
from pre_set import PRE_SET
from Get_Initial_Guess import get_the_valid_t
from get_fos import GET_FOS
from Get_TOS import get_Oxidation_State
from TOSS_vis import VIS as TOSS_VIS
Load pre-trained models for link prediction (LP) and node classification (NC).
LP_model = pyg_Hetero_GCNPredictor(atom_feats=13, bond_feats=13, hidden_feats=[256,256,256,256],
predictor_hidden_feats=64, n_tasks=2, predictor_dropout=0.3)
NC_model = pyg_GCNPredictor(in_feats=15, hidden_feats=[256, 256, 256, 256],
predictor_hidden_feats=64, n_tasks=12, predictor_dropout=0.3)
LP_model.load_state_dict(torch.load("./models/pyg_Hetero_GCN_s_0608.pth"))
NC_model.load_state_dict(torch.load("./models/pyg_GCN_s_0609.pth"))
All keys matched successfully.
toss = Get_OS_by_models(mid, LP_model, NC_model)
pred_res = toss.NC_predict()
pd.DataFrame([pred_res["ele"], pred_res["os"], pred_res["cn"]], index=["Elements", "Valence", "Coordination Number"])
vis = GNN_VIS(pred_res)
vis.draw()
vis.show_fig()
Visit the Predicted 3D plot on our webpage.
import pandas as pd
import numpy as np
import sys
# Append TOSS path to system path
sys.path.append("./toss")
# Import packages from TOSS
from result import RESULT
from pre_set import PRE_SET
from Get_Initial_Guess import get_the_valid_t
from get_fos import GET_FOS
from Get_TOS import get_Oxidation_States
The modules GET_STRUCTURE and DIGEST are wrapped in the function get_the_valid_t
, which returns the valid tolerances for the given structure. In this example, only one tolerance is valid, i.e., 1.1.
valid_t = get_the_valid_t(m_id=mid)
valid_t
This is the 0th structure with mid FeTiO3.cif, and we got 3 different valid tolerances: [1.1, 1.12, 1.14]
Perform the initial guess for the OS and CN and display the results in a DataFrame.
GFOS = GET_FOS()
res = RESULT()
GFOS.initial_guess(m_id=mid, delta_X=0.1, tolerance=1.1, tolerance_list=valid_t, res=res)
pd.DataFrame([res.elements_list, res.sum_of_valence, res.shell_CN_list], index=["Elements", "Valence", "Coordination Number"])
Perform the final result for OS and display the results in a DataFrame.
RES = get_Oxidation_States(m_id=mid, input_tolerance_list=valid_t)[-1]
pd.DataFrame([RES.elements_list, RES.sum_of_valence, RES.shell_CN_list], index=["Elements", "Valence", "Coordination Number"])
Got the Formal Oxidation State of the 0th structure FeTiO3.cif in 1.350132942199707 seconds.
It will show one 3D plot (no change) or two 3D plots for better visualization of the changes during the MAP process. The opacity of the spheres represents the loss values projected onto each atom.
from visualization import VS
vis = TOSS_VIS(RES,res,loss_ratio=10)
vis.show_fig()
CNs are DIFFERENT! USE two figs! Visit the TOSS 3D plot on our webpage.
This project is licensed under the MIT License - see the LICENSE file for details.