Skip to content

Commit

Permalink
Remove work done in LibGMT.__init__ (#178)
Browse files Browse the repository at this point in the history
The shared library was being loaded in `__init__`, which is not a good
design pattern. Instead, now it's loaded the first time it gets used by
the new method `get_libgmt_func` that gets the ctypes function and sets
the return and argument types. This removes some redundancy from the
other methods (setting ctypes type conversions) and also removes the
usage of a private variable `_libgmt` from other methods. Now, only
methods that set a private variable need to know of it's existence.
The same applies to `_session_id` which is entirely encapsulated by
`current_session`.
  • Loading branch information
leouieda authored Apr 29, 2018
1 parent 51c76d6 commit 26c4849
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 83 deletions.
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ignore-patterns=
#init-hook=

# Use multiple processes to speed up Pylint.
jobs=1
jobs=4

# List of plugins (as comma separated values of python modules names) to load,
# usually to register additional checkers.
Expand Down Expand Up @@ -50,7 +50,7 @@ confidence=
# --enable=similarities". If you want to run only the classes checker, but have
# no Warning level messages displayed, use"--disable=all --enable=classes
# --disable=W"
disable=print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,long-suffix,old-ne-operator,old-octal-literal,import-star-module-level,raw-checker-failed,bad-inline-option,locally-disabled,locally-enabled,file-ignored,suppressed-message,useless-suppression,deprecated-pragma,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,eq-without-hash,div-method,idiv-method,rdiv-method,exception-message-attribute,invalid-str-codec,sys-max-int,bad-python3-import,deprecated-string-function,deprecated-str-translate-call
disable=print-statement,parameter-unpacking,unpacking-in-except,old-raise-syntax,backtick,long-suffix,old-ne-operator,old-octal-literal,import-star-module-level,raw-checker-failed,bad-inline-option,locally-disabled,locally-enabled,file-ignored,suppressed-message,useless-suppression,deprecated-pragma,apply-builtin,basestring-builtin,buffer-builtin,cmp-builtin,coerce-builtin,execfile-builtin,file-builtin,long-builtin,raw_input-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,no-absolute-import,old-division,dict-iter-method,dict-view-method,next-method-called,metaclass-assignment,indexing-exception,raising-string,reload-builtin,oct-method,hex-method,nonzero-method,cmp-method,input-builtin,round-builtin,intern-builtin,unichr-builtin,map-builtin-not-iterating,zip-builtin-not-iterating,range-builtin-not-iterating,filter-builtin-not-iterating,using-cmp-argument,eq-without-hash,div-method,idiv-method,rdiv-method,exception-message-attribute,invalid-str-codec,sys-max-int,bad-python3-import,deprecated-string-function,deprecated-str-translate-call,attribute-defined-outside-init

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
190 changes: 117 additions & 73 deletions gmt/clib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,6 @@ class LibGMT(): # pylint: disable=too-many-instance-attributes
'uint32': 'GMT_UINT',
}

def __init__(self):
self._logfile = None
self._session_id = None
self._libgmt = load_libgmt()

@property
def current_session(self):
"""
Expand All @@ -120,10 +115,10 @@ def current_session(self):
outside of the context manager).
"""
if self._session_id is None:
if not hasattr(self, '_session_id') or self._session_id is None:
raise GMTCLibNoSessionError(' '.join([
"No currently open session.",
"Call methods only inside a 'with' block."]))
"No currently open GMT API session.",
"Use only inside a 'with' block."]))
return self._session_id

@current_session.setter
Expand Down Expand Up @@ -153,6 +148,48 @@ def info(self):
}
return infodict

