Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLX inference support #41

Merged
merged 28 commits into from
Apr 8, 2024
Merged

Conversation

abdulfatir
Copy link
Contributor

@abdulfatir abdulfatir commented Apr 4, 2024

Issue #, if available: #28

Description of changes: This PR adds MLX inference support.

Summary of changes

  • Update pyproject.toml withmlx dependencies.
  • Create chronos_mlx package which will hosts all mlx inference stuff.
    • All classes from main:src/chronos/chronos.py are copy-pasted into mlx:src/chronos_mlx/chronos.py and modified to use numpy and mlx arrays instead. Note that the reason for using numpy arrays as input and output is that mlx doesn't support some operations that are required for input and output transform.
    • MLX implementation of T5 is in src/chronos_mlx/t5.py. It has been adapted from ml-explore/mlx-examples with the following main modifications:
      • Added support for attention mask.
      • Added support for top_k and top_p sampling.
    • src/chronos_mlx/translate.py translates weights from a torch HF model to mlx.
  • Add THIRD-PARTY-LICENSES.txt for third party code from mlx-examples.
  • Add tests and CI for mlx version.

Sample inference code

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from chronos_mlx import ChronosPipeline

pipeline = ChronosPipeline.from_pretrained(
    "amazon/chronos-t5-small",
    dtype="bfloat16",
)

df = pd.read_csv(
    "https://raw.githubusercontent.com/AileenNielsen/TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv"
)

# context must be either a 1D tensor, a list of 1D tensors,
# or a left-padded 2D tensor with batch as the first dimension
context = df["#Passengers"].values
prediction_length = 12
forecast = pipeline.predict(
    context, prediction_length
)  # shape [num_series, num_samples, prediction_length]

# visualize the forecast
forecast_index = range(len(df), len(df) + prediction_length)
low, median, high = np.quantile(forecast[0], [0.1, 0.5, 0.9], axis=0)

plt.figure(figsize=(8, 4))
plt.plot(df["#Passengers"], color="royalblue", label="historical data")
plt.plot(forecast_index, median, color="tomato", label="median forecast")
plt.fill_between(
    forecast_index,
    low,
    high,
    color="tomato",
    alpha=0.3,
    label="80% prediction interval",
)
plt.legend()
plt.grid()
plt.show()

Benchmark

benchmark

import timeit

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from gluonts.dataset.repository import get_dataset
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import SampleForecast
from tqdm.auto import tqdm

from chronos import ChronosPipeline as ChronosPipelineTorch
from chronos_mlx import ChronosPipeline as ChronosPipelineMLX


def benchmark_torch_model(
    pipeline: ChronosPipelineTorch,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [
            torch.tensor(item["target"])
            for item in test_data_input[idx : idx + batch_size]
        ]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = torch.cat(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst.numpy(), start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def benchmark_mlx_model(
    pipeline: ChronosPipelineMLX,
    gluonts_dataset: str = "m4_hourly",
    batch_size: int = 32,
):
    dataset = get_dataset(gluonts_dataset)
    prediction_length = dataset.metadata.prediction_length
    _, test_template = split(dataset.test, offset=-prediction_length)
    test_data = test_template.generate_instances(prediction_length)
    test_data_input = list(test_data.input)

    start_time = timeit.default_timer()
    forecasts = []
    for idx in tqdm(range(0, len(test_data_input), batch_size)):
        batch = [item["target"] for item in test_data_input[idx : idx + batch_size]]
        batch_forecasts = pipeline.predict(batch, prediction_length)
        forecasts.append(batch_forecasts)
    forecasts = np.concatenate(forecasts)
    end_time = timeit.default_timer()

    print(f"Inference time: {end_time-start_time:.2f}s")

    results_df = evaluate_forecasts(
        forecasts=[
            SampleForecast(fcst, start_date=label["start"])
            for fcst, label in zip(forecasts, test_data.label)
        ],
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(np.arange(0.1, 1, 0.1))],
    )
    results_df["inference_time"] = end_time - start_time
    return results_df


def main(
    version: str = "cpu",  # cpu, mps, mlx
    dtype: str = "bfloat16",
    gluonts_dataset: str = "australian_electricity_demand",
    model_name: str = "amazon/chronos-t5-small",
    batch_size: int = 4,
):
    if version == "cpu" or version == "mps":
        pipeline = ChronosPipelineTorch.from_pretrained(
            model_name,
            device_map=version,
            torch_dtype=getattr(torch, dtype),
        )
        benchmark_fn = benchmark_torch_model
    else:
        pipeline = ChronosPipelineMLX.from_pretrained(model_name, dtype=dtype)
        benchmark_fn = benchmark_mlx_model

    result_df = benchmark_fn(
        pipeline, gluonts_dataset=gluonts_dataset, batch_size=batch_size
    )
    result_df["model"] = model_name
    return result_df


if __name__ == "__main__":
    gluonts_dataset: str = "m4_hourly"
    model_name: str = "amazon/chronos-t5-mini"
    batch_size: int = 8
    dfs = []
    for version in ["cpu", "mps", "mlx"]:
        for dtype in ["float32"]:
            try:
                df = main(
                    version=version,
                    dtype=dtype,
                    model_name=model_name,
                    gluonts_dataset=gluonts_dataset,
                    batch_size=batch_size,
                )
                df["version"] = version
                df["dtype"] = dtype
                dfs.append(df)
            except TypeError:
                pass

    result_df = pd.concat(dfs).reset_index(drop=True)
    result_df.to_csv("benchmark.csv", index=False)

    result_df["version"] = result_df["version"].map(
        {"cpu": "Torch (CPU)", "mps": "Torch (MPS)", "mlx": "MLX"}
    )
    fig = plt.figure(figsize=(8, 5))
    g = sns.barplot(
        data=result_df,
        x="dtype",
        y="inference_time",
        hue="version",
        alpha=0.6,
    )
    plt.ylabel("Inference Time (on M1 Pro)")
    plt.title(f"{model_name} inference times on {gluonts_dataset} dataset")
    plt.savefig("benchmark.png", dpi=200)

TODOs:

  • Implement top_p sampling.
  • Add tests.
  • Add CI.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@abdulfatir abdulfatir marked this pull request as draft April 4, 2024 19:38
@abdulfatir abdulfatir added the enhancement New feature or request label Apr 5, 2024
@abdulfatir abdulfatir marked this pull request as ready for review April 5, 2024 15:22
@abdulfatir abdulfatir changed the base branch from main to mlx April 5, 2024 16:26
.github/workflows/ci.yml Outdated Show resolved Hide resolved
Copy link
Contributor

@lostella lostella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks @abdulfatir!

@abdulfatir abdulfatir merged commit 159ea36 into amazon-science:mlx Apr 8, 2024
1 check passed
@abdulfatir abdulfatir deleted the add-mlx-support branch April 8, 2024 13:03
@abdulfatir abdulfatir mentioned this pull request Apr 8, 2024
abdulfatir added a commit that referenced this pull request Apr 8, 2024
*Issue #, if available:* #28 (also, PR #41)

*Description of changes:* This PR updates the README with information on
MLX support.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants