diff --git a/.gitignore b/.gitignore index a63de96ac6d6..fe1264c2b747 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,14 @@ dmlc-core mshadow config.mk + +*.pyc +.Rhistory +*log +Debug +*suo + +# vim +*.swp +*.swo +*.swn diff --git a/Makefile b/Makefile index 0aa41cc4fd57..8892255f4386 100644 --- a/Makefile +++ b/Makefile @@ -48,20 +48,23 @@ endif BIN = test/api_registry_test OBJ = storage.o narray_op_cpu.o operator.o operator_cpu.o -OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o +# add threaded engine after it is done +OBJCXX11 = engine.o narray.o mxnet_api.o api_registry.o engine.o CUOBJ = narray_op_gpu.o operator_gpu.o - +SLIB = api/libmxnet.so +ALIB = api/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a .PHONY: clean all -all: $(OBJ) $(OBJCXX11) $(CUOBJ) $(BIN) +all: $(ALIB) $(SLIB) $(BIN) $(DMLC_CORE)/libdmlc.a: + cd $(DMLC_CORE); make libdmlc.a config=$(ROOTDIR)/$(config); cd $(ROOTDIR) storage.o: src/storage/storage.cc engine.o: src/dag_engine/simple_engine.cc +threaded_engine.o: src/dag_engine/threaded_engine.cc src/common/concurrent_blocking_queue.h src/common/spin_lock.h narray.o: src/narray/narray.cc narray_op_cpu.o: src/narray/narray_op_cpu.cc src/narray/narray_op-inl.h narray_op_gpu.o: src/narray/narray_op_gpu.cu src/narray/narray_op-inl.h @@ -71,7 +74,10 @@ operator_gpu.o: src/operator/operator_gpu.cu api_registry.o: src/api_registry.cc mxnet_api.o: api/mxnet_api.cc -test/api_registry_test: test/api_registry_test.cc $(OBJ) $(OBJCXX11) $(CUOBJ) +api/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) +api/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) + +test/api_registry_test: test/api_registry_test.cc api/libmxnet.a $(BIN) : $(CXX) $(CFLAGS) -std=c++11 -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) @@ -85,6 +91,9 @@ $(OBJCXX11) : $(SLIB) : $(CXX) $(CFLAGS) -shared -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) +$(ALIB): + ar cr $@ $+ + $(CUOBJ) : $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $(filter %.cu, $^) @@ -92,5 +101,5 @@ $(CUBIN) : $(NVCC) -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -Xlinker "$(LDFLAGS)" $(filter %.cu %.cpp %.o, $^) clean: - $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) *~ */*~ */*/*~ + $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - diff --git a/README.md b/README.md index d2d7f138e26e..647e0f02d093 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,15 @@ # MXNet -This is an experimental project to put cxxnet and minerva together, nothing is working yet. +This is a project that combines lessons and ideas we learnt from [cxxnet](https://github.com/dmlc/cxxnet), [minerva](https://github.com/dmlc/minerva) and [purine2](https://github.com/purine/purine2). +- The interface is designed in collaboration by authors of three projects. +- Nothing is yet working # Guidelines * Use google c style * Put module header in [include](include) - - move them to ```project-name/include``` when we finalized the name * Depend on [dmlc-core](https://github.com/dmlc/dmlc-core) * Doxygen comment every function, class and variable for the module headers - Ref headers in [dmlc-core/include](https://github.com/dmlc/dmlc-core/tree/master/include/dmlc) - Use the same style as dmlc-core -* Try write some use-cases of interface in [test](test) - - They do not need to link, but need to pass compile * Minimize dependency, if possible only depend on dmlc-core * Macro Guard CXX11 code by - Try to make interface compile when c++11 was not avaialable(but with some functionalities pieces missing) diff --git a/api/mxnet_api.cc b/api/mxnet_api.cc index c38f5dc99092..5ca45257f786 100644 --- a/api/mxnet_api.cc +++ b/api/mxnet_api.cc @@ -1,4 +1,191 @@ +#include +#include #include #include +#include #include "./mxnet_api.h" +// NOTE: all functions return 0 upon success +// consider add try/catch block for user error +// handling in the future +using namespace mxnet; + +// macro to guard beginning and end section of all functions +// every function starts with API_BEGIN(); and finishes with API_END(); +#define API_BEGIN() try { +#define API_END() } catch(dmlc::Error &e) { return MXHandleException(e); } return 0; + +/*! + * \brief a helper function for error handling + * will set the last error to be str_set when it is not NULL + * \param str_set the error to set + * \return a pointer message to last error + */ +const char *MXSetGetLastError_(const char *str_set) { + // use last_error to record last error + static thread_local std::string last_error; + if (str_set != NULL) { + last_error = str_set; + } + return last_error.c_str(); +} + +/*! \brief return str message of the last error */ +const char *MXGetLastError() { + return MXSetGetLastError_(NULL); +} + +/*! + * \brief handle exception throwed out + * \param e the exception + * \return the return value of API after exception is handled + */ +int MXHandleException(const dmlc::Error &e) { + MXSetGetLastError_(e.what()); + return -1; +} + +// NOTE: return value is added in API_END +int MXNArrayCreateNone(NArrayHandle *out) { + API_BEGIN(); + *out = new NArray(); + API_END(); +} + +int MXNArrayCreateShareMem(mx_float *data, + mx_uint *shape, + mx_uint ndim, + NArrayHandle *out) { + API_BEGIN(); + *out = new NArray(TBlob(data, TShape(shape, shape + ndim), + cpu::kDevMask), 0); + API_END(); +} + +int MXNArrayCreate(const mx_uint *shape, + mx_uint ndim, + int dev_mask, + int dev_id, + int delay_alloc, + NArrayHandle *out) { + API_BEGIN(); + *out = new NArray(TShape(shape, shape + ndim), + Context(dev_mask, dev_id), + delay_alloc != 0); + API_END(); +} + +int MXNArrayWait(NArrayHandle handle) { + API_BEGIN(); + static_cast(handle)->Wait(); + API_END(); +} + +int MXNArrayWaitAll() { + API_BEGIN(); + DAGEngine::Get()->WaitForAll(); + API_END(); +} + +int MXNArrayFree(NArrayHandle handle) { + API_BEGIN(); + delete static_cast(handle); + API_END(); +} + +int MXNArrayGetShape(NArrayHandle handle, + mx_uint *out_dim, + const mx_uint **out_pdata) { + API_BEGIN(); + NArray *arr = static_cast(handle); + if (!arr->is_none()) { + const TShape &s = arr->shape(); + *out_dim = s.ndim(); + *out_pdata = s.data(); + } else { + *out_dim = 0; + } + API_END(); +} + +int MXNArrayGetData(NArrayHandle handle, + mx_float **out_pdata) { + API_BEGIN(); + NArray *arr = static_cast(handle); + if (!arr->is_none()) { + CHECK(arr->ctx().dev_mask == cpu::kDevMask) + << "MXNArrayGetData can only be called for NArray on CPU"; + const TBlob &b = arr->data(); + CHECK(b.CheckContiguous()); + *out_pdata = b.FlatTo2D().dptr_; + } else { + *out_pdata = nullptr; + } + API_END(); +} + +int MXNArrayGetContext(NArrayHandle handle, + int *out_dev_mask, + int *out_dev_id) { + API_BEGIN(); + NArray *arr = static_cast(handle); + if (!arr->is_none()) { + const Context &ctx = arr->ctx(); + *out_dev_mask = ctx.dev_mask; + *out_dev_id = ctx.dev_id; + } else { + *out_dev_mask = 0; + *out_dev_id = 0; + } + API_END(); +} + +int MXListFunctions(mx_uint *out_size, + FunctionHandle **out_array) { + API_BEGIN(); + auto &vec = FunctionRegistry::List(); + *out_size = static_cast(vec.size()); + *out_array = (FunctionHandle*)(dmlc::BeginPtr(vec)); + API_END(); +} + +int MXGetFunction(const char *name, + FunctionHandle *out) { + API_BEGIN(); + *out = FunctionRegistry::Find(name); + API_END(); +} + +int MXFuncGetName(FunctionHandle fun, + const char **out_name) { + API_BEGIN(); + auto *f = static_cast(fun); + *out_name = f->name.c_str(); + API_END(); +} + +int MXFuncDescribe(FunctionHandle fun, + mx_uint *num_use_vars, + mx_uint *num_scalars, + mx_uint *num_mutate_vars, + int *type_mask) { + API_BEGIN(); + auto *f = static_cast(fun); + *num_use_vars = f->num_use_vars; + *num_scalars = f->num_scalars; + *num_mutate_vars = f->num_mutate_vars; + *type_mask = f->type_mask; + API_END(); +} + +int MXFuncInvoke(FunctionHandle fun, + NArrayHandle *use_vars, + mx_float *scalar_args, + NArrayHandle *mutate_vars) { + API_BEGIN(); + auto *f = static_cast(fun); + (*f)((NArray**)(use_vars), + scalar_args, + (NArray**)(mutate_vars)); + API_END(); +} diff --git a/api/mxnet_api.h b/api/mxnet_api.h index 7c4cb18b9f5d..0710f9cee37a 100644 --- a/api/mxnet_api.h +++ b/api/mxnet_api.h @@ -26,7 +26,7 @@ typedef float mx_float; /*! \brief handle to NArray */ typedef void *NArrayHandle; /*! \brief handle to a mxnet narray function that changes NArray */ -typedef void *FunctionHandle; +typedef const void *FunctionHandle; /*! \brief handle to a symbol that can be bind as operator */ typedef void *SymbolHandle; /*! \brief handle to a NArrayOperator */ @@ -34,6 +34,16 @@ typedef void *OperatorHandle; /*! \brief handle to a DataIterator */ typedef void *DataIterHandle; +/*! + * \brief return str message of the last error + * all function in this file will return 0 when success + * and -1 when an error occured, + * MXGetLastError can be called to retrieve the error + * + * this function is threadsafe and can be called by different thread + */ +MXNET_DLL const char *MXGetLastError(); + //-------------------------------- // Part 1: NArray creation and deletion //-------------------------------- @@ -71,6 +81,8 @@ MXNET_DLL int MXNArrayCreateShareMem(mx_float *data, * \param ndim the dimension of the shape * \param dev_mask device mask, specify device we want to take * \param dev_id the device id of the specific device + * \param delay_alloc whether to delay allocation until + * the narray is first mutated * \param out the returning handle * \return 0 when success, -1 when failure happens */ @@ -78,6 +90,7 @@ MXNET_DLL int MXNArrayCreate(const mx_uint *shape, mx_uint ndim, int dev_mask, int dev_id, + int delay_alloc, NArrayHandle *out); /*! * \brief wait until all the operation with respect NArray @@ -105,25 +118,27 @@ MXNET_DLL int MXNArrayFree(NArrayHandle handle); * \param out_pdata pointer holder to get data pointer of the shape * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetShape(NArrayHandle *handle, +MXNET_DLL int MXNArrayGetShape(NArrayHandle handle, mx_uint *out_dim, - mx_uint **out_pdata); + const mx_uint **out_pdata); /*! * \brief get the content of the data in NArray * \param handle the handle to the narray * \param out_pdata pointer holder to get pointer of data * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetData(NArrayHandle *handle, +MXNET_DLL int MXNArrayGetData(NArrayHandle handle, mx_float **out_pdata); /*! - * \brief get the device of the NArray + * \brief get the context of the NArray * \param handle the handle to the narray - * \param out_device the output device mask + * \param out_dev_mask the output device mask + * \param out_dev_id the output device id * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXNArrayGetDevice(NArrayHandle *handle, - int *out_device); +MXNET_DLL int MXNArrayGetContext(NArrayHandle handle, + int *out_dev_mask, + int *out_dev_id); //-------------------------------- // Part 2: functions on NArray @@ -158,13 +173,15 @@ MXNET_DLL int MXFuncGetName(FunctionHandle fun, * \param num_use_vars how many NArrays to be passed in as used_vars * \param num_scalars scalar variable is needed * \param num_mutate_vars how many NArrays to be passed in as mutate_vars + * \param type_mask the type mask of this function * \return 0 when success, -1 when failure happens * \sa MXFuncInvoke */ -MXNET_DLL int MXFuncDescribeArgs(FunctionHandle fun, - mx_uint *num_use_vars, - mx_uint *num_scalars, - mx_uint *num_mutate_vars); +MXNET_DLL int MXFuncDescribe(FunctionHandle fun, + mx_uint *num_use_vars, + mx_uint *num_scalars, + mx_uint *num_mutate_vars, + int *type_mask); /*! * \brief invoke a function, the array size of passed in arguments diff --git a/api/python/mxnet/__init__.py b/api/python/mxnet/__init__.py new file mode 100644 index 000000000000..c78c0d485159 --- /dev/null +++ b/api/python/mxnet/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# coding: utf-8 +"""MXNet: a concise, fast and flexible framework for deep learning + +MXNet is a project that evolves from cxxnet, minerva and purine2. +The interface is designed in collaboration by authors of three projects. + +Version : 0.10 +""" +from __future__ import absolute_import + +from .context import Context, current_context +from .narray import NArray, _init_function_registry +from .function import _FunctionRegistry + +# this is a global function registry that can be used to invoke functions +op = _init_function_registry(_FunctionRegistry()) diff --git a/api/python/mxnet/base.py b/api/python/mxnet/base.py new file mode 100644 index 000000000000..441c67fd092d --- /dev/null +++ b/api/python/mxnet/base.py @@ -0,0 +1,135 @@ +# coding: utf-8 +""" ctypes library of mxnet and helper functions """ +from __future__ import absolute_import + +import os +import sys +import ctypes +import platform +import numpy as np + +#---------------------------- +# library loading +#---------------------------- +if sys.version_info[0] == 3: + string_types = str, +else: + string_types = basestring, + + +class MXNetError(Exception): + """Error that will be throwed by all mxnet functions""" + pass + + +def _load_lib(): + """load libary by searching possible path""" + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + api_path = os.path.join(curr_path, '../../') + dll_path = [api_path, curr_path] + if os.name == 'nt': + if platform.architecture()[0] == '64bit': + dll_path.append(os.path.join(api_path, '../windows/x64/Release/')) + else: + dll_path.append(os.path.join(api_path, '../windows/Release/')) + if os.name == 'nt': + dll_path = [os.path.join(p, 'mxnet.dll') for p in dll_path] + else: + dll_path = [os.path.join(p, 'libmxnet.so') for p in dll_path] + lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] + if len(dll_path) == 0: + raise MXNetError('cannot find find the files in the candicate path ' + str(dll_path)) + lib = ctypes.cdll.LoadLibrary(lib_path[0]) + + # DMatrix functions + lib.MXGetLastError.restype = ctypes.c_char_p + return lib + + +# library instance of mxnet +lib = _load_lib() + +# type definitions +mx_uint = ctypes.c_uint +mx_float = ctypes.c_float +NArrayHandle = ctypes.c_void_p +FunctionHandle = ctypes.c_void_p + +#---------------------------- +# helper function definition +#---------------------------- + +def check_call(ret): + """Check the return value of C API call + + This function will raise exception when error occurs. + Wrap every API call with this function + + Parameters + ---------- + ret : int + return value from API calls + """ + if ret != 0: + raise MXNetError(lib.MXGetLastError()); + + +def c_str(string): + """Create ctypes char * from a python string + + Parameters + ---------- + string : string type + python string + + Returns + ------- + a char pointer that can be passed to C API + """ + + return ctypes.c_char_p(string.encode('utf-8')) + + +def c_array(ctype, values): + """Create ctypes array from a python array + + Parameters + ---------- + ctype : ctypes data type + data type of the array we want to convert to + + values : tuple or list + data content + + Returns + ------- + created ctypes array + """ + return (ctype * len(values))(*values) + + +def ctypes2numpy_shared(cptr, shape): + """Convert a ctypes pointer to a numpy array + + The result numpy array shares the memory with the pointer + + Parameters + ---------- + cptr : ctypes.POINTER(mx_float) + pointer to the memory region + + shape : tuple + shape of target narray + + Returns + ------- + a numpy array : numpy array + """ + if not isinstance(cptr, ctypes.POINTER(mx_float)): + raise RuntimeError('expected float pointer') + size = 1 + for s in shape: + size *= s + dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents)) + return np.frombuffer(dbuffer, dtype = np.float32).reshape(shape) + diff --git a/api/python/mxnet/context.py b/api/python/mxnet/context.py new file mode 100644 index 000000000000..a440e310ce30 --- /dev/null +++ b/api/python/mxnet/context.py @@ -0,0 +1,50 @@ +# coding: utf-8 +""" code for context management """ +from __future__ import absolute_import + +class Context: + """Context representing device and device id in mxnet""" + # static class variable + default_ctx = None + devmask2type = { 1: 'cpu', 2: 'gpu'} + devtype2mask = {'cpu': 1, 'gpu': 2 } + + def __init__(self, device_type, device_id = 0): + """Constructing a context + + Parameters + ---------- + device_type : str (can be 'cpu' or 'gpu') + a string representing the device type + + device_id : int (default=0) + the device id of the device, needed for GPU + """ + self.device_mask = Context.devtype2mask[device_type] + self.device_id = device_id + + @property + def device_type(self): + return Context.devmask2type[self.device_mask] + + def __str__(self): + return 'Context(device_type=%s, device_id=%d)' % ( + self.device_type, self.device_id) + + def __repr__(self): + return self.__str__() + + def __enter__(self): + self._old_ctx = Context.default_ctx + Context.default_ctx = self + return self + + def __exit__(self, type, value, trace): + Context.default_ctx= self._old_ctx + +# initialize the default context in Context +Context.default_ctx = Context('cpu', 0) + +def current_context(): + """Return the current context""" + return Context.default_ctx diff --git a/api/python/mxnet/function.py b/api/python/mxnet/function.py new file mode 100644 index 000000000000..149d88a0f450 --- /dev/null +++ b/api/python/mxnet/function.py @@ -0,0 +1,131 @@ +# coding: utf-8 +"""NArray functions support of mxnet""" +from __future__ import absolute_import + +import ctypes +from .base import lib +from .base import c_array +from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle +from .base import check_call, MXNetError +from .narray import NArray, _new_empty_handle + +class _Function: + # constants for type masks + NARRAY_ARG_BEFORE_SCALAR = 1 + SCALAR_ARG_BEFORE_NARRAY = 1 << 1 + ACCEPT_EMPTY_MUTATE_TARGET = 1 << 2 + + def __init__(self, handle, name): + """Initialize the function with handle + + Parameters + ---------- + handle : FunctionHandle + the function handle of the function + + name : string + the name of the function + """ + self.handle = handle + self.name = name + n_used_vars = mx_uint() + n_scalars = mx_uint() + n_mutate_vars = mx_uint() + type_mask = ctypes.c_int() + check_call(lib.MXFuncDescribe( + self.handle, + ctypes.byref(n_used_vars), + ctypes.byref(n_scalars), + ctypes.byref(n_mutate_vars), + ctypes.byref(type_mask))) + self.n_used_vars = n_used_vars.value + self.n_scalars = n_scalars.value + self.n_mutate_vars = n_mutate_vars.value + self.type_mask = type_mask.value + # infer type of the function + if (self.type_mask & _Function.NARRAY_ARG_BEFORE_SCALAR) != 0: + self.use_vars_range = range(0, self.n_used_vars) + self.scalar_range = range(self.n_used_vars, + self.n_used_vars + self.n_scalars) + else: + self.scalar_range = range(0, self.n_scalars) + self.use_vars_range = range(self.n_scalars, + self.n_scalars + self.n_used_vars) + self.accept_empty_mutate = (self.type_mask & + _Function.ACCEPT_EMPTY_MUTATE_TARGET) != 0 + + def __call__(self, *args, **kwargs): + """Invoke this function by passing in parameters + + Parameters + ---------- + *args: positional arguments + positional arguments of input scalars and NArray + + mutate_vars: kwarg(optional) + provide the NArray to store the result of the operation + + Returns + ------- + the result NArrays of mutated result + """ + if 'mutate_vars' in kwargs: + mutate_vars = kwargs['mutate_vars'] + if isinstance(mutate_vars, NArray): + mutate_vars = (mutate_vars,) + if len(mutate_vars) != self.n_mutate_vars: + raise MXNetError('expect %d mutate_vars in op.%s', self.n_mutate_vars, self.name) + else: + if self.accept_empty_mutate: + mutate_vars = tuple( + NArray(_new_empty_handle()) for i in range(self.n_mutate_vars)) + else: + raise MXNetError('mutate_vars argument is required to call op.%s' % self.name) + + self.invoke_with_handle_([args[i].handle for i in self.use_vars_range], + [args[i] for i in self.scalar_range], + [v.handle for v in mutate_vars]) + if self.n_mutate_vars == 1: + return mutate_vars[0] + else: + return mutate_vars + + def invoke_with_handle_(self, use_vars, scalars, mutate_vars): + """Invoke this function by passing in arguments as tuples + + This is a very primitive call to the function handle that + involves passing in a C handle + + Parameters + ---------- + fhandle : FunctionHandle + function handle of C API + + use_vars : tuple + tuple of NArray handles + + scalars : tuple + tuple of real number arguments + + mutate_vars : tuple + tuple of NArray handles to mutate + """ + check_call(lib.MXFuncInvoke( + self.handle, + c_array(NArrayHandle, use_vars), + c_array(mx_float, scalars), + c_array(NArrayHandle, mutate_vars))) + +class _FunctionRegistry: + def __init__(self): + plist = ctypes.POINTER(ctypes.c_void_p)() + size = ctypes.c_uint() + check_call(lib.MXListFunctions(ctypes.byref(size), + ctypes.byref(plist))) + hmap = {} + for i in range(size.value): + h = plist[i] + name = ctypes.c_char_p() + check_call(lib.MXFuncGetName(h, ctypes.byref(name))) + hmap[name.value] = _Function(h, name.value) + self.__dict__.update(hmap) diff --git a/api/python/mxnet/narray.py b/api/python/mxnet/narray.py new file mode 100644 index 000000000000..b16270cdba61 --- /dev/null +++ b/api/python/mxnet/narray.py @@ -0,0 +1,207 @@ +# coding: utf-8 +"""NArray interface of mxnet""" +from __future__ import absolute_import + +import ctypes +import numpy as np +from .base import lib +from .base import c_array +from .base import mx_uint, mx_float, NArrayHandle +from .base import ctypes2numpy_shared +from .base import check_call +from .base import MXNetError +from .context import Context + +# op is implicitly imported from .function +# as a singleton of _FunctionRegistry +global op + +def _new_empty_handle(): + """Return a new empty handle + + Empty handle can be used to hold result + + Returns + ------- + a new empty narray handle + """ + h = NArrayHandle() + check_call(lib.MXNArrayCreateNone(ctypes.byref(h))) + return h + +def _new_alloc_handle(shape, ctx, delay_alloc): + """Return a new handle with specified shape, context + + Empty handle is only used to hold results + Returns + ------- + a new empty narray handle + """ + h = NArrayHandle() + check_call(lib.MXNArrayCreate( + c_array(mx_uint, shape), + len(shape), + ctx.device_mask, + ctx.device_id, + int(delay_alloc), + ctypes.byref(h))) + return h + +class NArray(object): + """NArray object in mxnet + + NArray is basic ndarray like data structure in mxnet + """ + def __init__(self, handle): + """initialize a new NArray + + Parameters + ---------- + handle : NArrayHandle + NArray handle of C API + """ + assert isinstance(handle, NArrayHandle) + self.handle = handle + + def __del__(self): + check_call(lib.MXNArrayFree(self.handle)) + + def __add__(self, other): + hret = _new_empty_handle() + if isinstance(other, NArray): + op.plus.invoke_with_handle_((other.handle, self.handle), (), (hret,)) + else: + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + hret = _new_empty_handle() + if isinstance(other, NArray): + op.minus.invoke_with_handle_((other.handle, self.handle), (), (hret,)) + else: + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) + + def __mul__(self, other): + hret = _new_empty_handle() + if isinstance(other, NArray): + op.mul.invoke_with_handle_((other.handle, self.handle), (), (hret,)) + else: + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) + + def __rmul__(self, other): + return self.__mul__(other) + + def __div__(self, other): + hret = _new_empty_handle() + if isinstance(other, NArray): + op.div.invoke_with_handle_((other.handle, self.handle), (), (hret,)) + else: + raise MXNetError('type %s not supported' % str(type(other))) + return NArray(handle = hret) + + def wait(self): + """Wait until the data on current NArray is available""" + check_call(lib.MXNArrayWait(self.handle)) + + @property + def shape(self): + """Get shape of current NArray + + Returns + ------- + a tuple representing shape of current narray + """ + ndim = mx_uint() + pdata = ctypes.POINTER(mx_uint)() + check_call(lib.MXNArrayGetShape( + self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) + return tuple(pdata[i] for i in range(ndim.value)) + + @property + def context(self): + """Get context of current NArray + + Returns + ------- + the context of current NArray + """ + dev_mask = ctypes.c_int() + dev_id = ctypes.c_int() + check_call(lib.MXNArrayGetContext( + self.handle, ctypes.byref(dev_mask), ctypes.byref(dev_id))) + return Context(Context.devmask2type[dev_mask.value], dev_id.value) + + @property + def numpy(self): + """Return a numpy representation of current array + + This array have to sit on CPU + + Returns + ------- + a numpy array view + """ + self.wait() + pdata = ctypes.POINTER(mx_float)() + check_call(lib.MXNArrayGetData(self.handle, ctypes.byref(pdata))) + return ctypes2numpy_shared(pdata, self.shape) + + def copyto(self, other): + """copy the content of current array to othe + + When other is NArray, the content is copied over. + When other is a Context, a new NArray in the context + will be created as target + + Parameters + ---------- + other : NArray or Context + another narray we want to copy to, + or target context we want copy the data to + + Returns + ------- + the copy target NArray + """ + if isinstance(other, NArray): + op.copy.invoke_with_handle_((self.handle,), (), (other.handle,)) + return other + elif isinstance(other, Context): + hret = _new_alloc_handle(self.shape, other, True) + op.copy.invoke_with_handle_((self.handle,), (), (hret,)) + return NArray(handle = hret) + else: + raise MXNetError('copyto do not support type ' + type(other)) + +def create(shape, ctx = Context.default_ctx): + """Create a new NArray, with specified shape + + Parameters + ---------- + shape : tuple + shape of the NArray + + Returns + ------- + a new NArray + """ + return NArray(handle = _new_alloc_handle(shape, ctx, False)) + +def _init_function_registry(new_op): + """Initialize the global variable op with new_op + + This function is used to resolve cyclic dependency of .narray on function + + Parameters + ---------- + new_op : function._FunctionRegistry + a FunctionRegistry to pass in in startup + """ + global op + op = new_op + return op diff --git a/api/python/test_python.py b/api/python/test_python.py new file mode 100644 index 000000000000..4edab1247e1a --- /dev/null +++ b/api/python/test_python.py @@ -0,0 +1,29 @@ +import mxnet as mx + +a = mx.narray.create((3000,4000)) +b = mx.narray.create((3000,4000)) +a.numpy[:] = 10 +b.numpy[:] = 11 +print(a.numpy) + +c = b * a + +cc = mx.op.mul(b, a) + +print(c.context) +print(cc.numpy) +d = c.copyto(mx.Context('cpu', 0)) + +print(d.numpy) + +with mx.Context('gpu', 0) as ctx: + # gpu operations + print mx.current_context() + print ctx + a_gpu = a.copyto(ctx) + b_gpu = b.copyto(ctx) + c_gpu = b * a + +d_cpu = c_gpu.copyto(mx.current_context()) +print d_cpu.numpy + diff --git a/include/mxnet/api_registry.h b/include/mxnet/api_registry.h index 408601006910..e08ab03547fb 100644 --- a/include/mxnet/api_registry.h +++ b/include/mxnet/api_registry.h @@ -17,13 +17,31 @@ #include "./narray.h" namespace mxnet { + +/*! \brief mask information on how functions can be exposed */ +enum FunctionTypeMask { + /*! \brief all the use_vars should go before scalar */ + kNArrayArgBeforeScalar = 1, + /*! \brief all the scalar should go before use_vars */ + kScalarArgBeforeNArray = 1 << 1, + /*! + * \brief whether this function allows the handles in the target to + * be empty NArray that are not yet initialized, and will initialize + * them when the function is invoked. + * + * most function should support this, except copy between different + * devices, which requires the NArray to be pre-initialized with context + */ + kAcceptEmptyMutateTarget = 1 << 2 +}; + /*! \brief registry of NArray functions */ -class NArrayFunRegistry { +class FunctionRegistry { public: /*! \brief definition of NArray function */ typedef std::function NArrayFun; + NArray **mutate_vars)> Function; /*! \brief registry entry */ struct Entry { /*! \brief function name */ @@ -34,8 +52,10 @@ class NArrayFunRegistry { unsigned num_mutate_vars; /*! \brief number of scalars used by this function */ unsigned num_scalars; + /*! \brief information on how function should be called from API */ + int type_mask; /*! \brief the real function */ - NArrayFun body; + Function body; /*! * \brief constructor * \param name name of the function @@ -45,6 +65,7 @@ class NArrayFunRegistry { num_use_vars(0), num_mutate_vars(0), num_scalars(0), + type_mask(0), body(nullptr) {} /*! * \brief set the number of mutate variables @@ -75,9 +96,17 @@ class NArrayFunRegistry { * \param f function body to set * \return ref to the registered entry, used to set properties */ - inline Entry &set_body(NArrayFun f) { + inline Entry &set_body(Function f) { body = f; return *this; } + /*! + * \brief set the function body + * \param f function body to set + * \return ref to the registered entry, used to set properties + */ + inline Entry &set_type_mask(int tmask) { + type_mask = tmask; return *this; + } /*! * \brief set the function body to a binary NArray function * this will also auto set the parameters correctly @@ -87,12 +116,30 @@ class NArrayFunRegistry { inline Entry &set_function(void fbinary(const NArray &lhs, const NArray &rhs, NArray *out)) { - body = [fbinary] (NArray **used_vars, real_t *s, NArray **mutate_vars) { + body = [fbinary] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { fbinary(*used_vars[0], *used_vars[1], mutate_vars[0]); }; num_use_vars = 2; num_mutate_vars = 1; + type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; return *this; - } + } + /*! + * \brief set the function body to a unary NArray function + * this will also auto set the parameters correctly + * \param unary function body to set + * \return ref to the registered entry, used to set properties + */ + inline Entry &set_function(void funary(const NArray &src, + NArray *out)) { + body = [funary] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { + funary(*used_vars[0], mutate_vars[0]); + }; + num_use_vars = 1; num_mutate_vars = 1; + type_mask = kNArrayArgBeforeScalar | kAcceptEmptyMutateTarget; + return *this; + } /*! * \brief invoke the function * \param use_vars variables used by the function @@ -106,7 +153,7 @@ class NArrayFunRegistry { } }; // Entry /*! \return get a singleton */ - static NArrayFunRegistry *Get(); + static FunctionRegistry *Get(); /*! * \brief register a name function under name * \param name name of the function @@ -114,17 +161,18 @@ class NArrayFunRegistry { */ Entry &Register(const std::string name); /*! \return list of functions in the registry */ - inline const std::vector &List() const { - return fun_list_; + inline static const std::vector &List() { + return Get()->fun_list_; } /*! * \brief find an function entry with corresponding name * \param name name of the function * \return the corresponding function, can be NULL */ - inline const Entry *Find(const std::string &name) const { - auto p = fmap_.find(name); - if (p != fmap_.end()) { + inline static const Entry *Find(const std::string &name) { + auto &fmap = Get()->fmap_; + auto p = fmap.find(name); + if (p != fmap.end()) { return p->second; } else { return nullptr; @@ -137,9 +185,9 @@ class NArrayFunRegistry { /*! \brief map of name->function */ std::map fmap_; /*! \brief constructor */ - NArrayFunRegistry() {} + FunctionRegistry() {} /*! \brief destructor */ - ~NArrayFunRegistry(); + ~FunctionRegistry(); }; /*! @@ -159,7 +207,7 @@ class NArrayFunRegistry { */ #define REGISTER_NARRAY_FUN(name) \ static auto __ ## name ## _narray_fun__ = \ - ::mxnet::NArrayFunRegistry::Get()->Register("" # name) + ::mxnet::FunctionRegistry::Get()->Register("" # name) } // namespace mxnet #endif // MXNET_API_REGISTRY_H_ diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 0bc636161ee1..f58a7f263f60 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -5,6 +5,7 @@ */ #ifndef MXNET_BASE_H_ #define MXNET_BASE_H_ +#include #include #include diff --git a/include/mxnet/dag_engine.h b/include/mxnet/dag_engine.h index 2dd8682aadfc..0f2ed61b71bc 100644 --- a/include/mxnet/dag_engine.h +++ b/include/mxnet/dag_engine.h @@ -78,9 +78,12 @@ class DAGEngine { * depending on var is completed * * \param delete_fun a function that will be called after var is deleted + * \param exec_ctx execution context * \param var the variable to be deleted */ - virtual void PushDelete(Op delete_fun, Variable var) = 0; + virtual void PushDelete(Op delete_fun, + Context exec_ctx, + Variable var) = 0; /*! * \brief allocate a new variable, the variable can then * be used to schedul the operation concurrently via dependency patterns diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 0fd40512c35e..177018c9fcaf 100644 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -29,9 +29,11 @@ class NArray { * \brief constructing a new dynamic NArray * \param shape the shape of array * \param ctx context of NArray + * \param delay_alloc whether delay the allocation */ - NArray(const TShape &shape, Context ctx) - : ptr_(new Chunk(shape, ctx, false)) { + NArray(const TShape &shape, Context ctx, + bool delay_alloc = false) + : ptr_(new Chunk(shape, ctx, delay_alloc)) { } /*! * \brief constructing a static NArray that shares data with TBlob @@ -49,6 +51,12 @@ class NArray { inline const TShape &shape() const { return ptr_->data.shape_; } + /*! + * \return the data TBlob + */ + inline const TBlob &data() const { + return ptr_->data; + } /*! * \return the context of NArray, this function is only valid when the NArray is not empty */ @@ -59,6 +67,56 @@ class NArray { inline bool is_none() const { return ptr_.get() == nullptr; } + /*! \brief wait until the result of the NArray is computed */ + inline void Wait() const { + if (is_none()) return; + DAGEngine::Get()->WaitForVar(ptr_->var); + } + /*! + * \brief set all the elements in narray to be scalar + * \param scalar the scalar to set + * \return reference of self + */ + NArray &operator=(real_t scalar); + /*! + * \brief elementwise add to current space + * this mutate the current NArray + * \param src the data to add + * \return reference of self + */ + NArray &operator+=(const NArray &src); + /*! + * \brief elementwise subtract from current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator-=(const NArray &src); + /*! + * \brief elementwise multiplication to current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator*=(const NArray &src); + /*! + * \brief elementwise division from current narray + * this mutate the current NArray + * \param src the data to substract + * \return reference of self + */ + NArray &operator/=(const NArray &src); + /*! + * \brief return transpose of current NArray + * \return a new transposed NArray + */ + NArray T() const; + /*! + * \brief return a new copy this NArray + * \param ctx the new context of this NArray + * \return the new copy + */ + NArray Copy(Context ctx) const; private: /*! \brief the real data chunk that backs NArray */ @@ -107,31 +165,38 @@ class NArray { /*! \brief destructor */ ~Chunk() { if (static_data) { - DAGEngine::Get()->PushDelete([](RunContext s) {}, var); + DAGEngine::Get()->PushDelete([](RunContext s) {}, shandle.ctx, var); } else { CHECK(!delay_alloc) << "deleted before allocation"; StorageManager::Handle h = this->shandle; DAGEngine::Get()->PushDelete([h](RunContext s) { StorageManager::Get()->Free(h); - }, var); + }, shandle.ctx, var); } } }; /*! \brief internal data of NArray */ std::shared_ptr ptr_; - /*! - * \brief constructing a new dynamic NArray - * \param shape the shape of array - * \param ctx context of NArray - * \param delay_alloc whether delay the allocation - */ - NArray(const TShape &shape, Context ctx, bool delay_alloc) - : ptr_(new Chunk(shape, ctx, delay_alloc)) { - } // add friend to helper functions + friend void CopyFromTo(const NArray &from, NArray *to); + template + friend void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out); template - friend void BinaryEWise(const NArray &lhs, const NArray &rhs, NArray *out); + friend void UnaryOp(const NArray &lhs, const NArray &rhs, NArray *out); }; + +/*! + * \brief issue an copy operation from one NArray to another + * the two narray can sit on different devices + * this operation will be scheduled by the engine + * + * NOTE: this function name explicitly marks the order of from and to + * due to different possible convention carried by copy function + * \param from the narray we want to copy data from + * \param to the target narray + */ +void CopyFromTo(const NArray &from, NArray *to); + /*! * \brief elementwise add * \param lhs left operand diff --git a/src/api_registry.cc b/src/api_registry.cc index 43549be55376..93e0bd0ee678 100644 --- a/src/api_registry.cc +++ b/src/api_registry.cc @@ -4,25 +4,23 @@ namespace mxnet { -NArrayFunRegistry::Entry & -NArrayFunRegistry::Register(const std::string name) { +FunctionRegistry::Entry & +FunctionRegistry::Register(const std::string name) { CHECK(fmap_.count(name) == 0); Entry *e = new Entry(name); fmap_[name] = e; fun_list_.push_back(e); - // delete me later - LOG(INFO) << "register function " << name; return *e; } -NArrayFunRegistry::~NArrayFunRegistry() { +FunctionRegistry::~FunctionRegistry() { for (auto p = fmap_.begin(); p != fmap_.end(); ++p) { delete p->second; } } -NArrayFunRegistry *NArrayFunRegistry::Get() { - static NArrayFunRegistry instance; +FunctionRegistry *FunctionRegistry::Get() { + static FunctionRegistry instance; return &instance; } diff --git a/src/common/concurrent_blocking_queue.h b/src/common/concurrent_blocking_queue.h new file mode 100644 index 000000000000..aab39895b119 --- /dev/null +++ b/src/common/concurrent_blocking_queue.h @@ -0,0 +1,79 @@ +#pragma once +#include +#include +#include +#include +#include +#include + +template class ConcurrentBlockingQueue { + const static int BUSY_LOOP = 1000; + public: + ConcurrentBlockingQueue() : has_elmt_(false), exit_now_(false) { + } + void Push(const T& e) { + std::lock_guard lock(mutex_); + has_elmt_ = true; + queue_.push_back(e); + if (queue_.size() == 1) { + cv_.notify_all(); + } + } + bool Pop(T& rv) { + for (int i = 0; i < BUSY_LOOP; i++) { + if (has_elmt_) { + std::lock_guard lock(mutex_); + if (!has_elmt_) { + assert(queue_.empty()); + continue; + } + rv = queue_.front(); + queue_.pop_front(); + if (queue_.empty()) + has_elmt_ = false; + return false; + } + } + { + std::unique_lock lock(mutex_); + while (queue_.empty() && !exit_now_) { + cv_.wait(lock); + } + if (!exit_now_) { + rv = queue_.front(); + queue_.pop_front(); + if (queue_.empty()) + has_elmt_ = false; + return false; + } else { + return true; + } + } + } + std::list PopAll() { + std::lock_guard lock(mutex_); + std::list rv; + rv.swap(queue_); + return rv; + } + // Call `SignalForKill` before destruction + void SignalForKill() { + std::unique_lock lock(mutex_); + exit_now_ = true; + cv_.notify_all(); + } + size_t QueueSize() { + std::unique_lock lock(mutex_); + return queue_.size(); + } + + private: + std::atomic has_elmt_; + std::list queue_; + std::mutex mutex_; + std::condition_variable cv_; + std::atomic exit_now_; + + ConcurrentBlockingQueue(const ConcurrentBlockingQueue&) = delete; + ConcurrentBlockingQueue& operator=(const ConcurrentBlockingQueue&) = delete; +}; diff --git a/src/common/spin_lock.h b/src/common/spin_lock.h new file mode 100644 index 000000000000..5a0cc3f786e6 --- /dev/null +++ b/src/common/spin_lock.h @@ -0,0 +1,45 @@ +#ifndef _SPINLOCK_XCHG_H +#define _SPINLOCK_XCHG_H + +/* Spin lock using xchg. + * Copied from http://locklessinc.com/articles/locks/ + */ + +/* Compile read-write barrier */ +#define barrier() asm volatile("": : :"memory") + +/* Pause instruction to prevent excess processor bus usage */ +#define cpu_relax() asm volatile("pause\n": : :"memory") + +static inline unsigned short xchg_8(void *ptr, unsigned char x) { + __asm__ __volatile__("xchgb %0,%1" + :"=r" (x) + :"m" (*(volatile unsigned char *)ptr), "0" (x) + :"memory"); + + return x; +} + +#define BUSY 1 +typedef unsigned char spinlock; + +#define SPINLOCK_INITIALIZER 0 + +static inline void spin_lock(spinlock *lock) { + while (1) { + if (!xchg_8(lock, BUSY)) return; + + while (*lock) cpu_relax(); + } +} + +static inline void spin_unlock(spinlock *lock) { + barrier(); + *lock = 0; +} + +static inline int spin_trylock(spinlock *lock) { + return xchg_8(lock, BUSY); +} + +#endif /* _SPINLOCK_XCHG_H */ diff --git a/src/dag_engine/simple_engine.cc b/src/dag_engine/simple_engine.cc index 9ea42e979735..d38a2daba63a 100644 --- a/src/dag_engine/simple_engine.cc +++ b/src/dag_engine/simple_engine.cc @@ -3,6 +3,7 @@ namespace mxnet { class SimpleEngine : public DAGEngine { public: + virtual void Push(AsyncOp exec_fun, Context exec_ctx, const std::vector &use_vars, @@ -14,10 +15,18 @@ class SimpleEngine : public DAGEngine { Context exec_ctx, const std::vector &use_vars, const std::vector &mutate_vars) { - exec_fun(RunContext()); + if (exec_ctx.dev_mask == gpu::kDevMask) { + ctx_.stream = &stream; + mshadow::SetDevice(exec_ctx.dev_id); + exec_fun(ctx_); + } else { + exec_fun(ctx_); + } } - virtual void PushDelete(Op delete_fun, Variable var) { - delete_fun(RunContext()); + virtual void PushDelete(Op delete_fun, + Context exec_ctx, + Variable var) { + this->Push(delete_fun, exec_ctx, {}, {var}); } virtual Variable NewVar() { // in practice return a ptr to a cell @@ -25,6 +34,10 @@ class SimpleEngine : public DAGEngine { // use ptr directly instead of ID because this avoids an indirect mapping return NULL; } + + private: + RunContext ctx_; + mshadow::Stream stream; }; // implements the singleton factory DAGEngine* DAGEngine::Get() { diff --git a/src/dag_engine/threaded_engine.cc b/src/dag_engine/threaded_engine.cc new file mode 100644 index 000000000000..143b5e72f413 --- /dev/null +++ b/src/dag_engine/threaded_engine.cc @@ -0,0 +1,179 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include "../common/spin_lock.h" +#include "../common/concurrent_blocking_queue.h" + +using namespace std; + +namespace mxnet { + +#define DEFAULT_NUM_WORKER_THREADS 4 + +class ThreadedEngine : public DAGEngine { + public: + ThreadedEngine(int numthreads = DEFAULT_NUM_WORKER_THREADS): numthreads_(numthreads) { + for(int i = 0; i < numthreads; ++i) { + worker_queues_.push_back(new ConcurrentBlockingQueue()); + workers_.emplace_back(&ThreadedEngine::WorkerRoutine, this, i); + } + } + ~ThreadedEngine() { + for(int i = 0; i < numthreads_; ++i) { + worker_queues_[i]->SignalForKill(); + delete worker_queues_[i]; + workers_[i].join(); + } + } + void Push(AsyncOp exec_fun, + Context exec_ctx, + const vector &use_vars, + const vector &mutate_vars) override { + shared_ptr opd( new OpDescr{exec_fun, exec_ctx, use_vars, mutate_vars}, + [this] (OpDescr* o) { this->OnDepsResolved(o); } ); + for( Variable v : use_vars ) { // read + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + if (vard->rw < 0) { + vard->waitings.push(make_pair(opd, DepType::kRead)); + } else { + ++vard->rw; + } + spin_unlock(&vard->lock); + } + for( Variable v : mutate_vars ) { // write + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + if (vard->rw != 0) { + vard->waitings.push(make_pair(opd, DepType::kWrite)); + } else { + vard->rw = -1; + } + spin_unlock(&vard->lock); + } + } + void Push(Op exec_fun, + Context exec_ctx, + const vector &use_vars, + const vector &mutate_vars) override { + this->Push([exec_fun](RunContext ctx, Callback on_complete) { + exec_fun(ctx); on_complete(); + }, exec_ctx, use_vars, mutate_vars); + } + void PushDelete(Op delete_fun, Variable var) override { + // TODO + this->Push([delete_fun, var] (RunContext ctx) { + delete_fun(ctx); + delete static_cast(var); + }, Context()/* TODO exec_ctx is missing?*/, {}, {var}); + } + Variable NewVar() override { + // in practice return a ptr to a cell + // that have the info about the variable + // use ptr directly instead of ID because this avoids an indirect mapping + VarDescr* vd = new VarDescr; + vd->lock = SPINLOCK_INITIALIZER; + vd->rw = 0; + return vd; + } + void WaitForVar(Variable var) override { + // TODO + } + void WaitForAll() override { + // TODO + } + private: + enum class DepType { + kRead = 0, + kWrite, + kDelete, + }; + struct OpDescr { + AsyncOp op; + Context exec_ctx; + vector read_vars; + vector write_vars; + }; + struct VarDescr { + spinlock lock; + int rw; // a semaphore-like count + // if rw > 0, the variable has several readers and the number + // means how many operators are currently reading it; + // if rw < 0, the varaible has one writer (should be -1) + queue, DepType>> waitings; + }; + void TriggerWaiting(VarDescr* vard) { + // ATTENTION: this function should be called with vard->lock held. + CHECK(vard->rw == 0) << "the variable should be free during triggering"; + if(!vard->waitings.empty()) { + // pop all reads first + while(vard->waitings.front().second == DepType::kRead) { + vard->waitings.pop(); + ++vard->rw; + } + if (vard->rw == 0) { + // if the next one is a delete + // pop the next write + vard->waitings.pop(); + vard->rw = -1; + } + } + } + void OnOpFinished(OpDescr* opd) { + CHECK(opd) << "completing a nullptr op!"; + for(Variable v : opd->read_vars) { + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + CHECK(vard->rw > 0) << "incorrect rw count (reader):" << vard->rw; + if(--vard->rw == 0) { + TriggerWaiting(vard); + } + spin_unlock(&vard->lock); + } + for(Variable v : opd->write_vars) { + VarDescr* vard = static_cast(v); // safe to cast here + spin_lock(&vard->lock); + CHECK(vard->rw == -1) << "incorrect rw count (writer):" << vard->rw; + vard->rw = 0; + TriggerWaiting(vard); + spin_unlock(&vard->lock); + } + delete opd; // delete the operator + } + RunContext GetRunContext(const Context& ctx) { + // TODO + return RunContext(); + } + void OnDepsResolved(OpDescr* opd) { + static default_random_engine generator; + static uniform_int_distribution distribution(0, numthreads_); + int thrid = distribution(generator); + worker_queues_[thrid]->Push(opd); + } + void WorkerRoutine(int thrid) { + OpDescr* opd = nullptr; + while(! worker_queues_[thrid]->Pop(opd)) { + LOG(INFO) << "worker thread #" << thrid << " got operator " << opd; + opd->op(GetRunContext(opd->exec_ctx), [this, opd] () { this->OnOpFinished(opd); }); + opd = nullptr; + } + } + private: + const int numthreads_; + vector*> worker_queues_; + vector workers_; +}; + +// implements the singleton factory +DAGEngine* DAGEngine::Get() { + static ThreadedEngine engine; + return &engine; +} +} // namespace mxnet diff --git a/src/narray/narray.cc b/src/narray/narray.cc index abc9b499b993..e03a2c374190 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -6,16 +6,16 @@ namespace mxnet { /*! - * \brief run a binary operation, returning a new dynamically allocated NArray + * \brief run a binary operation * \param lhs left operand * \param rhs right operand * \param out the output narray * \param binary_op the real */ template -inline void BinaryEWise(const NArray &lhs, - const NArray &rhs, - NArray *out) { +inline void BinaryOp(const NArray &lhs, + const NArray &rhs, + NArray *out) { CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; // if out is none, allocate space if (out->is_none()) { @@ -28,46 +28,126 @@ inline void BinaryEWise(const NArray &lhs, // important: callback must always capture by value NArray ret = *out; // redirect everything to mshadow operations - DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { - ret.ptr_->CheckAndAlloc(); - switch (lhs.ctx().dev_mask) { - case cpu::kDevMask: - narray::Eval(lhs.ptr_->data, rhs.ptr_->data, ret.ptr_->data, ctx); - break; + switch (lhs.ctx().dev_mask) { + case cpu::kDevMask: + DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Eval(lhs.ptr_->data, rhs.ptr_->data, &ret.ptr_->data, ctx); + }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + break; #if MXNET_USE_CUDA - case gpu::kDevMask: - narray::Eval(lhs.ptr_->data, rhs.ptr_->data, ret.ptr_->data, ctx); - break; + case gpu::kDevMask: + DAGEngine::Get()->Push([lhs, rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Eval(lhs.ptr_->data, rhs.ptr_->data, &ret.ptr_->data, ctx); + }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + break; #endif - default: LOG(FATAL) << "GPU is not enabled"; - } - }, lhs.ctx(), {lhs.ptr_->var, rhs.ptr_->var}, {ret.ptr_->var}); + default: LOG(FATAL) << "GPU is not enabled"; + } +} + +void CopyFromTo(const NArray &from, NArray *to) { + CHECK(from.shape() == to->shape()) + << "operands shape mismatch"; + CHECK(from.shape().ndim() != 0) + << "source operands have zero dimension shape"; + // important: callback must always capture by value + NArray ret = *to; + int a = from.ctx().dev_mask; + int b = to->ctx().dev_mask; + if (a == cpu::kDevMask && b == cpu::kDevMask) { + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); + } else if (a == cpu::kDevMask && b == gpu::kDevMask) { +#if MXNET_USE_CUDA + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, ret.ctx(), {from.ptr_->var}, {ret.ptr_->var}); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } else if (a == gpu::kDevMask && b == cpu::kDevMask) { +#if MXNET_USE_CUDA + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } else if (a == gpu::kDevMask && b == gpu::kDevMask) { +#if MXNET_USE_CUDA + DAGEngine::Get()->Push([from, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + narray::Copy(from.ptr_->data, &ret.ptr_->data, + from.ctx(), ret.ctx(), ctx); + }, from.ctx(), {from.ptr_->var}, {ret.ptr_->var}); +#else + LOG(FATAL) << "GPU is not enabled"; +#endif + } else { + LOG(FATAL) << "unknown device mask"; + } } template -inline NArray BinaryEWiseRet(const NArray &lhs, - const NArray &rhs) { +inline NArray BinaryOpRet(const NArray &lhs, + const NArray &rhs) { NArray ret; - BinaryEWise(lhs, rhs, &ret); + BinaryOp(lhs, rhs, &ret); return ret; } +template +inline NArray &BinaryOpApply(NArray *dst, + const NArray &src) { + BinaryOp(*dst, src, dst); + return *dst; +} + NArray operator+(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); } NArray operator-(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); } NArray operator*(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); } NArray operator/(const NArray &lhs, const NArray &rhs) { - return BinaryEWiseRet(lhs, rhs); + return BinaryOpRet(lhs, rhs); +} + +NArray &NArray::operator+=(const NArray &src) { + return BinaryOpApply(this, src); +} +NArray &NArray::operator-=(const NArray &src) { + return BinaryOpApply(this, src); +} +NArray &NArray::operator*=(const NArray &src) { + return BinaryOpApply(this, src); } +NArray &NArray::operator/=(const NArray &src) { + return BinaryOpApply(this, src); +} + +// register API function +REGISTER_NARRAY_FUN(plus).set_function(BinaryOp); +REGISTER_NARRAY_FUN(minus).set_function(BinaryOp); +REGISTER_NARRAY_FUN(mul).set_function(BinaryOp); +REGISTER_NARRAY_FUN(div).set_function(BinaryOp); +// copy function is special +//that we need to remove kAcceptEmptyMutateTarget from it +REGISTER_NARRAY_FUN(copy) +.set_function(CopyFromTo) +.set_type_mask(kNArrayArgBeforeScalar); -REGISTER_NARRAY_FUN(Plus).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(Minus).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(Mul).set_function(BinaryEWise); -REGISTER_NARRAY_FUN(Div).set_function(BinaryEWise); } // namespace mxnet diff --git a/src/narray/narray_op-inl.h b/src/narray/narray_op-inl.h index 9891d9a993d0..dd0660336dbe 100644 --- a/src/narray/narray_op-inl.h +++ b/src/narray/narray_op-inl.h @@ -4,8 +4,8 @@ #ifndef DECL_BINARY #define DECL_BINARY(XPU, OP, FUN) \ template<> \ - void Eval(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ } #endif @@ -19,10 +19,11 @@ namespace mxnet { namespace narray { // true implementation template -inline void Eval_(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx) { +inline void Eval_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = static_cast*>(ctx.stream); - ret.FlatTo2D(s) + ret->FlatTo2D(s) = F(lhs.FlatTo2D(s), rhs.FlatTo2D(s)); } diff --git a/src/narray/narray_op.h b/src/narray/narray_op.h index bbdb3e1e53b3..cf827268254f 100644 --- a/src/narray/narray_op.h +++ b/src/narray/narray_op.h @@ -16,6 +16,7 @@ namespace narray { struct BinaryBase { inline static TShape GetShape(const TShape &lshape, const TShape &rshape) { CHECK(lshape == rshape) << "operands shape mismatch"; + CHECK(lshape.ndim() != 0) << "source operand have zero dimension shape"; return lshape; } }; @@ -33,7 +34,13 @@ struct Div : public BinaryBase { typedef mshadow::op::div mshadow_op; }; template -void Eval(const TBlob &lhs, const TBlob &rhs, TBlob ret, RunContext ctx); +void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); + +// copy function when only cpu is involved +template +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx); } // namespace narray } // namespace mxnet diff --git a/src/narray/narray_op_cpu.cc b/src/narray/narray_op_cpu.cc index 9e59be609688..b6c7014964ad 100644 --- a/src/narray/narray_op_cpu.cc +++ b/src/narray/narray_op_cpu.cc @@ -1,3 +1,15 @@ // this will be invoked by gcc and compile CPU version #include "./narray_op.h" #include "./narray_op-inl.h" + +namespace mxnet { +namespace narray { +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D()); +} +} // namespace narray +} // namespace mxnet diff --git a/src/narray/narray_op_gpu.cu b/src/narray/narray_op_gpu.cu index 335be54c27ca..571757e41ee8 100644 --- a/src/narray/narray_op_gpu.cu +++ b/src/narray/narray_op_gpu.cu @@ -1,3 +1,48 @@ // this will be invoked by nvcc and compile GPU version +#include #include "./narray_op.h" #include "./narray_op-inl.h" + +namespace mxnet { +namespace narray { +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D(), + static_cast*>(ctx.stream)); +} + +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D(), + static_cast*>(ctx.stream)); +} + +template<> +void Copy(const TBlob &from, TBlob *to, + Context from_ctx, Context to_ctx, + RunContext ctx) { + if (from_ctx.dev_id == to_ctx.dev_id) { + mshadow::Copy(to->FlatTo2D(), + from.FlatTo2D(), + static_cast*>(ctx.stream)); + } else { + CHECK(from.CheckContiguous() && to->CheckContiguous()) + << "copy across only support continugous memory"; + mshadow::Stream *s = static_cast*>(ctx.stream); + CHECK(s != NULL) << "need stream in GPU context"; + cudaMemcpyPeerAsync(to->dptr_, + to_ctx.dev_id, + from.dptr_, + from_ctx.dev_id, + from.shape_.Size() * sizeof(real_t), + s->stream_); + } +} +} // namespace narray +} // namespace mxnet diff --git a/src/storage/storage.cc b/src/storage/storage.cc index 342c898801b7..ce24a94e8ade 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -1,3 +1,4 @@ +#include #include namespace mxnet { class NaiveStorageManager : public StorageManager { @@ -9,14 +10,25 @@ class NaiveStorageManager : public StorageManager { StorageManager::Handle NaiveStorageManager::Alloc(size_t size, Context ctx) { Handle hd; - hd.dptr = new char[size]; hd.ctx = ctx; - hd.handle_ = NULL; + hd.handle_ = NULL; + if (ctx.dev_mask == cpu::kDevMask) { + cudaMallocHost(&hd.dptr, size); + } else { +#if MXNET_USE_CUDA + cudaMalloc(&hd.dptr, size); +#endif + } return hd; } void NaiveStorageManager::Free(StorageManager::Handle handle) { - char *dptr = static_cast(handle.dptr); - delete [] dptr; + if (handle.ctx.dev_mask == cpu::kDevMask) { + cudaFreeHost(handle.dptr); + } else { +#if MXNET_USE_CUDA + cudaFree(handle.dptr); +#endif + } } StorageManager *StorageManager::Get() { static NaiveStorageManager inst; diff --git a/test/api_registry_test.cc b/test/api_registry_test.cc index b361ca2242e0..8e82fad7dc56 100644 --- a/test/api_registry_test.cc +++ b/test/api_registry_test.cc @@ -3,7 +3,7 @@ #include int main(int argc, char *argv[]) { - auto fadd = mxnet::NArrayFunRegistry::Get()->Find("Plus"); - printf("f.name=%s\n", fadd->name.c_str()); + auto fadd = mxnet::FunctionRegistry::Find("Plus"); + printf("f.name=%s\n", fadd->name.c_str()); return 0; } diff --git a/test/test_threaded_engine.cc b/test/test_threaded_engine.cc new file mode 100644 index 000000000000..40dea029cf6e --- /dev/null +++ b/test/test_threaded_engine.cc @@ -0,0 +1,9 @@ +#include + +using namespace std; +using namespace mxnet; + +int main() { + DAGEngine* engine = DAGEngine::Get(); + return 0; +}