def get_libgmt_func(self, name, argtypes=None, restype=None):
"""
Get a ctypes function from the libgmt shared library.
Also assigns the argument and return type conversions to the function.
Parameters
----------
name : str
The name of the GMT API function.
argtypes : list
List of ctypes types used to convert the Python input arguments for
the API function.
restype : ctypes type
The ctypes type used to convert the input returned by the function
into a Python type.
Returns
-------
function
The GMT API function.
Examples
--------
>>> from ctypes import c_void_p, c_int
>>> with LibGMT() as lib:
... func = lib.get_libgmt_func('GMT_Destroy_Session',
... argtypes=[c_void_p], restype=c_int)
>>> type(func)
<class 'ctypes.CDLL.__init__.<locals>._FuncPtr'>
"""
if not hasattr(self, '_libgmt'):
self._libgmt = load_libgmt()
function = getattr(self._libgmt, name)
if argtypes is not None:
function.argtypes = argtypes
if restype is not None:
function.restype = restype
return function

def __enter__(self):
"""
Start the GMT session and keep the session argument.
Expand Down Expand Up @@ -204,10 +241,11 @@ def create_session(self, session_name):
Used by GMT C API functions.
"""
c_create_session = self._libgmt.GMT_Create_Session
c_create_session.argtypes = [ctypes.c_char_p, ctypes.c_uint,
ctypes.c_uint, ctypes.c_void_p]
c_create_session.restype = ctypes.c_void_p
c_create_session = self.get_libgmt_func(
'GMT_Create_Session',
argtypes=[ctypes.c_char_p, ctypes.c_uint, ctypes.c_uint,
ctypes.c_void_p],
restype=ctypes.c_void_p)

# None is passed in place of the print function pointer. It becomes the
# NULL pointer when passed to C, prompting the C API to use the default
Expand Down Expand Up @@ -240,9 +278,10 @@ def destroy_session(self, session):
The :py:class:`ctypes.CDLL` instance for the libgmt shared library.
"""
c_destroy_session = self._libgmt.GMT_Destroy_Session
c_destroy_session.argtypes = [ctypes.c_void_p]
c_destroy_session.restype = ctypes.c_int
c_destroy_session = self.get_libgmt_func(
'GMT_Destroy_Session',
argtypes=[ctypes.c_void_p],
restype=ctypes.c_int)

status = c_destroy_session(session)
if status:
Expand Down Expand Up @@ -272,9 +311,8 @@ def get_constant(self, name):
If the constant doesn't exist.
"""
c_get_enum = self._libgmt.GMT_Get_Enum
c_get_enum.argtypes = [ctypes.c_char_p]
c_get_enum.restype = ctypes.c_int
c_get_enum = self.get_libgmt_func(
'GMT_Get_Enum', argtypes=[ctypes.c_char_p], restype=ctypes.c_int)

value = c_get_enum(name.encode())

Expand Down Expand Up @@ -317,10 +355,10 @@ def get_default(self, name):
If the parameter doesn't exist.
"""
c_get_default = self._libgmt.GMT_Get_Default
c_get_default.argtypes = [ctypes.c_void_p, ctypes.c_char_p,
ctypes.c_char_p]
c_get_default.restype = ctypes.c_int
c_get_default = self.get_libgmt_func(
'GMT_Get_Default',
argtypes=[ctypes.c_void_p, ctypes.c_char_p, ctypes.c_char_p],
restype=ctypes.c_int)

# Make a string buffer to get a return value
value = ctypes.create_string_buffer(10000)
Expand Down Expand Up @@ -362,18 +400,19 @@ def log_to_file(self, logfile=None):
>>> with LibGMT() as lib:
... mode = lib.get_constant('GMT_MODULE_CMD')
... with lib.log_to_file() as logfile:
... status = lib._libgmt.GMT_Call_Module(
... lib.current_session, 'info'.encode(), mode,
... 'bogus-file.bla'.encode())
... call_module = lib.get_libgmt_func('GMT_Call_Module')
... status = call_module(lib.current_session, 'info'.encode(),
... mode, 'bogus-file.bla'.encode())
... with open(logfile) as flog:
... print(flog.read().strip())
gmtinfo [ERROR]: Error for input file: No such file (bogus-file.bla)
"""
c_handle_messages = self._libgmt.GMT_Handle_Messages
c_handle_messages.argtypes = [ctypes.c_void_p, ctypes.c_uint,
ctypes.c_uint, ctypes.c_char_p]
c_handle_messages.restype = ctypes.c_int
c_handle_messages = self.get_libgmt_func(
'GMT_Handle_Messages',
argtypes=[ctypes.c_void_p, ctypes.c_uint, ctypes.c_uint,
ctypes.c_char_p],
restype=ctypes.c_int)

if logfile is None:
tmp_file = NamedTemporaryFile(prefix='gmt-python-', suffix='.log',
Expand Down Expand Up @@ -419,10 +458,11 @@ def call_module(self, module, args):
If the returned status code of the function is non-zero.
"""
c_call_module = self._libgmt.GMT_Call_Module
c_call_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p,
ctypes.c_int, ctypes.c_void_p]
c_call_module.restype = ctypes.c_int
c_call_module = self.get_libgmt_func(
'GMT_Call_Module',
argtypes=[ctypes.c_void_p, ctypes.c_char_p, ctypes.c_int,
ctypes.c_void_p],
restype=ctypes.c_int)

mode = self.get_constant('GMT_MODULE_CMD')
# If there is no open session, this will raise an exception. Can' let
Expand Down Expand Up @@ -490,20 +530,19 @@ def create_data(self, family, geometry, mode, **kwargs):
object.
"""
c_create_data = self._libgmt.GMT_Create_Data
c_create_data.argtypes = [
ctypes.c_void_p, # API
ctypes.c_uint, # family
ctypes.c_uint, # geometry
ctypes.c_uint, # mode
ctypes.POINTER(ctypes.c_uint64), # dim
ctypes.POINTER(ctypes.c_double), # range
ctypes.POINTER(ctypes.c_double), # inc
ctypes.c_uint, # registration
ctypes.c_int, # pad
ctypes.c_void_p, # data
]
c_create_data.restype = ctypes.c_void_p
c_create_data = self.get_libgmt_func(
'GMT_Create_Data',
argtypes=[ctypes.c_void_p, # API
ctypes.c_uint, # family
ctypes.c_uint, # geometry
ctypes.c_uint, # mode
ctypes.POINTER(ctypes.c_uint64), # dim
ctypes.POINTER(ctypes.c_double), # range
ctypes.POINTER(ctypes.c_double), # inc
ctypes.c_uint, # registration
ctypes.c_int, # pad
ctypes.c_void_p], # data
restype=ctypes.c_void_p)

family_int = self._parse_constant(family, valid=self.data_families,
valid_modifiers=self.data_vias)
Expand Down Expand Up @@ -690,10 +729,11 @@ def put_vector(self, dataset, column, vector):
0.
"""
c_put_vector = self._libgmt.GMT_Put_Vector
c_put_vector.argtypes = [ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_uint, ctypes.c_uint, ctypes.c_void_p]
c_put_vector.restype = ctypes.c_int
c_put_vector = self.get_libgmt_func(
'GMT_Put_Vector',
argtypes=[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint,
ctypes.c_uint, ctypes.c_void_p],
restype=ctypes.c_int)

gmt_type = self._check_dtype_and_dim(vector, ndim=1)
vector_pointer = vector.ctypes.data_as(ctypes.c_void_p)
Expand Down Expand Up @@ -744,10 +784,11 @@ def put_matrix(self, dataset, matrix, pad=0):
0.
"""
c_put_matrix = self._libgmt.GMT_Put_Matrix
c_put_matrix.argtypes = [ctypes.c_void_p, ctypes.c_void_p,
ctypes.c_uint, ctypes.c_int, ctypes.c_void_p]
c_put_matrix.restype = ctypes.c_int
c_put_matrix = self.get_libgmt_func(
'GMT_Put_Matrix',
argtypes=[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint,
ctypes.c_int, ctypes.c_void_p],
restype=ctypes.c_int)

gmt_type = self._check_dtype_and_dim(matrix, ndim=2)
matrix_pointer = matrix.ctypes.data_as(ctypes.c_void_p)
Expand Down Expand Up @@ -797,12 +838,13 @@ def write_data(self, family, geometry, mode, wesn, output, data):
non-zero status code.
"""
c_write_data = self._libgmt.GMT_Write_Data
c_write_data.argtypes = [ctypes.c_void_p, ctypes.c_uint, ctypes.c_uint,
ctypes.c_uint, ctypes.c_uint,
ctypes.POINTER(ctypes.c_double),
ctypes.c_char_p, ctypes.c_void_p]
c_write_data.restype = ctypes.c_int
c_write_data = self.get_libgmt_func(
'GMT_Write_Data',
argtypes=[ctypes.c_void_p, ctypes.c_uint, ctypes.c_uint,
ctypes.c_uint, ctypes.c_uint,
ctypes.POINTER(ctypes.c_double), ctypes.c_char_p,
ctypes.c_void_p],
restype=ctypes.c_int)

family_int = self._parse_constant(family, valid=self.data_families,
valid_modifiers=self.data_vias)
Expand Down Expand Up @@ -880,15 +922,16 @@ def open_virtual_file(self, family, geometry, direction, data):
<vector memory>: N = 5 <0/4> <5/9>
"""
c_open_virtualfile = self._libgmt.GMT_Open_VirtualFile
c_open_virtualfile.argtypes = [ctypes.c_void_p, ctypes.c_uint,
ctypes.c_uint, ctypes.c_uint,
ctypes.c_void_p, ctypes.c_char_p]
c_open_virtualfile.restype = ctypes.c_int
c_open_virtualfile = self.get_libgmt_func(
'GMT_Open_VirtualFile',
argtypes=[ctypes.c_void_p, ctypes.c_uint, ctypes.c_uint,
ctypes.c_uint, ctypes.c_void_p, ctypes.c_char_p],
restype=ctypes.c_int)

c_close_virtualfile = self._libgmt.GMT_Close_VirtualFile
c_close_virtualfile.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
c_close_virtualfile.restype = ctypes.c_int
c_close_virtualfile = self.get_libgmt_func(
'GMT_Close_VirtualFile',
argtypes=[ctypes.c_void_p, ctypes.c_char_p],
restype=ctypes.c_int)

family_int = self._parse_constant(family, valid=self.data_families,
valid_modifiers=self.data_vias)
Expand Down Expand Up @@ -1200,10 +1243,11 @@ def extract_region(self):
-165.00, -150.00, 15.00, 25.00
"""
c_extract_region = self._libgmt.GMT_Extract_Region
c_extract_region.argtypes = [ctypes.c_void_p, ctypes.c_char_p,
ctypes.POINTER(ctypes.c_double)]
c_extract_region.restype = ctypes.c_int
c_extract_region = self.get_libgmt_func(
'GMT_Extract_Region',
argtypes=[ctypes.c_void_p, ctypes.c_char_p,
ctypes.POINTER(ctypes.c_double)],
restype=ctypes.c_int)

wesn = np.empty(4, dtype=np.float64)
# Use NaNs so that we can know if GMT didn't change the array
Expand Down
21 changes: 13 additions & 8 deletions gmt/tests/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,19 @@ def mock_api_function(*args): # pylint: disable=unused-argument

mock_func = mock_api_function

backup = getattr(lib._libgmt, func)
setattr(lib._libgmt, func, mock_func)
try:
yield
finally:
# Need to restore the original method to please pylint. Make sure it
# always happens by putting it in this finally block.
setattr(lib._libgmt, func, backup)
get_libgmt_func = lib.get_libgmt_func

def mock_get_libgmt_func(name, argtypes=None, restype=None):
"""
Return our mock function.
"""
if name == func:
return mock_func
return get_libgmt_func(name, argtypes, restype)

setattr(lib, 'get_libgmt_func', mock_get_libgmt_func)

yield


def test_load_libgmt():
Expand Down

0 comments on commit 26c4849

Please sign in to comment.