Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Towards Type Stability #101

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/AtomsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ using StaticArrays
using Requires

include("interface.jl")
include("chem.jl")
include("properties.jl")
include("visualize_ascii.jl")
include("show.jl")
include("cell.jl")
include("flexible_system.jl")
include("atomview.jl")
include("atom.jl")
Expand Down
130 changes: 108 additions & 22 deletions src/atom.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,45 @@
#
# A simple and flexible atom implementation
#
export Atom, atomic_system, periodic_system, isolated_system
export Atom, FastAtom,
atomic_system, periodic_system, isolated_system


# Valid types for user-oriented atom identifiers
const AtomId = Union{Symbol, AbstractString, Integer, ChemicalElement}


# ---------------------------------------------
# Implementation of a maximally flexible Atom type

# Valid types for atom identifiers
const AtomId = Union{Symbol,AbstractString,Integer}

struct Atom{D, L<:Unitful.Length, V<:Unitful.Velocity, M<:Unitful.Mass}
position::SVector{D, L}
velocity::SVector{D, V}
atomic_symbol::Symbol
atomic_number::Int
chemical_element::ChemicalElement
atomic_mass::M
data::Dict{Symbol, Any} # Store arbitrary data about the atom.
end

velocity(atom::Atom) = atom.velocity
position(atom::Atom) = atom.position
atomic_mass(atom::Atom) = atom.atomic_mass
atomic_symbol(atom::Atom) = atom.atomic_symbol
atomic_number(atom::Atom) = atom.atomic_number
element(atom::Atom) = element(atomic_number(atom))
atomic_symbol(atom::Atom) = atom.chemical_element
atomic_number(atom::Atom) = atomic_number(atom.chemical_element)
element(atom::Atom) = element(atom.chemical_element)
n_dimensions(::Atom{D}) where {D} = D

Base.getindex(at::Atom, x::Symbol) = hasfield(Atom, x) ? getfield(at, x) : getindex(at.data, x)
Base.haskey(at::Atom, x::Symbol) = hasfield(Atom, x) || haskey(at.data, x)
function Base.get(at::Atom, x::Symbol, default)
hasfield(Atom, x) ? getfield(at, x) : get(at.data, x, default)
end
Base.getindex(at::Atom, x::Symbol) =
hasfield(Atom, x) ? getfield(at, x) : getindex(at.data, x)
Base.haskey(at::Atom, x::Symbol) =
hasfield(Atom, x) || haskey(at.data, x)
Base.get(at::Atom, x::Symbol, default) =
haskey(at, x) ? getindex(at, x) : get(at.data, x, default)

function Base.keys(at::Atom)
(:position, :velocity, :atomic_symbol, :atomic_number, :atomic_mass, keys(at.data)...)
(:position, :velocity, :chemical_element, :atomic_mass, keys(at.data)...)
end

Base.pairs(at::Atom) = (k => at[k] for k in keys(at))

"""
Expand All @@ -46,18 +56,19 @@ Supported `kwargs` include `atomic_symbol`, `atomic_number`, `atomic_mass`, `cha
function Atom(identifier::AtomId,
position::AbstractVector{L},
velocity::AbstractVector{V}=zeros(length(position))u"bohr/s";
atomic_symbol=Symbol(element(identifier).symbol),
atomic_number=element(identifier).number,
atomic_mass::M=element(identifier).atomic_mass,
chemical_element = ChemicalElement(identifier),
atomic_mass::M = element(chemical_element).atomic_mass,
kwargs...) where {L <: Unitful.Length, V <: Unitful.Velocity, M <: Unitful.Mass}
Atom{length(position), L, V, M}(position, velocity, atomic_symbol,
atomic_number, atomic_mass, Dict(kwargs...))
Atom{length(position), L, V, M}(position, velocity, chemical_element,
atomic_mass, Dict(kwargs...))
end

function Atom(id::AtomId, position::AbstractVector, velocity::Missing; kwargs...)
Atom(id, position, zeros(length(position))u"bohr/s"; kwargs...)
end
function Atom(; atomic_symbol, position, velocity=zeros(length(position))u"bohr/s", kwargs...)
Atom(atomic_symbol, position, velocity; atomic_symbol, kwargs...)

function Atom(; chemical_element, position, velocity=zeros(length(position))u"bohr/s", kwargs...)
Atom(chemical_element, position, velocity; kwargs...)
end

"""
Expand Down Expand Up @@ -88,7 +99,82 @@ Base.show(io::IO, at::Atom) = show_atom(io, at)
Base.show(io::IO, mime::MIME"text/plain", at::Atom) = show_atom(io, mime, at)


