Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into varlen
Browse files Browse the repository at this point in the history
  • Loading branch information
jiqing-feng authored Jan 13, 2025
2 parents c1bd7f7 + 190ae87 commit fb71c2e
Show file tree
Hide file tree
Showing 18 changed files with 655 additions and 99 deletions.
23 changes: 20 additions & 3 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Here is the list of the supported architectures :

- Albert
- Aquila
- Aquila 2
- Arctic
- Audio Spectrogram Transformer
- Baichuan 2
Expand All @@ -28,7 +29,7 @@ Here is the list of the supported architectures :
- Bloom
- CLIP
- Camembert
- ChatGLM
- ChatGLM (ChatGLM2, ChatGLM3, GLM4)
- CodeGen
- CodeGen2
- Cohere
Expand All @@ -49,6 +50,7 @@ Here is the list of the supported architectures :
- Falcon
- Flaubert
- GLM-4
- GLM-Edge
- GPT-2
- GPT-BigCode
- GPT-J
Expand All @@ -57,12 +59,18 @@ Here is the list of the supported architectures :
- GPT-NeoX-Japanese
- Gemma
- Gemma2
- Granite
- GraniteMoE
- Hubert
- IBert
- InternLM
- InternLM2
- InternVL2
- Jais
- Levit
- Llama
- Llava
- Llava-Next
- M2-M100
- MBart
- MPNet
Expand All @@ -71,6 +79,7 @@ Here is the list of the supported architectures :
- Marian
- MiniCPM
- MiniCPM3
- MiniCPMV
- Mistral
- Mixtral
- MobileBert
Expand All @@ -86,10 +95,13 @@ Here is the list of the supported architectures :
- Persimmon
- Phi
- Phi3
- Phi3Vision
- Pix2Struct
- PoolFormer
- Qwen
- Qwen2(Qwen1.5)
- Qwen2(Qwen1.5, Qwen2.5)
- Qwen2MoE
- Qwen2VL
- ResNet
- Roberta
- Roformer
Expand Down Expand Up @@ -119,10 +131,15 @@ Here is the list of the supported architectures :
- Stable Diffusion
- Stable Diffusion XL
- Latent Consistency
- Stable Diffusion 3
- Flux

## [Timm](https://huggingface.co/docs/timm/index)
- PiT
- ViT

## [Sentence Transformers](https://github.com/UKPLab/sentence-transformers)
- All Transformer and CLIP-based models.
- All Transformer and CLIP-based models.

## [OpenCLIP](https://github.com/mlfoundations/open_clip)
- All CLIP-based models
41 changes: 21 additions & 20 deletions optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
)
self.free_blocks = torch.arange(self.num_blocks, device=device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=device)
self.max_cache_len = max_cache_len
self.num_kv_heads = config.num_key_value_heads
self.num_hidden_layers = config.num_hidden_layers
Expand Down Expand Up @@ -88,12 +88,10 @@ def update_for_prefill(
all_slot_offsets = []
num_blocks = (input_lens + self.block_size - 1) // self.block_size
for i in range(batch_size):
for b_idx in range(num_blocks[i]):
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

nb = num_blocks[i]
block_table = self.free_blocks.nonzero().view(-1)[0:nb]
self.block_tables[i][0:nb] = block_table
self.free_blocks[block_table] = 0
slots_range = torch.arange(input_lens[i], device=key_states.device)
block_indices = slots_range // self.block_size
slot_offsets = slots_range % self.block_size
Expand All @@ -103,7 +101,6 @@ def update_for_prefill(
all_block_indices = torch.cat(all_block_indices)
all_slot_offsets = torch.cat(all_slot_offsets)
self.slots = all_block_indices * self.block_size + all_slot_offsets

# Update the cache
PagedAttention.reshape_and_cache(
key_states,
Expand All @@ -127,16 +124,16 @@ def update_for_decode(
):
if layer_idx == 0:
start_block_idx = self._seen_tokens // self.block_size
num_blocks = (self._seen_tokens + self.block_size) // self.block_size
slot_offset_in_block = (self._seen_tokens) % self.block_size
self.slots = torch.zeros([batch_size], device=key_states.device, dtype=torch.int32)
for i in range(batch_size):
for b_idx in range(start_block_idx[i], num_blocks[i]):
if slot_offset_in_block[i] == 0:
# need a new block:
b_idx = start_block_idx[i]
if self.block_tables[i][b_idx] == -1:
# need a free block
self.block_tables[i][b_idx] = self.free_blocks[0]
self.free_blocks = self.free_blocks[1:]

self.block_tables[i][b_idx] = self.free_blocks.nonzero().view(-1)[0:1]
self.free_blocks[self.block_tables[i][b_idx]] = 0
self.slots[i] = self.block_tables[i][start_block_idx[i]] * self.block_size + slot_offset_in_block[i]
# Update the cache
PagedAttention.reshape_and_cache(
Expand Down Expand Up @@ -196,7 +193,7 @@ def reset(self):
"""Resets the cache values while preserving the objects"""
self._seen_tokens = torch.zeros([self.batch_size], dtype=torch.int32, device=self.block_tables.device)
self.block_tables.fill_(-1)
self.free_blocks = torch.arange(self.num_blocks, device=self.block_tables.device)
self.free_blocks = torch.ones([self.num_blocks], dtype=torch.int32, device=self.block_tables.device)
self.max_seq_len = 0

def reorder_cache(self, beam_idx: torch.LongTensor):
Expand All @@ -206,16 +203,18 @@ def reorder_cache(self, beam_idx: torch.LongTensor):
updated_block_tables = self.block_tables.index_select(0, beam_idx.to(device))
mask = self.block_tables.masked_fill(self.block_tables != -1, 1).masked_fill(self.block_tables == -1, 0)
num_blocks = mask.cumsum(-1)[:, -1]
updated_table = []
updated_table = torch.zeros_like(beam_idx)
for i in range(beam_idx.shape[0]):
self.block_tables[i, 0 : num_blocks[i] - 1] = updated_block_tables[i, 0 : num_blocks[i] - 1]
updated_table.append(self.block_tables[i : i + 1, num_blocks[i] - 1 : num_blocks[i]])
updated_table = torch.cat(tuple(updated_table), dim=0)
nb = num_blocks[i]
self.block_tables[i, 0 : nb - 1] = updated_block_tables[i, 0 : nb - 1]
updated_table[i] = self.block_tables[i][nb - 1]
for layer_idx in range(self.num_hidden_layers):
self.key_cache[layer_idx][updated_table] = self.key_cache[layer_idx][updated_table[beam_idx]]
self.value_cache[layer_idx][updated_table] = self.value_cache[layer_idx][updated_table[beam_idx]]
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
Expand All @@ -235,4 +234,6 @@ def crop(self, maximum_length: int):
self._seen_tokens[bs] = new_tokens
self.max_seq_len, _ = self._seen_tokens.max(dim=0)
free_table = torch.unique((origin_table[origin_table != self.block_tables]).view(-1))
self.free_blocks = torch.cat((self.free_blocks, free_table))
for i in free_table:
if not (self.block_tables == i).any():
self.free_blocks[i] = 1
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)


FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"}
FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager", "gemma2": "sdpa"}

if TYPE_CHECKING:
from optimum.intel.openvino.configuration import OVConfig
Expand Down
42 changes: 29 additions & 13 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
_torch_version,
_transformers_version,
compare_versions,
is_diffusers_version,
is_openvino_tokenizers_version,
is_tokenizers_version,
is_transformers_version,
Expand Down Expand Up @@ -663,6 +664,9 @@ def export_from_model(
# Get the shapes to be used to generate dummy inputs
input_shapes = {}
for input_name in DEFAULT_DUMMY_SHAPES.keys():
if input_name in ["height", "width"]:
# use H and W from generator defaults
continue
input_shapes[input_name] = (
kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name]
)
Expand Down Expand Up @@ -988,24 +992,36 @@ def _get_submodels_and_export_configs(
def get_diffusion_models_for_export_ext(
pipeline: "DiffusionPipeline", int_dtype: str = "int64", float_dtype: str = "fp32", exporter: str = "openvino"
):
try:
from diffusers import (
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3Pipeline,
)
if is_diffusers_version(">=", "0.29.0"):
from diffusers import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline

is_sd3 = isinstance(
pipeline, (StableDiffusion3Pipeline, StableDiffusion3InpaintPipeline, StableDiffusion3Img2ImgPipeline)
)
except ImportError:
sd3_pipes = [StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline]
if is_diffusers_version(">=", "0.30.0"):
from diffusers import StableDiffusion3InpaintPipeline

sd3_pipes.append(StableDiffusion3InpaintPipeline)

is_sd3 = isinstance(pipeline, tuple(sd3_pipes))
else:
is_sd3 = False

try:
if is_diffusers_version(">=", "0.30.0"):
from diffusers import FluxPipeline

is_flux = isinstance(pipeline, FluxPipeline)
except ImportError:
flux_pipes = [FluxPipeline]

if is_diffusers_version(">=", "0.31.0"):
from diffusers import FluxImg2ImgPipeline, FluxInpaintPipeline

flux_pipes.extend([FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline])

if is_diffusers_version(">=", "0.32.0"):
from diffusers import FluxFillPipeline

flux_pipes.append(FluxFillPipeline)

is_flux = isinstance(pipeline, tuple(flux_pipes))
else:
is_flux = False

if not is_sd3 and not is_flux:
Expand Down
Loading

0 comments on commit fb71c2e

Please sign in to comment.