Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #22 from antinucleon/master
Browse files Browse the repository at this point in the history
simplify symbol creator as discussed
  • Loading branch information
antinucleon committed Aug 21, 2015
2 parents 83b8788 + ef7bb06 commit 0635103
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 156 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# MXNet

[![Build Status](https://travis-ci.org/dmlc/mxnet.svg?branch=master)](https://travis-ci.org/dmlc/mxnet)
[![Documentation Status](https://readthedocs.org/projects/mxnet/badge/?version=latest)](https://readthedocs.org/projects/mxnet/?badge=latest)

This is a project that combines lessons and ideas we learnt from [cxxnet](https://github.com/dmlc/cxxnet), [minerva](https://github.com/dmlc/minerva) and [purine2](https://github.com/purine/purine2).
- The interface is designed in collaboration by authors of three projects.
Expand Down
5 changes: 2 additions & 3 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from .context import Context, current_context
from .narray import NArray
from .function import _FunctionRegistry
from .symbol import Symbol
from .symbol_creator import _SymbolCreatorRegistry
from . import symbol

__version__ = "0.1.0"

# this is a global function registry that can be used to invoke functions
op = NArray._init_function_registry(_FunctionRegistry())
sym = Symbol._init_symbol_creator_registry(_SymbolCreatorRegistry())

130 changes: 113 additions & 17 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,16 @@
from __future__ import absolute_import

import ctypes
import sys
from .base import _LIB
from .base import c_array, c_str, mx_uint, NArrayHandle, ExecutorHandle, SymbolHandle
from .base import c_array, c_str, mx_uint, string_types
from .base import NArrayHandle, ExecutorHandle, SymbolHandle
from .base import check_call
from .context import Context
from .executor import Executor

class Symbol(object):
"""Symbol is symbolic graph of the mxnet."""
_registry = None

@staticmethod
def _init_symbol_creator_registry(symbol_creator_registry):
"""Initialize symbol creator registry
Parameters
----------
symbol_creator_registry:
pass in symbol_creator_registry
Returns
-------
the passed in registry
"""
_registry = symbol_creator_registry
return _registry

def __init__(self, handle):
"""Initialize the function with handle
Expand Down Expand Up @@ -257,3 +243,113 @@ def bind(self, ctx, args, args_grad, reqs):
reqs_array,
ctypes.byref(handle)))
return Executor(handle)


def Variable(name):
"""Create a symbolic variable with specified name.
Parameters
----------
name : str
Name of the variable.
Returns
-------
variable : Symbol
The created variable symbol.
"""
if not isinstance(name, string_types):
raise TypeError('Expect a string for variable `name`')
handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateVariable(name, ctypes.byref(handle)))
return Symbol(handle)


def Group(symbols):
"""Create a symbolic variable that groups several symbols together.
Parameters
----------
symbols : list
List of symbols to be grouped.
Returns
-------
sym : Symbol
The created group symbol.
"""
ihandles = []
for sym in symbols:
if not isinstance(sym, Symbol):
raise TypeError('Expect Symbols in the list input')
ihandles.append(sym.handle)
handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateGroup(
len(ihandles), c_array(SymbolHandle, ihandles), ctypes.byref(handle)))
return Symbol(handle)


def _make_atomic_symbol_function(handle, func_name):
"""Create an atomic symbol function by handle and funciton name."""
def creator(*args, **kwargs):
"""Activation Operator of Neural Net.
The parameters listed below can be passed in as keyword arguments.
Parameters
----------
name : string, required.
Name of the resulting symbol.
Returns
-------
symbol: Symbol
the resulting symbol
"""
param_keys = []
param_vals = []
symbol_kwargs = {}
name = kwargs.pop('name', None)

for k, v in kwargs.items():
if isinstance(v, Symbol):
symbol_kwargs[k] = v
else:
param_keys.append(c_str(k))
param_vals.append(c_str(str(v)))
# create atomic symbol
param_keys = c_array(ctypes.c_char_p, param_keys)
param_vals = c_array(ctypes.c_char_p, param_vals)
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
handle, len(param_keys),
param_keys, param_vals,
ctypes.byref(sym_handle)))

if len(args) != 0 and len(symbol_kwargs) != 0:
raise TypeError('%s can only accept input \
Symbols either as positional or keyword arguments, not both' % func_name)

s = Symbol(sym_handle)
s._compose(*args, name=name, **symbol_kwargs)
return s
creator.__name__ = func_name
return creator


def _init_module_functions():
"""List and add all the atomic symbol functions to current module."""
plist = ctypes.POINTER(ctypes.c_void_p)()
size = ctypes.c_uint()
check_call(_LIB.MXSymbolListAtomicSymbolCreators(ctypes.byref(size),
ctypes.byref(plist)))
module_obj = sys.modules[__name__]
for i in range(size.value):
hdl = ctypes.c_void_p(plist[i])
name = ctypes.c_char_p()
check_call(_LIB.MXSymbolGetAtomicSymbolName(hdl, ctypes.byref(name)))
function = _make_atomic_symbol_function(hdl, name.value)
setattr(module_obj, function.__name__, function)

# Initialize the atomic symbo in startups
_init_module_functions()

132 changes: 0 additions & 132 deletions python/mxnet/symbol_creator.py

This file was deleted.

8 changes: 4 additions & 4 deletions python/test_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def Get(self):

# symbol net
batch_size = 100
data = mx.sym.Variable('data')
fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=160)
act1 = mx.sym.Activation(data = fc1, name='relu1', type="relu")
fc2 = mx.sym.FullyConnected(data = act1, name='fc2', num_hidden=10)
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=160)
act1 = mx.symbol.Activation(data = fc1, name='relu1', type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name='fc2', num_hidden=10)
args_list = fc2.list_arguments()
# infer shape
data_shape = (batch_size, 784)
Expand Down

0 comments on commit 0635103

Please sign in to comment.