Skip to content

Commit

Permalink
Merge branch 'master' into dos.get_cbm_vbm-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
shyuep authored Feb 18, 2025
2 parents bf6c904 + 8d6d337 commit b00d94b
Show file tree
Hide file tree
Showing 35 changed files with 848 additions and 506 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.2
rev: v0.9.4
hooks:
- id: ruff
args: [--fix, --unsafe-fixes]
Expand All @@ -27,7 +27,7 @@ repos:
- id: mypy

- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.1
hooks:
- id: codespell
stages: [pre-commit, commit-msg]
Expand All @@ -48,7 +48,7 @@ repos:
- id: blacken-docs

- repo: https://github.com/igorshubovych/markdownlint-cli
rev: v0.43.0
rev: v0.44.0
hooks:
- id: markdownlint
# MD013: line too long
Expand All @@ -65,6 +65,6 @@ repos:
args: [--drop-empty-cells, --keep-output]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.391
rev: v1.1.393
hooks:
- id: pyright
13 changes: 8 additions & 5 deletions dev_scripts/potcar_scrambler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class PotcarScrambler:
"""
Takes a POTCAR and replaces its values with completely random values
Takes a POTCAR and replaces its values with completely random values.
Does type matching and attempts precision matching on floats to ensure
file is read correctly by Potcar and PotcarSingle classes.
Expand All @@ -40,14 +40,15 @@ class PotcarScrambler:

def __init__(self, potcars: Potcar | PotcarSingle) -> None:
self.PSP_list = [potcars] if isinstance(potcars, PotcarSingle) else potcars
self.scrambled_potcars_str = ""
self.scrambled_potcars_str: str = ""
for psp in self.PSP_list:
scrambled_potcar_str = self.scramble_single_potcar(psp)
self.scrambled_potcars_str += scrambled_potcar_str

def _rand_float_from_str_with_prec(self, input_str: str, bloat: float = 1.5) -> float:
n_prec = len(input_str.split(".")[1])
bd = max(1, bloat * abs(float(input_str))) # ensure we don't get 0
"""Generate a random float from str to replace true values."""
n_prec: int = len(input_str.split(".")[1])
bd: float = max(1.0, bloat * abs(float(input_str))) # ensure we don't get 0
return round(bd * np.random.default_rng().random(), n_prec)

def _read_fortran_str_and_scramble(self, input_str: str, bloat: float = 1.5):
Expand Down Expand Up @@ -124,14 +125,16 @@ def scramble_single_potcar(self, potcar: PotcarSingle) -> str:
return scrambled_potcar_str

def to_file(self, filename: str) -> None:
"""Write scrambled POTCAR to file."""
with zopen(filename, mode="wt", encoding="utf-8") as file:
file.write(self.scrambled_potcars_str)

@classmethod
def from_file(cls, input_filename: str, output_filename: str | None = None) -> Self:
"""Read a POTCAR from file and generate a scrambled version."""
psp = Potcar.from_file(input_filename)
psp_scrambled = cls(psp)
if output_filename:
if output_filename is not None:
psp_scrambled.to_file(output_filename)
return psp_scrambled

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ dependencies = [
"networkx>=2.7", # PR4116
"palettable>=3.3.3",
"pandas>=2",
"plotly>=4.5.0,<6.0.0",
"plotly>=5.0.0",
"pybtex>=0.24.0",
"requests>=2.32",
"ruamel.yaml>=0.17.0",
Expand Down
97 changes: 61 additions & 36 deletions src/pymatgen/analysis/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from pymatgen.util.plotting import add_fig_kwargs, get_ax_fig, pretty_plot

if TYPE_CHECKING:
from typing import ClassVar
from collections.abc import Sequence
from typing import Any, ClassVar

import matplotlib.pyplot as plt

Expand All @@ -40,7 +41,11 @@ class EOSBase(ABC):
implementations.
"""