#


# ---------------------------------------------
# Implementation of a FastAtom type which is isbits
# and provides equivalent functionality to `FastSystem`

"""
`FastAtom` : atom type compatible with `FastSystem`, i.e. providing only
atom id (chemical element), position and mass. It is isbits and hence
can be used with `FlexibleSystem` and achieve similar performance
as `FastSystem`.

### Constructors:
```
Atom(identifier::AtomId, position::AbstractVector; atomic_mass)
Atom(; identifier, position, atomic_mass)
```
If `mass` is not provided then it the default provided by `PeriodicTable.jl`
is assigned.
"""
struct FastAtom{D, L<:Unitful.Length, M<:Unitful.Mass}
chemical_element::ChemicalElement
position::SVector{D, L}
atomic_mass::M
end

Base.show(io::IO, at::FastAtom) = show_atom(io, at)
Base.show(io::IO, mime::MIME"text/plain", at::FastAtom) = show_atom(io, mime, at)


velocity(atom::FastAtom) = missing
position(atom::FastAtom) = atom.position
atomic_mass(atom::FastAtom) = atom.atomic_mass
atomic_symbol(atom::FastAtom) = atom.chemical_element
atomic_number(atom::FastAtom) = atomic_number(atom.chemical_element)
element(atom::FastAtom) = element(atom.chemical_element)
n_dimensions(::FastAtom{D}) where {D} = D

Base.getindex(at::FastAtom, x::Symbol) =
hasfield(FastAtom, x) ? getfield(at, x) : missing

Base.haskey(at::FastAtom, x::Symbol) =
hasfield(FastAtom, x)

Base.get(at::FastAtom, x::Symbol, default) =
hasfield(FastAtom, x) ? getfield(at, x) : default

Base.keys(at::FastAtom) =
(:position, :element, :atomic_mass,)

Base.pairs(at::FastAtom) =
(k => at[k] for k in keys(at))


function FastAtom(identifier::AtomId, position::AbstractVector{L};
chemical_element = ChemicalElement(identifier),
atomic_mass::M = atomic_mass(chemical_element),
) where {L <: Unitful.Length, M <: Unitful.Mass}
D = length(position)
return FastAtom(chemical_element, SVector{D}(position), atomic_mass)
end

function FastAtom(; atomic_symbol, position, kwargs...)
return FastAtom(atomic_symbol, position; kwargs...)
end


# ---------------------------------------------
# a few alternative getters

for f in (velocity, position, atomic_mass, atomic_symbol, atomic_number, element)
@eval Base.getindex(at::Union{Atom, FastAtom}, ::typeof($f)) = $f(at)
end


# ---------------------------------------------
# Special high-level functions to construct atomic systems
#

Expand Down
4 changes: 4 additions & 0 deletions src/atomview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
using `AtomView` as its species type.

## Example
```jldoctest; setup=:(using AtomsBase, Unitful; atoms = Atom[:C => [0.25, 0.25, 0.25]u"Å", :C => [0.75, 0.75, 0.75]u"Å"]; box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"Å"; boundary_conditions = [Periodic(), Periodic(), DirichletZero()])

Check failure on line 16 in src/atomview.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/AtomsBase.jl/AtomsBase.jl/src/atomview.jl:16-28 ```jldoctest; setup=:(using AtomsBase, Unitful; atoms = Atom[:C => [0.25, 0.25, 0.25]u"Å", :C => [0.75, 0.75, 0.75]u"Å"]; box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"Å"; boundary_conditions = [Periodic(), Periodic(), DirichletZero()]) julia> system = FastSystem(atoms, box, boundary_conditions); julia> atom = system[2] AtomView(C, atomic_number = 6, atomic_mass = 12.011 u): position : [0.75,0.75,0.75]u"Å" julia> atom isa AtomView{typeof(system)} true julia> atomic_symbol(atom) :C ``` Subexpression: atom = system[2] Evaluated output: AtomView(C, atomic_mass = 12.011 u): position : [0.75,0.75,0.75]u"Å" Expected output: AtomView(C, atomic_number = 6, atomic_mass = 12.011 u): position : [0.75,0.75,0.75]u"Å" diff = Warning: Diff output requires color. AtomView(C, atomic_number = 6, atomic_mass = 12.011 u): u): position : [0.75,0.75,0.75]u"Å"

Check failure on line 16 in src/atomview.jl

View workflow job for this annotation

GitHub Actions / Documentation

doctest failure in ~/work/AtomsBase.jl/AtomsBase.jl/src/atomview.jl:16-28 ```jldoctest; setup=:(using AtomsBase, Unitful; atoms = Atom[:C => [0.25, 0.25, 0.25]u"Å", :C => [0.75, 0.75, 0.75]u"Å"]; box = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]u"Å"; boundary_conditions = [Periodic(), Periodic(), DirichletZero()]) julia> system = FastSystem(atoms, box, boundary_conditions); julia> atom = system[2] AtomView(C, atomic_number = 6, atomic_mass = 12.011 u): position : [0.75,0.75,0.75]u"Å" julia> atom isa AtomView{typeof(system)} true julia> atomic_symbol(atom) :C ``` Subexpression: atomic_symbol(atom) Evaluated output: C Expected output: :C diff = Warning: Diff output requires color. :CC
julia> system = FastSystem(atoms, box, boundary_conditions);

julia> atom = system[2]
Expand All @@ -37,6 +37,7 @@
ismissing(vel) && return missing
return vel[v.index]
end

position(v::AtomView) = position(v.system)[v.index]
atomic_mass(v::AtomView) = atomic_mass(v.system)[v.index]
atomic_symbol(v::AtomView) = atomic_symbol(v.system)[v.index]
Expand All @@ -48,9 +49,12 @@
Base.show(io::IO, mime::MIME"text/plain", at::AtomView) = show_atom(io, mime, at)

Base.getindex(v::AtomView, x::Symbol) = getindex(v.system, v.index, x)

Base.haskey(v::AtomView, x::Symbol) = hasatomkey(v.system, x)

function Base.get(v::AtomView, x::Symbol, default)
hasatomkey(v.system, x) ? v[x] : default
end

Base.keys(v::AtomView) = atomkeys(v.system)
Base.pairs(at::AtomView) = (k => at[k] for k in keys(at))
100 changes: 100 additions & 0 deletions src/cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

"""
Implementation of a computational cell for particle systems
within AtomsBase.jl. `PCell` specifies a parallepiped shaped cell
with
"""
struct PCell{D, T}
cell_vectors::NTuple{D, SVector{D, T}}
pbc::NTuple{D, Bool}
end

bounding_box(cell::PCell) = cell.cell_vectors

boundary_conditions(cell::PCell) = map(p -> p ? Periodic() : OpenBC(), cell.pbc)

periodicity(cell::PCell) = cell.pbc

isinfinite(cell::PCell) = map(!, cell.pbc)

n_dimensions(::PCell{D}) where {D} = D

# ---------------------- pretty printing

function Base.show(io::IO, cell::PCell{D}) where {D}
u = unit(first(cell.cell_vectors[1][1]))
print(io, "PCell(", prod(p -> p ? "T" : "F", cell.pbc), ", ")
for d = 1:D
print(io, ustrip.(cell.cell_vectors[d]), u)
if d < D; print(io, ", "); end
end
println(")")
end

function Base.show(io::IO, ::MIME"text/plain", cell::PCell{D}) where {D}
u = unit(first(cell.cell_vectors[1][1]))
println(io, " PCell(", prod(p -> p ? "T" : "F", cell.pbc), ",")
for d = 1:D
print(io, " ", ustrip.(cell.cell_vectors[d]), u)
d < D && println(",")
end
println(")")
end


# ---------------------- Constructors

# different ways to construct cell vectors

function _auto_cell_vectors(vecs::Tuple)
D = length(vecs)
@assert all(length.(vecs) .== D) "All cell vectors must have the same length"
return ntuple(i -> SVector{D}(vecs[i]), D)
end

_auto_cell_vectors(vecs::AbstractVector) =
_auto_cell_vectors(tuple(vecs...))

# .... could consider allowing construction from a matrix but
# that introduced an ambiguity (transpose?) that we may
# wish to avoid.

# different ways to construct PBC

_auto_pbc1(bc::Bool) = bc
_auto_pbc1(::Periodic) = true
_auto_pbc1(::OpenBC) = false
_auto_pbc1(::Nothing) = false
_auto_pbc1(::DirichletZero) = false

_auto_pbc(bc::Tuple, cell_vectors = nothing) =
map(_auto_pbc1, bc)

_auto_pbc(bc::AbstractVector, cell_vectors = nothing) =
_auto_pbc(tuple(bc...))

_auto_pbc(bc::Union{Bool, Nothing, BoundaryCondition}, cell_vectors) =
ntuple(i -> _auto_pbc1(bc), length(cell_vectors))

# kwarg constructor for PCell

PCell(; cell_vectors, boundary_conditions) =
PCell(_auto_cell_vectors(cell_vectors),
_auto_pbc(boundary_conditions, cell_vectors))



# ----------------------
# interface functions to connect Systems and cells

bounding_box(system::SystemWithCell{D, <: PCell}) where {D} =
bounding_box(system.cell)

boundary_conditions(system::SystemWithCell{D, <: PCell}) where {D} =
boundary_conditions(system.cell)

periodicity(system::SystemWithCell{D, <: PCell}) where {D} =
periodicity(system.cell)

isinfinite(system::SystemWithCell{D, <: PCell}) where {D} =
isinfinite(system.cell)
55 changes: 55 additions & 0 deletions src/chem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# ---------------------------------------------
# Simple wrapper for chemical element type

import PeriodicTable
using Unitful


"""
Encodes a chemical element by wrapping an integer that represents the atomic
number.
"""
struct ChemicalElement
atomic_number::UInt8
end

Base.show(io::IO, element::ChemicalElement) =
print(io, Symbol(element))

ChemicalElement(sym::Symbol) = ChemicalElement(_sym2z[sym])
ChemicalElement(sym::ChemicalElement) = sym

import Base.==

==(a::ChemicalElement, sym::Symbol) = (Symbol(a) == sym)

# -------- fast access to the periodic table

@assert length(PeriodicTable.elements) == maximum(el.number for el in PeriodicTable.elements)
@assert all(el.number == i for (i, el) in enumerate(PeriodicTable.elements))

const _chem_el_info = [
(symbol = Symbol(PeriodicTable.elements[z].symbol),
atomic_mass = PeriodicTable.elements[z].atomic_mass, )
for z in 1:length(PeriodicTable.elements)
]

const _sym2z = Dict{Symbol, UInt8}()
for z in 1:length(_chem_el_info)
_sym2z[_chem_el_info[z].symbol] = z
end

# -------- accessor functions

atomic_number(element::ChemicalElement) = element.atomic_number

atomic_symbol(element::ChemicalElement) = element

Base.convert(::Type{Symbol}, element::ChemicalElement) = Symbol(element)
Symbol(element::ChemicalElement) = _chem_el_info[element.atomic_number].symbol

atomic_mass(element::ChemicalElement) = _chem_el_info[element.atomic_number].atomic_mass

rich_info(element::ChemicalElement) = PeriodicTable.elements[element.atomic_number]

element(element::ChemicalElement) = rich_info(element)
Loading
Loading