diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index 0e4d64c9aa0..87351c18452 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -74,7 +74,7 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_mmlu(self): + def test_gsm8k(self): args = SimpleNamespace( num_shots=5, data_path=None, @@ -90,5 +90,58 @@ def test_mmlu(self): self.assertGreater(metrics["accuracy"], 0.62) +class TestDeepseekV3MTP(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmzheng/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + "--disable-radix", + "--enable-torch-compile", + "--torch-compile-max-bs", + "1", + "--speculative-algorithm", + "NEXTN", + "--speculative-draft", + "SGLang/sglang-ci-dsv3-test-NextN", + "--speculative-num-steps", + "2", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "4", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + if __name__ == "__main__": unittest.main()