Skip to content

Commit

Permalink
Allow libaray path to be configurable (#50)
Browse files Browse the repository at this point in the history
* Allow libaray path to be configurable

* Enable partial shape inference result to be passed via shape

* fix python3

* disallow copy assign in index
  • Loading branch information
tqchen authored Sep 21, 2016
1 parent b431fec commit 79cf63b
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 9 deletions.
2 changes: 2 additions & 0 deletions include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ class IndexedGraph {
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}
// disalllow copy assign
IndexedGraph(const IndexedGraph&) = delete;

private:
friend class Graph;
Expand Down
27 changes: 21 additions & 6 deletions python/nnvm/libinfo.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
# coding: utf-8
"""Information about nnvm."""
from __future__ import absolute_import
import sys
import os
import platform

if sys.version_info[0] == 3:
import builtins as __builtin__
else:
import __builtin__

def find_lib_path():
"""Find NNNet dynamic library files.
Expand All @@ -12,10 +18,19 @@ def find_lib_path():
lib_path : list(string)
List of all found path to the libraries
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/')
dll_path = [curr_path, api_path, cmake_build_path]
if hasattr(__builtin__, "NNVM_BASE_PATH"):
base_path = __builtin__.NNVM_BASE_PATH
else:
base_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))

if hasattr(__builtin__, "NNVM_LIBRARY_NAME"):
lib_name = __builtin__.NNVM_LIBRARY_NAME
else:
lib_name = "libnnvm_example"

api_path = os.path.join(base_path, '../../lib/')
cmake_build_path = os.path.join(base_path, '../../build/Release/')
dll_path = [base_path, api_path, cmake_build_path]
if os.name == 'nt':
vs_configuration = 'Release'
if platform.architecture()[0] == '64bit':
Expand All @@ -27,9 +42,9 @@ def find_lib_path():
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt':
dll_path = [os.path.join(p, 'libnnvm_example.dll') for p in dll_path]
dll_path = [os.path.join(p, '%s.dll' % lib_name) for p in dll_path]
else:
dll_path = [os.path.join(p, 'libnnvm_example.so') for p in dll_path]
dll_path = [os.path.join(p, '%s.so' % lib_name) for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_path) == 0:
raise RuntimeError('Cannot find the files.\n' +
Expand Down
2 changes: 1 addition & 1 deletion python/nnvm/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __rsub__(self, other):
def __mul__(self, other):
if isinstance(other, Symbol):
return _internal.__mul_symbol__(self, other)
if isinstance(other, Number):
if isinstance(other, _Number):
return _internal.__mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
Expand Down
10 changes: 8 additions & 2 deletions src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ Graph InferAttr(Graph &&ret,
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector
AttrVector rshape(idx.num_node_entries(), default_val);
AttrVector rshape;
if (ret.attrs.count(attr_name) != 0) {
rshape = ret.MoveCopyAttr<AttrVector>(attr_name);
} else {
rshape.resize(idx.num_node_entries(), default_val);
}

if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
Expand All @@ -39,6 +44,7 @@ Graph InferAttr(Graph &&ret,
// erase the provided arguments
ret.attrs.erase(input_name);
}

std::string shape_attr_key;
if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
Expand Down
17 changes: 17 additions & 0 deletions tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ def test_infer_shape():
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]

def test_infer_shape_known_partial():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
shape = [[4, 2], [] , []]
g._set_json_attr("shape", shape, 'list_shape')
g = g.apply("InferShape")
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]


def test_infer_type():
x = sym.Variable('x')
y = sym.add(x, x, name='add1')
Expand Down Expand Up @@ -116,6 +132,7 @@ def test_plan_memory():
test_graph_json_attr()
test_json_pass()
test_infer_shape()
test_infer_shape_known_partial()
test_infer_type()
test_place_device()
test_plan_memory()
Expand Down

0 comments on commit 79cf63b

Please sign in to comment.