Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[RFC] Need to export all HybridBlocks in a Gluon model #19535

samskalicky opened this issue Nov 14, 2020 · 5 comments

[RFC] Need to export all HybridBlocks in a Gluon model #19535

samskalicky opened this issue Nov 14, 2020 · 5 comments


Copy link

samskalicky commented Nov 14, 2020


Gluon can have any hierarchy of Blocks where a top-level Block is not a HybridBlock, and lower-level blocks are HybridBlocks. These HybridBlocks can be dispersed throughout the model architecture. Users can optimize the parts of their model that support it by calling hybridize on the top level block to trigger a recursive call throughout all child blocks. But there is currently no way to export all HybridBlocks without going through and manually calling export on each one. Further, theres no way to reload those exported symbol files back without changing the model design and swapping those HybridBlocks for SymbolBlocks and than one-by-one calling imports to reload.

I would like to ask for suggestions from the community and have an active discussion on the different ways to address this problem. To kick things off, here is a proposal to start the discussion around:

def save_cached_graphs(block):
    def _save_cached_graphs(blk, index):
        if isinstance(blk, mx.gluon.nn.HybridBlock):
            blk.export( + str(index[0]))
        for child in blk._children.values():
            index[0] += 1
            _save_cached_graphs(child, index)
    #save top-level block                                                                                  
    index = [0]
    _save_cached_graphs(block, index)

def load_cached_graphs(block):
    def _load_cached_graphs(blk, index):
        if isinstance(blk, mx.gluon.nn.HybridBlock):
            sym = symbol.load( + str(index[0]) + '-symbol.json')
            blk._cached_graph = sym
        for child in blk._children.values():
            index[0] += 1
            _load_cached_graphs(child, index)
    #load top-level block                                                                                  
    index = [0]
    _load_cached_graphs(block, index)

With these two functions, we can recursively export each hybrid block and then reload the symbols. Obviously the code is not complete or even functional (_cached_graph is actual a tuple of symbols and sym.var inputs). But should serve as a point of reference.

Items on v1.x

  • Since a Block's children are stored in a dictionary, need to save/restore their unique names
  • Since parameters are mapped to their block's name, need to synchronize names after reloading model architecture to match save params

Items on master

  • Since parameters have a UUID, need to save/restore mapping of a Block's params to UUID

General Approach

  1. Recursively create a dictionary structure of blocks mimicking the model architecture
  2. Be able to uniquely identify the block in the model
  3. For HybridBlocks, save/restore the cached graph (symbol + inputs) and in/out formats
  4. Save the model architecture dictionary & parameters
  5. Restore the model architecture with unique identifiers to synchronize with parameter naming/IDs
Copy link
Contributor Author

samskalicky commented Nov 14, 2020

Heres a silly example of a hierarchical Block/HybridBlock to play around with:

class MyBlock(mx.gluon.nn.Block):
    def __init__(self, **kwargs):
        super(MyBlock, self).__init__(**kwargs)
    def add(self, block):
        self._children[ + str(len(self._children))] = block
    def forward(self, x, *args):
        out = (x,) + args
        for block in self._children.values():
            out = block(*out)
        return out

# create the Model
inside = MyBlock()
net = MyBlock()
x = mx.nd.empty((1,10))
out = net(x)

#hybridize and create cached_graphs
out = net(x)

#save cached_graphs

The hierarchy should look like this, where a top level Block(0) has a child Block(1) and a HybridBlock(1) and the child Block(1) has a child HybridBlock(0).

- MyBlock(0)
         - MyBlock(1)
         |     |
         |     - Dense(0)
         - Dense(1)

Copy link
Contributor Author

samskalicky commented Nov 14, 2020

Heres a complete, working solution on v1.8.x branch. Notice the new save and load functions perform the full model export/reload of both model architecture and params.

import mxnet as mx
import json

class MyBlock(mx.gluon.nn.Block):
    def __init__(self, **kwargs):
        super(MyBlock, self).__init__(**kwargs)
    def add(self, block):
        self._children[ + str(len(self._children))] = block
    def forward(self, x, *args):
        out = (x,) + args
        for block in self._children.values():
            out = block(*out)
        return out
    def save(self, prefix):
        # create empty model structure
        model = {}
        def _save_cached_graphs(blk, index, structure):
            # create new entry for this block
            mdl = {'orig_name':}
            # encode unique name based on block type and ID
            name = type(blk).__name__.lower()
            structure[name+str(index[0])] = mdl
            if isinstance(blk, mx.gluon.nn.HybridBlock):
                # save in/out formats
                mdl['in_format'] = blk._in_format
                mdl['out_format'] = blk._out_format
                # save cached graph & input symbols
                syms, out = blk._cached_graph
                mdl_syms = []
                for sym in syms:
                mdl['inputs'] = mdl_syms
                mdl['symbol'] = out.tojson()
            children = dict()
            mdl['children'] = children
            # recursively save children
            for ch_name, child in blk._children.items():
                index[0] += 1
                # save child's original name in this block's map
                children[] = ch_name
                _save_cached_graphs(child, index, mdl)
        # save top-level block
        index = [0]
        _save_cached_graphs(self, index, model)
        # save model
        fp = open(prefix+'-model.json','w')
        json.dump(model, fp)
        # save params
    def load(self, prefix):
        # load model json from file
        fp = open(prefix+'-model.json')
        model = json.load(fp)
        def _load_cached_graphs(blk, index, log):
            # get block name
            name = type(blk).__name__.lower()
            # lookup previous encoded name based on block type and ID
            mdl = log[name+str(index[0])]
            # rename block to what it was when saved
            blk._name = mdl['orig_name']
            if isinstance(blk, mx.gluon.nn.HybridBlock):
                # restore in/out formats
                blk._in_format = mdl['in_format']
                blk._out_format = mdl['out_format']
                # get saved symbol
                out = mx.sym.load_json(mdl['symbol'])
                syms = []
                # recreate inputs for this symbol
                for inp in mdl['inputs']:
                # reset cached_graph and active status
                blk._cached_graph = (syms, out)
                blk._active = True
            # rename params with updated block name
            pnames = list(blk.params.keys())
            for p in pnames:
                param = blk.params._params[p]
                new_name = +'_'+ p[len(blk.params._prefix):]
                blk.params._params[new_name] = param            
            # recursively reload children
            for ch_name, child in blk._children.items():
                index[0] += 1
                _load_cached_graphs(child, index, mdl)
            # current set of child names
            ch_names = list(blk._children.keys())
            # original child names
            children = mdl['children']
            # loop and remap children with original names
            for ch_name in ch_names:
                child = blk._children[ch_name]
                orig_name = children[]
                blk._children[orig_name] = child
        # load top-level block
        index = [0]
        _load_cached_graphs(self, index, model)
        # load params

def createNet():
    inside = MyBlock()
    dense = mx.gluon.nn.Dense(10)
    net = MyBlock()
    return net

# create and initialize model
net = createNet()
# run first inference to test
x = mx.nd.empty((1,10))
out = net(x)
# hybridize (the hybridizeable blocks, ie. the Dense layers)
out = net(x)

# save hybridized model'MyModel')

# create a new model, uninitialized
net = createNet()
# reload hybridized model
# run inference again
out = net(x)

And heres a complete, working solution on master branch:

import mxnet as mx
import json

class MyBlock(mx.gluon.Block):
    def __init__(self, **kwargs):
        super(MyBlock, self).__init__(**kwargs)
        self.layers = []
    def add(self, block):
    def forward(self, x, *args):
        out = (x,) + args
        for block in self._children.values():
            out = block()(*out)
        return out                                    
    def save(self, prefix):
        # create empty model structure
        model = {}
        def _save_cached_graphs(blk, index, structure):
            # create new entry for this block
            mdl = {}
            # encode unique name based on block type and ID
            name = type(blk).__name__.lower()
            structure[name+str(index[0])] = mdl
            if isinstance(blk, mx.gluon.nn.HybridBlock):
                # save in/out formats
                mdl['in_format'] = blk._in_format
                mdl['out_format'] = blk._out_format
                # save cached graph & input symbols
                syms, out = blk._cached_graph
                mdl_syms = []
                for sym in syms:
                mdl['inputs'] = mdl_syms
                mdl['symbol'] = out.tojson()
            # save param uuids
            pmap = {}
            mdl['params'] = pmap
            pnames = list(blk.params.keys())
            for p in pnames:
                param = blk.params[p]
            # recursively save children
            for ch_name, child in blk._children.items():
                index[0] += 1
                _save_cached_graphs(child(), index, mdl)
        # save top-level block
        index = [0]
        _save_cached_graphs(self, index, model)
        # save model
        fp = open(prefix+'-model.json','w')
        json.dump(model, fp)
        # save params
    def load(self, prefix):
        # load model json from file
        fp = open(prefix+'-model.json')
        model = json.load(fp)
        def _load_cached_graphs(blk, index, structure):
            # get block name
            name = type(blk).__name__.lower()
            # lookup previous encoded name based on block type and ID
            mdl = structure[name+str(index[0])]
            if isinstance(blk, mx.gluon.nn.HybridBlock):
                # restore in/out formats
                blk._in_format = mdl['in_format']
                blk._out_format = mdl['out_format']
                # get saved symbol
                out = mx.sym.load_json(mdl['symbol'])
                syms = []
                # recreate inputs for this symbol
                for inp in mdl['inputs']:
                # reset cached_graph and active status
                blk._cached_graph = (syms, out)
                blk._active = True
            # reload param uuids
            pmap = mdl['params']
            for p, uuid in pmap.items():
                param = blk.params[p]
                param._uuid = pmap[p]
            # recursively reload children
            for ch_name, child in blk._children.items():
                index[0] += 1
                _load_cached_graphs(child(), index, mdl)
        # load top-level block
        index = [0]
        _load_cached_graphs(self, index, model)
        # load params
def createNet():
    inside = MyBlock()
    dense = mx.gluon.nn.Dense(10)
    net = MyBlock()
    return net

# create and initialize model
net = createNet()
# run first inference to test
x = mx.nd.random.randn(1,10)
out = net(x)
# hybridize (the hybridizeable blocks, ie. the Dense layers)
out = net(x)

# save hybridized model'MyModel')

# create a new model, uninitialized
net = createNet()
# reload hybridized model
# run inference again
out = net(x)

Copy link

fhieber commented Nov 16, 2020

Thanks @samskalicky, this would sound like an interesting feature for Sockeye as well. In Sockeye, we use a mix of Hybrid and non-Hybrid blocks, both during training and inference. We currently do not use the export mechanism of Gluon, but keep our own configuration file to build the model graph at inference time, according to its specifications, we therefore tie code and model artifacts closely together. An export functionality that would allow us to export the inference computation after training sounds useful.

My main question about this though would be how exported HybridBlocks (i.e. SymbolBlocks) behave w.r.t changing input shapes? Before Gluon, we had to make heavy use of the BucketingModule in MXNet to account for varying shapes at inference time.

Copy link
Contributor Author

Thanks @fhieber is your approach for Sockeye's configuration file generalizable to something we could build in to mxnet? Can you point me to an example where you do this?

My main question about this though would be how exported HybridBlocks (i.e. SymbolBlocks) behave w.r.t changing input shapes? Before Gluon, we had to make heavy use of the BucketingModule in MXNet to account for varying shapes at inference time.

Nothing would change for that problem with exporting models. The restrictions around input shapes is specific to the MXNet engine memory allocation scheme where reshaping causes re-allocation. Its much lower level in the stack. If you wanna discuss that lets start a separate GitHub issue.

Copy link
Contributor Author

Created PRs on v1.x (#19565) and master (#19564) branches. Closing this RFC issue and migrating to PRs for further discussion.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
None yet

No branches or pull requests

2 participants