Skip to content

Commit

Permalink
Due to cyclic import, changed import statement inside get_model_outpu…
Browse files Browse the repository at this point in the history
…t function.

Signed-off-by: Ahmed Umair <[email protected]>
  • Loading branch information
Umair Ahmed committed Sep 25, 2024
1 parent 5c9248c commit 476b335
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions crossfit/backend/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
create_nested_list_series_from_3d_ar,
)
from crossfit.utils.torch_utils import cleanup_torch_cache, concat_and_pad_tensors
from crossfit.backend.torch.loader import SortedSeqLoader


class Model:
Expand Down Expand Up @@ -61,6 +60,8 @@ def max_seq_length(self) -> int:
raise NotImplementedError()

def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cudf.DataFrame:
from crossfit.backend.torch.loader import SortedSeqLoader

out = cudf.DataFrame(index=index)
_index = loader.sort_column(index.values) if type(loader) == SortedSeqLoader else index

Expand All @@ -85,7 +86,7 @@ def get_model_output(self, all_outputs_ls, index, loader, pred_output_col) -> cu
elif len(outputs.shape) == 3:
out[pred_output_col] = create_nested_list_series_from_3d_ar(outputs, _index)
else:
raise RuntimeError(f"Unexpected output shape: {output.shape}")
raise RuntimeError(f"Unexpected output shape: {outputs.shape}")
del outputs
del _index
cleanup_torch_cache()
Expand Down

0 comments on commit 476b335

Please sign in to comment.