Skip to content

Commit

Permalink
Using generate_model_group_names() API in model server test
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanMullins committed Oct 11, 2024
1 parent d442048 commit 2488aa7
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions lit_nlp/examples/gcp/model_server_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import os
from unittest import mock

from absl.testing import absltest
from lit_nlp.examples.gcp import model_server
from lit_nlp.examples.prompt_debugging import utils as pd_utils
import webtest


class TestWSGIApp(absltest.TestCase):

@mock.patch('lit_nlp.examples.prompt_debugging.models.get_models')
def test_predict_endpoint(self, mock_get_models):
test_model_name = 'lit_on_gcp_test_model'
test_model_config = f'{test_model_name}:test_model_path'
os.environ['MODEL_CONFIG'] = test_model_config

mock_model = mock.MagicMock()
mock_model.predict.side_effect = [[{'response': 'test output text'}]]
Expand All @@ -24,10 +30,12 @@ def test_predict_endpoint(self, mock_get_models):
[{'tokens': ['test', 'output', 'text']}]
]

sal_name, tok_name = pd_utils.generate_model_group_names(test_model_name)

mock_get_models.return_value = {
'gemma_1.1_2b_IT': mock_model,
'_gemma_1.1_2b_IT_salience': salience_model,
'_gemma_1.1_2b_IT_tokenize': tokenize_model,
test_model_name: mock_model,
sal_name: salience_model,
tok_name: tokenize_model,
}
app = webtest.TestApp(model_server.get_wsgi_app())

Expand Down

0 comments on commit 2488aa7

Please sign in to comment.