Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't deepcopy an xformer model with triton 2 update #290

Closed
jramapuram opened this issue May 1, 2022 · 6 comments · Fixed by #309
Closed

Can't deepcopy an xformer model with triton 2 update #290

jramapuram opened this issue May 1, 2022 · 6 comments · Fixed by #309
Assignees
Labels

Comments

@jramapuram
Copy link

🐛 Bug

Given the recent triton2 update an xformer model cannot be deep-copied. This is an important requirement for numerous tasks including EMA (without knowledge of the generating class / hyper-params).

xformers ViT-B Config

reversible: False
block_type: "encoder"
num_layers: 12
dim_model: 768
layer_norm_style: "pre"

multi_head_config:
  num_heads: 12
  residual_dropout: 0.1  # (1) tried without this, (2) swapping this for DropPath, (3) with regular dropout
  use_rotary_embeddings: False

  attention:
    name: "scaled_dot_product"
    dropout: 0.0
    causal: False

feedforward_config:
  name: "MLP"
  dropout: 0.0
  activation: "gelu"
  hidden_layer_multiplier: 4

To reproduce

from copy import deepcopy

with open(transfomer_config_file, "rb") as fileptr:
    self.model_config = yaml.load(fileptr, Loader=yaml.FullLoader)

model = xFormer.from_config(xFormerConfig([self.model_config]))
deepcopy(model)

Error is:

TypeError: cannot pickle 'PyCapsule' object
@blefaudeux
Copy link
Contributor

Oh jeez.. OK, you can remove triton from your env and this should unlock short term, I'll have a look in the meantime :)

@jramapuram
Copy link
Author

For ref even the SWA impl in vanilla pytorch relies on deepcopy: https://github.com/pytorch/pytorch/blob/master/torch/optim/swa_utils.py#L100 -- not being able to EMA is a ~1-2% loss in overall perf 😬

blefaudeux added a commit that referenced this issue May 24, 2022
blefaudeux added a commit that referenced this issue May 24, 2022
@blefaudeux
Copy link
Contributor

it should be fine with the attached PR @jramapuram, if you see other issues it's easy to fix and augment the unit test, basically in that case lazy initializing the triton parts fixes that

blefaudeux added a commit that referenced this issue May 24, 2022
@blefaudeux
Copy link
Contributor

sorry for the delay, I should have seen that before

blefaudeux added a commit that referenced this issue May 24, 2022
blefaudeux added a commit that referenced this issue May 24, 2022
@blefaudeux blefaudeux self-assigned this May 24, 2022
blefaudeux added a commit that referenced this issue May 25, 2022
* Tentatively fixing pickling issues, lazy init
@jramapuram
Copy link
Author

Awesome @blefaudeux ! Testing now, thanks 🙏

@blefaudeux
Copy link
Contributor

Awesome @blefaudeux ! Testing now, thanks pray

let me know if another part fails this, I should be able to fix in a similar fashion. And I'm still on #219, trying to come up with a repro less expensive than full blown IN. My current lead hypothesis is to provide means to handle various inits out of the box, right now deepnorm sets the distribution to a scaled uniform init, and it's probably not the best for all problems. It can always be done from the outside, but it kind of negates the benefit of having deepnorm out of the box

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants