Skip to content

Commit

Permalink
Add window size detection to OmniSR (#1933)
Browse files Browse the repository at this point in the history
Add window size detection to omnisr
  • Loading branch information
joeyballentine authored Jul 11, 2023
1 parent 5aa2060 commit 097d98e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
12 changes: 11 additions & 1 deletion backend/src/nodes/impl/pytorch/architecture/OmniSR/OmniSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,17 @@ def __init__(
residual_layer = []
self.res_num = res_num

self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer)
if (
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
in state_dict.keys()
):
rel_pos_bias_weight = state_dict[
"residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
].shape[0]
self.window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2)
else:
self.window_size = 8

self.up_scale = up_scale

for _ in range(res_num):
Expand Down
1 change: 1 addition & 0 deletions backend/src/nodes/properties/outputs/pytorch_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _get_sizes(value: PyTorchModel) -> List[str]:
elif "OmniSR" in value.model_arch:
return [
f"{value.num_feat}nf",
f"w{value.window_size}",
f"{value.res_num}nr",
]
elif value.model_arch in [
Expand Down

0 comments on commit 097d98e

Please sign in to comment.