-
Notifications
You must be signed in to change notification settings - Fork 979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
state_dict in pytorch isn't compatible with params with theano #5
Comments
Yes, the provided models in this repo were trained with Theano. If you want to load the models for pytorch, maybe you can load the list and rewrite it as a dict according to the state_dict format. The file policy_value_net_numpy.py may be helpful for you to figure out how the params are originally stored in a list. |
The params of the pretrained Theano models can be transformed to the state_dict format of a PyTorch model by using the following script. Note that Theano conv2d flips the filters (rotate 180 degree) first while doing the calculation. import pickle
from collections import OrderedDict
param_theano = pickle.load(open('best_policy_6_6_4.model', 'rb'))
keys = ['conv1.weight' ,'conv1.bias' ,'conv2.weight' ,'conv2.bias' ,'conv3.weight' ,'conv3.bias'
,'act_conv1.weight' ,'act_conv1.bias' ,'act_fc1.weight' ,'act_fc1.bias'
,'val_conv1.weight' ,'val_conv1.bias' ,'val_fc1.weight' ,'val_fc1.bias' ,'val_fc2.weight' ,'val_fc2.bias']
param_pytorch = OrderedDict()
for key, value in zip(keys, param_theano):
if 'fc' in key and 'weight' in key:
param_pytorch[key] = torch.FloatTensor(value.T)
elif 'conv' in key and 'weight' in key:
param_pytorch[key] = torch.FloatTensor(value[:,:,::-1,::-1].copy())
else:
param_pytorch[key] = torch.FloatTensor(value) |
Sorry for reply so late. I just don't know exactly how state_dict is composed of, so I don't know how to rewrite it. Anyway, thx a lot for your code! |
我也尝试用pytorch load这个model file,把上面的代码加到 |
when I try to load a model trained by pytroch, I came across the same error, I try the follow code in <human_play.py>,line 65:
And It worked |
state_dict in pytorch is a dict while params trained with theano dumped as list.
when you want to retrain the model trained with theano, it seems that the model can't be loaded properly.
is there any way to solve this?
The text was updated successfully, but these errors were encountered: