Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

state_dict in pytorch isn't compatible with params with theano #5

Closed
GeneZC opened this issue Jan 19, 2018 · 5 comments
Closed

state_dict in pytorch isn't compatible with params with theano #5

GeneZC opened this issue Jan 19, 2018 · 5 comments

Comments

@GeneZC
Copy link

GeneZC commented Jan 19, 2018

state_dict in pytorch is a dict while params trained with theano dumped as list.
when you want to retrain the model trained with theano, it seems that the model can't be loaded properly.
is there any way to solve this?

@GeneZC GeneZC changed the title state_dict in pytorch doesn't compatible with implementation using numpy state_dict in pytorch isn't compatible with params with theano Jan 19, 2018
@junxiaosong
Copy link
Owner

Yes, the provided models in this repo were trained with Theano. If you want to load the models for pytorch, maybe you can load the list and rewrite it as a dict according to the state_dict format. The file policy_value_net_numpy.py may be helpful for you to figure out how the params are originally stored in a list.

@junxiaosong
Copy link
Owner

The params of the pretrained Theano models can be transformed to the state_dict format of a PyTorch model by using the following script. Note that Theano conv2d flips the filters (rotate 180 degree) first while doing the calculation.

import pickle
from collections import OrderedDict
param_theano = pickle.load(open('best_policy_6_6_4.model', 'rb'))
keys = ['conv1.weight' ,'conv1.bias' ,'conv2.weight' ,'conv2.bias' ,'conv3.weight' ,'conv3.bias'  
    ,'act_conv1.weight' ,'act_conv1.bias' ,'act_fc1.weight' ,'act_fc1.bias'     
    ,'val_conv1.weight' ,'val_conv1.bias' ,'val_fc1.weight' ,'val_fc1.bias' ,'val_fc2.weight' ,'val_fc2.bias']
param_pytorch = OrderedDict()
for key, value in zip(keys, param_theano):
    if 'fc' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value.T)
    elif 'conv' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value[:,:,::-1,::-1].copy())
    else:
        param_pytorch[key] = torch.FloatTensor(value)

@GeneZC
Copy link
Author

GeneZC commented Jan 22, 2018

Sorry for reply so late. I just don't know exactly how state_dict is composed of, so I don't know how to rewrite it. Anyway, thx a lot for your code!

@dofish
Copy link

dofish commented Apr 21, 2019

The params of the pretrained Theano models can be transformed to the state_dict format of a PyTorch model by using the following script. Note that Theano conv2d flips the filters (rotate 180 degree) first while doing the calculation.

import pickle
from collections import OrderedDict
param_theano = pickle.load(open('best_policy_6_6_4.model', 'rb'))
keys = ['conv1.weight' ,'conv1.bias' ,'conv2.weight' ,'conv2.bias' ,'conv3.weight' ,'conv3.bias'  
    ,'act_conv1.weight' ,'act_conv1.bias' ,'act_fc1.weight' ,'act_fc1.bias'     
    ,'val_conv1.weight' ,'val_conv1.bias' ,'val_fc1.weight' ,'val_fc1.bias' ,'val_fc2.weight' ,'val_fc2.bias']
param_pytorch = OrderedDict()
for key, value in zip(keys, param_theano):
    if 'fc' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value.T)
    elif 'conv' in key and 'weight' in key:
        param_pytorch[key] = torch.FloatTensor(value[:,:,::-1,::-1].copy())
    else:
        param_pytorch[key] = torch.FloatTensor(value)

我也尝试用pytorch load这个model file,把上面的代码加到
best_policy = PolicyValueNet(width, height, model_file = param_pytorch)
前面,但是pytorch load这个转换过的参数文件依然报错
File "/home/dofish/.local/lib/python3.7/site-packages/torch/serialization.py", line 189, in _check_seekable
raise_err_msg(["seek", "tell"], e)
File "/home/dofish/.local/lib/python3.7/site-packages/torch/serialization.py", line 182, in raise_err_msg
raise type(e)(msg)
AttributeError: 'collections.OrderedDict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.

@lzwbit
Copy link

lzwbit commented Dec 8, 2020

when I try to load a model trained by pytroch, I came across the same error, I try the follow code in <human_play.py>,line 65:

policy_param = model_file

And It worked

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants