Skip to content

Commit

Permalink
CI: run examples (#763)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored Jan 17, 2022
1 parent e95c717 commit 0d18bda
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
12 changes: 12 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ jobs:
# testRunTitle: '$(Agent.OS) - $(Build.BuildNumber)[$(Agent.JobName)] - Python $(python.version)'
# condition: succeededOrFailed()

- bash: |
set -e
pip install .
FILES="tm_examples/*.py"
for fn in $FILES
do
echo "Processing $fn example..."
python $fn
done
pip uninstall -y torchmetrics
displayName: 'Examples'
- bash: |
pip install -r requirements/integrate.txt --quiet --upgrade-strategy only-if-needed
pip uninstall -y torchmetrics
Expand Down
9 changes: 4 additions & 5 deletions tm_examples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,10 @@ def user_forward_fn(model: torch.nn.Module, batch: Dict[str, torch.Tensor]) -> t
tokenizer = UserTokenizer()
model = get_user_model_encoder()

Scorer = BERTScore(
bs = BERTScore(
model=model, user_tokenizer=tokenizer, user_forward_fn=user_forward_fn, max_length=_MAX_LEN, return_hash=False
)
Scorer.update(_PREDS, _REFS)
print("Predictions")
pprint(Scorer.predictions)
bs.update(_PREDS, _REFS)
print(f"Predictions:\n {bs.preds_input_ids}\n {bs.preds_attention_mask}")

pprint(Scorer.compute())
pprint(bs.compute())

0 comments on commit 0d18bda

Please sign in to comment.