def __init__(self, volumes, energies):
def __init__(
self,
volumes: Sequence[float],
energies: Sequence[float],
) -> None:
"""
Args:
volumes (Sequence[float]): in Ang^3.
Expand All @@ -50,18 +55,28 @@ def __init__(self, volumes, energies):
self.energies = np.array(energies)
# minimum energy(e0), buk modulus(b0),
# derivative of bulk modulus w.r.t. pressure(b1), minimum volume(v0)
self._params = None
self._params: Sequence | None = None
# the eos function parameters. It is the same as _params except for
# equation of states that uses polynomial fits(delta_factor and
# numerical_eos)
self.eos_params = None
self.eos_params: Sequence | None = None

def _initial_guess(self):
def __call__(self, volume: float) -> float:
"""
Args:
volume (float | list[float]): volume(s) in Ang^3.
Returns:
Compute EOS with this volume.
"""
return self.func(volume)

def _initial_guess(self) -> tuple[float, float, float, float]:
"""
Quadratic fit to get an initial guess for the parameters.
Returns:
tuple: 4 floats for (e0, b0, b1, v0)
tuple[float, float, float, float]: e0, b0, b1, v0
"""
a, b, c = np.polyfit(self.volumes, self.energies, 2)
self.eos_params = [a, b, c]
Expand All @@ -78,7 +93,7 @@ def _initial_guess(self):

return e0, b0, b1, v0

def fit(self):
def fit(self) -> None:
"""
Do the fitting. Does least square fitting. If you want to use custom
fitting, must override this.
Expand Down Expand Up @@ -120,24 +135,20 @@ def func(self, volume):
"""
return self._func(np.array(volume), self.eos_params)

def __call__(self, volume: float) -> float:
"""
Args:
volume (float | list[float]): volume(s) in Ang^3.
Returns:
Compute EOS with this volume.
"""
return self.func(volume)

@property
def e0(self) -> float:
"""The min energy."""
if self._params is None:
raise RuntimeError("params have not be initialized.")

return self._params[0]

@property
def b0(self) -> float:
"""The bulk modulus in units of energy/unit of volume^3."""
if self._params is None:
raise RuntimeError("params have not be initialized.")

return self._params[1]

@property
Expand All @@ -156,11 +167,18 @@ def v0(self):
return self._params[3]

@property
def results(self):
def results(self) -> dict[str, Any]:
"""A summary dict."""
return {"e0": self.e0, "b0": self.b0, "b1": self.b1, "v0": self.v0}

def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
def plot(
self,
width: float = 8,
height: float | None = None,
ax: plt.Axes = None,
dpi: float | None = None,
**kwargs,
) -> plt.Axes:
"""
Plot the equation of state.
Expand All @@ -170,7 +188,7 @@ def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
golden ratio.
ax (plt.Axes): If supplied, changes will be made to the existing Axes.
Otherwise, new Axes will be created.
dpi:
dpi (float): DPI.
kwargs (dict): additional args fed to pyplot.plot.
supported keys: style, color, text, label
Expand Down Expand Up @@ -211,16 +229,18 @@ def plot(self, width=8, height=None, ax: plt.Axes = None, dpi=None, **kwargs):
return ax

@add_fig_kwargs
def plot_ax(self, ax: plt.Axes = None, fontsize=12, **kwargs):
def plot_ax(
self,
ax: plt.Axes | None = None,
fontsize: float = 12,
**kwargs,
) -> plt.Figure:
"""
Plot the equation of state on axis `ax`.
Args:
ax: matplotlib Axes or None if a new figure should be created.
fontsize: Legend fontsize.
color (str): plot color.
label (str): Plot label
text (str): Legend text (options)
Returns:
plt.Figure: matplotlib figure.
Expand Down Expand Up @@ -270,7 +290,7 @@ def plot_ax(self, ax: plt.Axes = None, fontsize=12, **kwargs):
class Murnaghan(EOSBase):
"""Murnaghan EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""From PRB 28,5480 (1983)."""
e0, b0, b1, v0 = tuple(params)
return e0 + b0 * volume / b1 * (((v0 / volume) ** b1) / (b1 - 1.0) + 1.0) - v0 * b0 / (b1 - 1.0)
Expand All @@ -279,7 +299,7 @@ def _func(self, volume, params):
class Birch(EOSBase):
"""Birch EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""From Intermetallic compounds: Principles and Practice, Vol. I:
Principles Chapter 9 pages 195-210 by M. Mehl. B. Klein,
D. Papaconstantopoulos.
Expand All @@ -296,7 +316,7 @@ def _func(self, volume, params):
class BirchMurnaghan(EOSBase):
"""BirchMurnaghan EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""BirchMurnaghan equation from PRB 70, 224107."""
e0, b0, b1, v0 = tuple(params)
eta = (v0 / volume) ** (1 / 3)
Expand All @@ -306,7 +326,7 @@ def _func(self, volume, params):
class PourierTarantola(EOSBase):
"""Pourier-Tarantola EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""Pourier-Tarantola equation from PRB 70, 224107."""
e0, b0, b1, v0 = tuple(params)
eta = (volume / v0) ** (1 / 3)
Expand All @@ -317,7 +337,7 @@ def _func(self, volume, params):
class Vinet(EOSBase):
"""Vinet EOS."""

def _func(self, volume, params):
def _func(self, volume, params: tuple[float, float, float, float]):
"""Vinet equation from PRB 70, 224107."""
e0, b0, b1, v0 = tuple(params)
eta = (volume / v0) ** (1 / 3)
Expand All @@ -335,7 +355,7 @@ class PolynomialEOS(EOSBase):
def _func(self, volume, params):
return np.poly1d(list(params))(volume)

def fit(self, order):
def fit(self, order: int) -> None:
"""
Do polynomial fitting and set the parameters. Uses numpy polyfit.
Expand All @@ -345,7 +365,7 @@ def fit(self, order):
self.eos_params = np.polyfit(self.volumes, self.energies, order)
self._set_params()

def _set_params(self):
def _set_params(self) -> None:
"""
Use the fit polynomial to compute the parameter e0, b0, b1 and v0
and set to the _params attribute.
Expand All @@ -372,7 +392,7 @@ def _func(self, volume, params):
x = volume ** (-2 / 3.0)
return np.poly1d(list(params))(x)

def fit(self, order=3):
def fit(self, order: int = 3) -> None:
"""Overridden since this eos works with volume**(2/3) instead of volume."""
x = self.volumes ** (-2 / 3.0)
self.eos_params = np.polyfit(x, self.energies, order)
Expand Down Expand Up @@ -407,7 +427,12 @@ def _set_params(self):
class NumericalEOS(PolynomialEOS):
"""A numerical EOS."""

def fit(self, min_ndata_factor=3, max_poly_order_factor=5, min_poly_order=2):
def fit(
self,
min_ndata_factor: int = 3,
max_poly_order_factor: int = 5,
min_poly_order: int = 2,
) -> None:
"""Fit the input data to the 'numerical eos', the equation of state employed
in the quasiharmonic Debye model described in the paper:
10.1103/PhysRevB.90.174107.
Expand Down Expand Up @@ -539,7 +564,7 @@ class EOS:
eos_fit.plot()
"""

MODELS: ClassVar = {
MODELS: ClassVar[dict[str, Any]] = {
"murnaghan": Murnaghan,
"birch": Birch,
"birch_murnaghan": BirchMurnaghan,
Expand All @@ -549,7 +574,7 @@ class EOS:
"numerical_eos": NumericalEOS,
}

def __init__(self, eos_name="murnaghan"):
def __init__(self, eos_name: str = "murnaghan") -> None:
"""
Args:
eos_name (str): Type of EOS to fit.
Expand All @@ -562,7 +587,7 @@ def __init__(self, eos_name="murnaghan"):
self._eos_name = eos_name
self.model = self.MODELS[eos_name]

def fit(self, volumes, energies):
def fit(self, volumes: Sequence[float], energies: Sequence[float]) -> EOSBase:
"""Fit energies as function of volumes.
Args:
Expand Down
Loading

0 comments on commit b00d94b

Please sign in to comment.