From 1b9ea13c9b751f60b0a8aea6d14e86842d0c08f4 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Ansari Date: Mon, 25 Mar 2024 12:02:59 +0000 Subject: [PATCH] Update test --- test/test_chronos.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test_chronos.py b/test/test_chronos.py index ea4f562..0fba658 100644 --- a/test/test_chronos.py +++ b/test/test_chronos.py @@ -186,12 +186,9 @@ def test_pipeline_embed(torch_dtype: str): device_map="cpu", torch_dtype=torch_dtype, ) - model_context_length = pipeline.model.config.context_length - expected_embed_length = model_context_length + ( - 1 if pipeline.model.config.use_eos_token else 0 - ) d_model = pipeline.model.model.config.d_model context = 10 * torch.rand(size=(4, 16)) + 10 + expected_embed_length = 16 + (1 if pipeline.model.config.use_eos_token else 0) # input: tensor of shape (batch_size, context_length)