Skip to content

Commit

Permalink
Fix types, add mypy to workflow (#42)
Browse files Browse the repository at this point in the history
*Description of changes:* Fix some type checking issues, add mypy to
github workflow, apply black


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
lostella authored Apr 5, 2024
1 parent 96cedec commit 4b1d1c8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 25 deletions.
41 changes: 31 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,44 @@ name: CI
on: [push, pull_request]

jobs:
type-check:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ["3.11"]
platform: [ubuntu-latest]

runs-on: ${{ matrix.platform }}

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[typecheck]"
- name: Type checks with mypy
run: mypy src test

test:
strategy:
max-parallel: 4
fail-fast: false
matrix:
python-version: ['3.11']
python-version: ["3.11"]
platform: [ubuntu-latest]

runs-on: ${{ matrix.platform }}

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]"
- name: Test with pytest
run: pytest
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install ".[test]"
- name: Test with pytest
run: pytest
14 changes: 6 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
name = "chronos"
version = "1.1.0"
requires-python = ">=3.8"
license = {file = "LICENSE"}
license = { file = "LICENSE" }
dependencies = [
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate"
"torch~=2.1", # package was tested on 2.2
"transformers~=4.31",
"accelerate",
]

[project.optional-dependencies]
test = [
"pytest~=8.0",
"numpy~=1.21"
]
test = ["pytest~=8.0", "numpy~=1.21"]
typecheck = ["mypy~=1.9"]

[tool.mypy]
ignore_missing_imports = true
14 changes: 9 additions & 5 deletions src/chronos/chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ def embed(
or the length of the longest time series, if a list of 1D tensors was
provided, and the extra 1 is for EOS.
"""
context = self._prepare_and_validate_context(context=context)
context_tensor = self._prepare_and_validate_context(context=context)
token_ids, attention_mask, tokenizer_state = self.tokenizer.input_transform(
context
context_tensor
)
embeddings = self.model.encode(
input_ids=token_ids.to(self.model.device),
Expand Down Expand Up @@ -424,7 +424,7 @@ def predict(
Tensor of sample forecasts, of shape
(batch_size, num_samples, prediction_length).
"""
context = self._prepare_and_validate_context(context=context)
context_tensor = self._prepare_and_validate_context(context=context)

if prediction_length is None:
prediction_length = self.model.config.prediction_length
Expand All @@ -443,7 +443,9 @@ def predict(
remaining = prediction_length

while remaining > 0:
token_ids, attention_mask, scale = self.tokenizer.input_transform(context)
token_ids, attention_mask, scale = self.tokenizer.input_transform(
context_tensor
)
samples = self.model(
token_ids.to(self.model.device),
attention_mask.to(self.model.device),
Expand All @@ -463,7 +465,9 @@ def predict(
if remaining <= 0:
break

context = torch.cat([context, prediction.median(dim=1).values], dim=-1)
context_tensor = torch.cat(
[context_tensor, prediction.median(dim=1).values], dim=-1
)

return torch.cat(predictions, dim=-1)

Expand Down
4 changes: 2 additions & 2 deletions test/test_chronos.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_tokenizer_fixed_data(

samples = tokenizer.output_transform(
torch.arange(n_special_tokens, n_tokens).unsqueeze(0).repeat(batch_size, 1, 1),
decoding_context=scale,
tokenizer_state=scale,
)

assert (samples[:, 0, [0, -1]] == context).all()
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_tokenizer_random_data(use_eos_token: bool):
assert samples.shape == (2, 10, 4)


def validate_tensor(samples: torch.Tensor, shape: Tuple[int, int, int]) -> None:
def validate_tensor(samples: torch.Tensor, shape: Tuple[int, ...]) -> None:
assert isinstance(samples, torch.Tensor)
assert samples.shape == shape

Expand Down

0 comments on commit 4b1d1c8

Please sign in to comment.