diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/config/mixed_precision_config/exceptions.json b/microsoft-Phi-4-mini-instruct/QAIRT/config/mixed_precision_config/exceptions.json
new file mode 100644
index 000000000..0cb0ed1ce
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/config/mixed_precision_config/exceptions.json
@@ -0,0 +1,108 @@
+{
+ "module_list":
+ [
+ {
+ "module_name": "QuantizedRmsNorm",
+ "exceptions": {
+ "param_exceptions": {
+ "asymmetric": true,
+ "bitwidth": 16
+ },
+ "input_exceptions": null,
+ "output_exceptions": null
+ }
+ }
+ ],
+ "name_list":[
+ {
+ "module_name": "\\w*model_embed_tokens_Gather",
+ "exceptions": {
+ "param_exceptions": {
+ "bitwidth": 16,
+ "asymmetric": true
+ },
+ "input_exceptions": null,
+ "output_exceptions": null
+ }
+ },
+ {
+ "module_name": "\\w*lm_head_(MatMul|conv_Conv|conv2d_Conv|Conv)",
+ "exceptions": {
+ "param_exceptions": {
+ "bitwidth": 16
+ },
+ "input_exceptions": null,
+ "output_exceptions": null
+ }
+ },
+ {
+ "module_name": "\\w*norm_(Mul_1|Mul_1.module)",
+ "exceptions": {
+ "param_exceptions": null,
+ "input_exceptions": [
+ {
+ "input_index": 0,
+ "bitwidth": 16,
+ "asymmetric": true
+ }
+ ],
+ "output_exceptions": null
+ }
+ },
+ {
+ "module_name": "\\w*norm_(Pow|Pow.module|ReduceMean|Add|Sqrt|Div|Mul)",
+ "exceptions": {
+ "param_exceptions": null,
+ "input_exceptions": null,
+ "output_exceptions": [
+ {
+ "output_index": 0,
+ "enabled": false
+ }
+ ]
+ }
+ },
+ {
+ "module_name": "\\w*self_attn_Concat_1",
+ "exceptions": {
+ "param_exceptions": null,
+ "input_exceptions": null,
+ "output_exceptions": [
+ {
+ "output_index": 0,
+ "bitwidth": 16,
+ "asymmetric": false
+ }
+ ]
+ }
+ },
+ {
+ "module_name": "\\w*self_attn_Concat_4",
+ "exceptions": {
+ "param_exceptions": null,
+ "input_exceptions": null,
+ "output_exceptions": [
+ {
+ "output_index": 0,
+ "bitwidth": 16,
+ "asymmetric": false
+ }
+ ]
+ }
+ },
+ {
+ "module_name": "\\w*v_proj_(MatMul|conv_Conv|conv2d_Conv|Conv)(\\.base_layer)?",
+ "exceptions": {
+ "param_exceptions": null,
+ "input_exceptions": null,
+ "output_exceptions": [
+ {
+ "output_index": 0,
+ "bitwidth": 16,
+ "asymmetric": false
+ }
+ ]
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/concat_adapter.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/concat_adapter.py
new file mode 100644
index 000000000..9b2b933b8
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/concat_adapter.py
@@ -0,0 +1,217 @@
+import operator
+import collections
+from typing import Dict, Sequence
+from functools import reduce
+from dataclasses import replace
+
+import torch
+from peft.tuners.lora import LoraLayer
+from peft import get_peft_model, PeftModel, PeftMixedModel
+
+
+def concat_adapter(
+ model,
+ adapters: Sequence[str],
+ weights: Sequence[int],
+ keep_rank: bool = True,
+ fold_scaling: bool = False,
+ delete_adapters: bool = True,
+ concat_adapter_name: str = "concat_set"
+):
+ """
+ Concatenate adapters in the given model. This is a variant implementation of 'cat' mode of the below API:
+ https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/model.py
+ It differs in that the ranks of concatenated adapters will be customized by each LoRA layer instead of same ranks
+ across all layers.
+ :param model: A model attached with LoRA adapters
+ :param adapters: Adapters attached to the model
+ :param weights: Weights applied to each the corresponding adapters list, must be consistent with length
+ :param keep_rank: Whether to keep ranks across layers. When set to False, no zero paddings are applied
+ :param fold_scaling: Whether to fold scaling into LoRA weights. When set to False, the LoRA weights will not
+ be scaled by lora_scaling
+ :param delete_adapters: Whether to delete adapters. When set to True, all adapters in the adapters list are deleted
+ :param concat_adapter_name: Name of the concatenated adapter. Default is "concat_set"
+ :return: None. Adapters are concatenated in place and concatenated adapter will be set by default
+ """
+ for adapter in adapters:
+ if adapter not in model.peft_config.keys():
+ assert ValueError(f'Cannot find adapter adapter "{adapter}" in the given model!')
+
+ if keep_rank:
+ if not isinstance(model, (PeftModel, PeftMixedModel)):
+ # dummy_config created only to get a peft model instance
+ dummy_config = replace(
+ model.peft_config[adapters[0]],
+ r = 1,
+ lora_alpha=1,
+ target_modules=[list(model.peft_config[adapters[0]].target_modules)[0]]
+ )
+ model = get_peft_model(model, dummy_config, adapter_name='dummy')
+ model.delete_adapter('dummy')
+
+ # PEFT API to concat adapters
+ if not fold_scaling:
+ original_scalings = set_scaling(model, adapters, 1.0)
+ model.add_weighted_adapter(
+ adapters=adapters,
+ weights=weights,
+ adapter_name=concat_adapter_name,
+ combination_type='cat'
+ )
+
+ # Set lora scaling for targeted adapters' LoRA layers
+ for layer_name, target in model.named_modules():
+ if isinstance(target, LoraLayer):
+ # NOTE: there might be risk of directly writing into lora_scaling
+ loras_scaling = []
+
+ for adapter in adapters:
+ if adapter in target.lora_A:
+ current_scaling = torch.full((target.r[adapter],), original_scalings[adapter][layer_name])
+ elif adapter in target.lora_embedding_A:
+ raise RuntimeError(
+ f'We only support LoRA weights for now but LoRA embeddings are found in {adapter} instead!'
+ )
+ else:
+ continue
+
+ loras_scaling.append(current_scaling)
+
+ if len(loras_scaling) == 0:
+ raise ValueError(
+ f'No matching LoRAs found for {layer_name}, please check your adapter configuration!')
+
+ target.scaling[concat_adapter_name] = torch.zeros(
+ model.peft_config[concat_adapter_name].r, dtype=target.weight.dtype, device=target.weight.device
+ )
+ loras_scaling = torch.cat(loras_scaling)
+ target.scaling[concat_adapter_name][: loras_scaling.shape[0]].copy_(loras_scaling)
+ else:
+ concat_rank = sum([model.peft_config[adapter].r for adapter in adapters])
+ concat_modules = reduce(operator.or_, [model.peft_config[adapter].target_modules for adapter in adapters])
+ # We use the config to initialize an adapter but ranks/scalings will change for some target_modules
+ # So this config is never meant to be used or accessed external to this function
+ concat_config = replace(
+ model.peft_config[adapters[0]],
+ r=concat_rank,
+ lora_alpha=concat_rank,
+ target_modules=concat_modules
+ )
+
+ model = get_peft_model(model, concat_config, adapter_name=concat_adapter_name)
+
+ for layer_name, target in model.named_modules():
+ if isinstance(target, LoraLayer):
+ if concat_adapter_name in target.lora_A:
+ concat_lora_A = target.lora_A[concat_adapter_name]
+ concat_lora_B = target.lora_B[concat_adapter_name]
+ elif concat_adapter_name in target.lora_embedding_A:
+ raise RuntimeError(
+ f'We only support LoRA weights for now but LoRA embeddings are found in {concat_adapter_name} instead!'
+ )
+ else:
+ continue
+
+ # NOTE: there might be risk of directly writing into lora_scaling
+ loras_A, loras_B, loras_scaling = [], [], []
+ for adapter in adapters:
+ if adapter in target.lora_A:
+ current_adapter_lora_A = target.lora_A[adapter].weight
+ current_adapter_lora_B = target.lora_B[adapter].weight
+ elif adapter in target.lora_embedding_A:
+ raise RuntimeError(
+ f'We only support LoRA weights for now but LoRA embeddings are found in {adapter} instead!'
+ )
+ else:
+ continue
+
+ loras_A.append(current_adapter_lora_A.data)
+ loras_B.append(current_adapter_lora_B.data)
+ loras_scaling.append(torch.full((target.r[adapter], ), target.scaling[adapter]))
+
+ if len(loras_A) == 0:
+ raise ValueError(f'No matching LoRAs found for {layer_name}, please check your adapter configuration!')
+
+ loras_A = torch.cat(loras_A, dim=0)
+ loras_B = torch.cat(loras_B, dim=1)
+
+ # This might be risky since we're replacing the whole parameter
+ if isinstance(concat_lora_A, torch.nn.Conv2d):
+ target.lora_A[concat_adapter_name] = torch.nn.Conv2d(
+ loras_A.shape[1],
+ loras_A.shape[0],
+ kernel_size=concat_lora_A.kernel_size,
+ stride=concat_lora_A.stride,
+ padding=concat_lora_A.padding,
+ bias=concat_lora_A.bias is not None,
+ device=concat_lora_A.weight.device,
+ dtype=concat_lora_A.weight.dtype,
+ )
+ target.lora_B[concat_adapter_name] = torch.nn.Conv2d(
+ loras_B.shape[1],
+ loras_B.shape[0],
+ kernel_size=concat_lora_B.kernel_size,
+ stride=concat_lora_B.stride,
+ padding=concat_lora_B.padding,
+ bias=concat_lora_B.bias is not None,
+ device=concat_lora_B.weight.device,
+ dtype=concat_lora_B.weight.dtype,
+ )
+ elif isinstance(concat_lora_A, torch.nn.Linear):
+ target.lora_A[concat_adapter_name] = torch.nn.Linear(
+ loras_A.shape[1],
+ loras_A.shape[0],
+ bias=concat_lora_A.bias is not None,
+ device=concat_lora_A.weight.device,
+ dtype=concat_lora_A.weight.dtype,
+ )
+ target.lora_B[concat_adapter_name] = torch.nn.Linear(
+ loras_B.shape[1],
+ loras_B.shape[0],
+ bias=concat_lora_B.bias is not None,
+ device=concat_lora_B.weight.device,
+ dtype=concat_lora_B.weight.dtype,
+ )
+
+ target.lora_A[concat_adapter_name].weight.data.copy_(loras_A)
+ target.lora_B[concat_adapter_name].weight.data.copy_(loras_B)
+ target.scaling[concat_adapter_name] = torch.cat(loras_scaling).to(model.device)
+
+ model.set_adapter(concat_adapter_name)
+ if delete_adapters:
+ for adapter in adapters:
+ model.delete_adapter(adapter)
+
+def set_scaling(
+ model,
+ adapters: Sequence[str],
+ scaling: float = 1.
+) -> Dict[str, Dict[str, float]]:
+ """
+ Set scaling of LoRA layers for adapters to scaling
+ :param model: The LoRA model to be rescaled
+ :param adapters: Adapters of which to rescale
+ :param scaling: The scaling factor to be applied to all target adapters
+ :return: The original scaling for each adapter's target modules
+ """
+ original_scalings = collections.defaultdict(dict)
+
+ for layer_name, target in model.named_modules():
+ if isinstance(target, LoraLayer):
+ # NOTE: there might be risk of directly writing into lora_scaling
+ for adapter in adapters:
+ if adapter in target.scaling:
+ original_scalings[adapter][layer_name] = target.scaling[adapter]
+ target.scaling[adapter] = scaling
+ elif adapter in target.lora_embedding_A:
+ raise RuntimeError(
+ f'We only support LoRA weights for now but LoRA embeddings are found in {adapter} instead!'
+ )
+ else:
+ continue
+
+ for adapter in adapters:
+ if not original_scalings[adapter]:
+ raise RuntimeError(f'No matching LoRAs found for {adapter}, please check your adapter configuration!')
+
+ return original_scalings
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/profiler.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/profiler.py
new file mode 100644
index 000000000..09619d9c4
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/profiler.py
@@ -0,0 +1,350 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" common utilities and class implementation for time/GPU/RAM profiling """
+import argparse
+import contextlib
+import gc
+import io
+import json
+import os
+import time
+from dataclasses import dataclass
+from datetime import timedelta
+from multiprocessing import Process, RLock, Value
+from typing import Dict, List, Optional, Union
+
+import psutil
+import torch
+from aimet_common.utils import AimetLogger
+from torch.types import Device
+
+logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
+
+WATERMARK_THREAD_POLLING_INTERVAL_IN_MS = 100
+
+
+def convert_bytes(size: int):
+ """
+ :return bytes in human-readable format
+ """
+ sign = ''
+ if size < 0:
+ sign = '-'
+ size = abs(size)
+
+ for x in [' B', 'KB', 'MB', 'GB', 'TB']:
+ if size < 1024.0:
+ break
+ size /= 1024.0
+ return "%s%3.1f %s" % (sign, size, x)
+
+
+@dataclass
+class ProfileMarker:
+ """ Implements methods to capture profiling data """
+ event: str
+ device: Optional[Union[Device, int]]
+ gpu_memory_usage: int
+ cpu_memory_usage: int
+ delta_gpu_memory_usage: int
+ delta_cpu_memory_usage: int
+ time: float
+
+ def __str__(self):
+ """ string representation of the marker event """
+ return f'Event {self.event} : time={timedelta(seconds=self.time)}, ' \
+ f'GPU={convert_bytes(self.delta_gpu_memory_usage)}(+ {convert_bytes(self.gpu_memory_usage)}), ' \
+ f'RAM={convert_bytes(self.delta_cpu_memory_usage)}(+ {convert_bytes(self.cpu_memory_usage)})'
+
+ def to_dict(self) -> dict:
+ """return data in dict """
+ return {
+ 'event': self.event,
+ 'device': self.device,
+ 'gpu_memory_usage': self.gpu_memory_usage,
+ 'cpu_memory_usage': self.cpu_memory_usage,
+ 'delta_gpu_memory_usage': self.delta_gpu_memory_usage,
+ 'delta_cpu_memory_usage': self.delta_cpu_memory_usage,
+ 'time': self.time
+ }
+
+
+def ram_usage(pid):
+ """ :return RAM usage for given process Id. """
+ return psutil.Process(pid).memory_info().rss
+
+
+@dataclass
+class EventMarker:
+ """ Implements methods to capture system stats during an event """
+ event: str
+ device: Optional[Union[Device, int]]
+ _gpu_memory_usage: int
+ _cpu_memory_usage: int
+ time = time.time()
+
+ def __init__(self, event: str = None, cpu_memory_usage: int = 0, device: Union[Device, int] = None):
+ self._gpu_memory_usage = torch.cuda.max_memory_allocated(device)
+ self._cpu_memory_usage: int = cpu_memory_usage
+ self.device = device
+ self.time = time.time()
+ if event is None:
+ self.event = f'@ {self.time}'
+ else:
+ self.event = event
+
+ def delta(self, event: str, start_marker: 'EventMarker') -> ProfileMarker:
+ """computes diff between two event marker and returns profile marker"""
+ # pylint: disable=protected-access
+ return ProfileMarker(
+ event,
+ self.device,
+ start_marker._gpu_memory_usage,
+ start_marker._cpu_memory_usage,
+ self._gpu_memory_usage - start_marker._gpu_memory_usage,
+ self._cpu_memory_usage - start_marker._cpu_memory_usage,
+ int(self.time - start_marker.time)
+ )
+
+ def __str__(self):
+ """ string representation of the marker event """
+ return f'Event {self.event} : time={self.time}, GPU={self.gpu_memory_usage}, RAM={self.cpu_memory_usage}'
+
+ @property
+ def gpu_memory_usage(self) -> str:
+ """returns a string representation of GPU usage """
+ device_str = 'default' if self.device is None \
+ else f'cuda:{self.device}' if isinstance(self.device, int) \
+ else str(self.device)
+
+ return f'{device_str}:{convert_bytes(self._gpu_memory_usage)}'
+
+ @property
+ def cpu_memory_usage(self) -> str:
+ """returns a string representation of RAM usage """
+ return convert_bytes(self._cpu_memory_usage)
+
+ def to_dict(self) -> dict:
+ """return data in dict """
+ return {
+ 'event': self.event,
+ 'device': self.device,
+ 'gpu_memory_usage': self._gpu_memory_usage,
+ 'cpu_memory_usage': self._cpu_memory_usage,
+ 'time': self.time
+ }
+
+
+def ram_watermark_function(ram_allocated: Value, pid: int, polling_interval_in_ms: float):
+ """
+ observing process to reflect current peak RAM usage.
+ :param ram_allocated: shared variable used by profiler to reset to new allocation and the observing(this) process
+ to track max allocation.
+ :param pid: parent process pid for tracking mem allocation
+ :param polling_interval_in_ms: interval between polling for memory usage
+ """
+ logger.info('Created RAM watermark daemon process(pid=%d) for pid=%d, polling at %.1f ms',
+ os.getpid(), pid, polling_interval_in_ms)
+ while psutil.pid_exists(pid):
+ new_usage = ram_usage(pid)
+ with ram_allocated.get_lock():
+ ram_allocated.value = max(new_usage, ram_allocated.value)
+ time.sleep(polling_interval_in_ms / 1000.0)
+
+
+# pylint: disable=no-member
+class EventProfiler:
+ """ Implements methods to profile latency and RAM/GPU memory usage """
+ _instance = None
+
+ def __new__(cls):
+ """ Implements the Global Object Pattern (Singleton) """
+ if cls._instance is None:
+ cls._instance = super(EventProfiler, cls).__new__(cls)
+ cls._instance._empty_cache = False # pylint: disable=protected-access
+ cls._instance._markers = [] # pylint: disable=protected-access
+
+ if WATERMARK_THREAD_POLLING_INTERVAL_IN_MS:
+ cls._instance._ram_allocated = Value('q', 0, lock=RLock()) # pylint: disable=protected-access
+ p = Process(
+ target=ram_watermark_function,
+ args=(cls._instance._ram_allocated,
+ os.getpid(),
+ WATERMARK_THREAD_POLLING_INTERVAL_IN_MS
+ ))
+ p.daemon = True # < will terminate the watermark process when this process exits
+ cls._instance.reset_peak_memory_stats()
+ p.start()
+
+ logger.info('Created Latency/Memory profiler: empty_cache=%s',
+ cls._instance._empty_cache) # pylint: disable=protected-access
+
+ return cls._instance
+
+ def reset_peak_memory_stats(self):
+ """ reset RAM usage to current RAM allocation. """
+ if WATERMARK_THREAD_POLLING_INTERVAL_IN_MS:
+ with self._ram_allocated.get_lock():
+ self._ram_allocated.value = ram_usage(os.getpid())
+
+ @property
+ def max_memory_allocated(self):
+ """ getter for current peak RAM usage since last reset. """
+ if WATERMARK_THREAD_POLLING_INTERVAL_IN_MS:
+ with self._ram_allocated.get_lock():
+ return self._ram_allocated.value
+ else:
+ return ram_usage(os.getpid())
+
+ @property
+ def empty_cache(self):
+ """ getter for empty_cache if set to True snapshot calls would flush unused CUDA memory. """
+ return self._empty_cache
+
+ @empty_cache.setter
+ def empty_cache(self, enable: bool = False):
+ """ setter for empty_cache if set to True snapshot calls would flush unused CUDA memory. """
+ if self._empty_cache != enable:
+ if self._empty_cache:
+ logger.warning('enabling cache clear might impact latency, avoid excessive calls in tight loop')
+ self._markers.append(f"empty_cache:{self._empty_cache}")
+
+ def snapshot(self, snapshot_marker: str = None,
+ device: Union[Device, int] = None,
+ append: bool = True) -> EventMarker:
+ """
+ logs the current time and memory usage across all CUDA devices
+ :param snapshot_marker: text to capture with the GPU marker.
+ :param device: (torch.device or int, optional): selected device.
+ :param append: if True, added it to report logs
+ """
+ marker = EventMarker(snapshot_marker, self.max_memory_allocated, device)
+ if self._empty_cache:
+ torch.cuda.empty_cache()
+
+ if torch.cuda.is_available():
+ torch.cuda.reset_peak_memory_stats(device)
+
+ logger.info("memory usage @ '%s' : GPU %s, RAM %s",
+ snapshot_marker, marker.gpu_memory_usage, marker.cpu_memory_usage)
+ if append:
+ self._markers.append(marker)
+
+ return marker
+
+ def report(self):
+ """ dumps the collected memory usage logs """
+ logger.info("Profiling report :- %s",
+ generate_event_report([m.to_dict() for m in self._markers], max_memory_threshold=0.9))
+
+ def json_dump(self, filepath: str):
+ """ dumps the collected memory usage into a json file """
+ markers = [m.to_dict() for m in self._markers]
+ with open(filepath, 'w') as f:
+ json.dump(markers, f, sort_keys=True, indent=4)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ if exc_type == torch.cuda.OutOfMemoryError: # pylint: disable=no-member
+ self.report()
+
+
+@contextlib.contextmanager
+def event_marker(event: str, device: Union[Device, int] = None, flush_ram: bool = False):
+ """
+ utility to mark time taken and memory usage before and after executing a section of code.
+ :param event: marker string to use to identify the context.
+ :param device: (torch.device or int, optional): selected device.
+ :param flush_ram: invoke garbage collect for true estimates before profiling.
+ """
+ profiler = EventProfiler()
+ # reset for start low-watermark
+ if flush_ram:
+ gc.collect()
+ event = f'{event}[gc]'
+
+ if torch.cuda.is_available():
+ torch.cuda.reset_peak_memory_stats(device)
+ profiler.reset_peak_memory_stats()
+ start_marker = profiler.snapshot(f'{event} >> ', device, append=False)
+ yield
+ end_marker = profiler.snapshot(f'{event} << ', device, append=False)
+ profile_marker = end_marker.delta(event, start_marker)
+ logger.info('%s', profile_marker)
+ profiler._markers.append(profile_marker) # pylint: disable=protected-access
+
+
+def generate_event_report(event_list: List[Dict[str, Union[int, str]]], max_memory_threshold: float) -> str:
+ """
+ utility to create a format event list report with additional statistics
+ :param event_list: a list of event entries(dict).
+ :param max_memory_threshold: threshold to mark all event(s) with range of max event e.g. 0.9 would log all event
+ with 90% of max usage event.
+ """
+
+ gpu_usage = lambda event: event['gpu_memory_usage'] + event['delta_gpu_memory_usage']
+ cpu_usage = lambda event: event['cpu_memory_usage'] + event['delta_cpu_memory_usage']
+
+ tot_time = sum(event['time'] for event in event_list)
+ max_gpu = gpu_usage(max(event_list, key=gpu_usage)), []
+ max_cpu = cpu_usage(max(event_list, key=cpu_usage)), []
+
+ stream = io.StringIO(newline='\n')
+
+ stream.write("\n" + "-" * 150)
+ stream.write("\n{:>90} | {:>18} | {:>18} |".format("", "GPU", "RAM"))
+ stream.write("\n{:<65} {:>22} | {:>12} {:>11} | {:>11} {:>11} |".format(
+ "Event", "Time", "delta", "agg", "delta", "agg"))
+ stream.write("\n" + "-" * 150)
+
+ for event in event_list:
+
+ event_desc = event['event']
+ duration = event['time']
+ gpu_mem = gpu_usage(event)
+ cpu_mem = cpu_usage(event)
+
+ cpu_marker = gpu_marker = ' '
+ if max_gpu[0] * max_memory_threshold < gpu_mem:
+ max_gpu[1].append(event_desc)
+ gpu_marker = '*'
+ if max_cpu[0] * max_memory_threshold < cpu_mem:
+ max_cpu[1].append(event_desc)
+ cpu_marker = '*'
+
+ stream.write("\n{:<65} {:>20}{:>5} | {:>12} {:>12}{}| {:>12} {:>12}{}|".format(
+ event_desc, str(timedelta(seconds=duration)), '{:.0%}'.format(duration / tot_time),
+ convert_bytes(event['delta_gpu_memory_usage']), convert_bytes(gpu_mem), gpu_marker,
+ convert_bytes(event['delta_cpu_memory_usage']), convert_bytes(cpu_mem), cpu_marker))
+
+ stream.write("\n" + "-" * 150)
+ stream.write("\nSummary:")
+ stream.write("\n\tTime(*under profiling*): {:>10}".format(str(timedelta(seconds=tot_time))))
+ stream.write("\n\tMax RAM: {:>10} => [ > {:.0%} : {} ]".format(
+ convert_bytes(max_cpu[0]), max_memory_threshold, ', '.join(max_cpu[1])))
+ stream.write("\n\tMax GPU memory: {:>10} => [ > {:.0%} : {} ]".format(
+ convert_bytes(max_gpu[0]), max_memory_threshold, ', '.join(max_gpu[1])))
+ stream.write("\n" + "-" * 150)
+ return stream.getvalue()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--profiling_log", required=True)
+ parser.add_argument("--max_memory_threshold", type=float, default=0.9)
+ args, unk = parser.parse_known_args()
+ if len(unk) > 0:
+ raise ValueError(f'[ERROR] unknown args: {unk}')
+
+ with open(args.profiling_log, 'r') as file:
+ events = json.load(file)
+ print("\nProfiling logs from {}, ts={}, {}".format(
+ os.path.abspath(args.profiling_log),
+ time.ctime(os.path.getmtime(args.profiling_log)),
+ generate_event_report(events, args.max_memory_threshold)
+ ))
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/recipe_checker.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/recipe_checker.py
new file mode 100644
index 000000000..8ad0acbb4
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/recipe_checker.py
@@ -0,0 +1,136 @@
+import json
+import csv
+import argparse
+csv_writer = None
+def load_json(file_path):
+ with open(file_path, 'r') as file:
+ return json.load(file)
+
+def compare_dicts(dict1, dict2, context):
+ global csv_writer
+ '''
+ the compare_dict function compares two dictionaries dict1 and dict2 recursively
+ a. it iterates over the keys in dict1, if the key is in dict2, it compares the corresponding values
+ b. if the key is not present in dict2, we report it
+ c. we also iterate over the keys in dict2 and check for their presence in dict1 and report if not present.
+ params:
+ 1. dict1: dictionary representing the payload of a given PTQ technique/API
+ 2. dict2: dictionary representing the payload of a given PTQ technique/API
+ 3. context: A string describing the hierarchical path to the point of difference between the payloads
+ '''
+ for key in dict1:
+ if key not in dict2:
+ csv_writer.writerow([f"{context}: {key}", dict1[key], ""])
+ else:
+ compare_values(dict1[key], dict2[key], f"{context} -> {key}")
+
+ for key in dict2:
+ if key not in dict1:
+ csv_writer.writerow([f"{context}: {key}", "", dict2[key]])
+
+def compare_values(val1, val2, context):
+ '''
+ compares two values val1, and val2. it checks for the types of val1 and val2 and if it is a dictionary it
+ invokes the compare_dict function recursively, else we compare the values if they match or not.
+ params:
+ 1. val1: The value of a specific property in the PTQ technique/API payload being compared.
+ 2. val2: value of the same property in the PTQ technique/API payload being compared
+ 3. context: A string describing the hierarchical path to the point of difference between the payloads
+ '''
+ global csv_writer
+ if isinstance(val1, dict) and isinstance(val2, dict):
+ compare_dicts(val1, val2, context)
+ elif val1 != val2:
+ csv_writer.writerow([f"{context}", val1, val2])
+
+def compare_json(json1, json2):
+ '''
+ This function is primarily responsible for comparing the PTQ techniques/ AIMET calls as captured in the recipe dumps.
+ As the strict checking, we ensure that the two dumps contain the same number of techniques else, we report to user and stop.
+ Next, we check if the order of techniques (determined by the Module_name+Operation name is aligned) is same,
+ if not, we do a second check to ensure that the set of techniques is same across the two dumps for comparison.
+ Once we have verified that either the order or the set matches, we then iterate over each function call captured and
+ compare the payload within the Parameters of algo and the Additional Properties.
+ params:
+ 1. json1: the JSON corresponding to the recipe dump1
+ 2. json2: the JSON corresponding to the recipe dump2
+ '''
+ global csv_writer
+ report = []
+ file = open('mismatches.csv', mode='w', newline='')
+ csv_writer = csv.writer(file)
+ # Write the header
+ csv_writer.writerow(['Property', 'Value in API1', 'Value in API2'])
+
+ api1 = json1.get('Quantization_PTQ_API', [])
+ api2 = json2.get('Quantization_PTQ_API', [])
+
+ environment_1 = json1.get('Environment', [])[0]
+ environment_2 = json2.get('Environment', [])[0]
+ compare_dicts(environment_1, environment_2,
+ "Environment")
+ # Order and Element Matching
+ ordered = True
+ if len(api1) != len(api2):
+ report.append("Number of PTQ techniques applied do not match.")
+ return report
+ else:
+ for i, (dict1, dict2) in enumerate(zip(api1, api2)):
+ if dict1.get('Module_name') != dict2.get('Module_name') and dict1.get('Operation_name') != dict2.get('Operation_name'):
+ ordered=False
+ break
+
+ if not ordered:
+ report.append(
+ f"Warning: the order of PTQ is different between the two dumps. Proceeding with verify if the set of PTQ is same or not..")
+ for dict1 in api1:
+ dict2 = next((d for d in api2 if
+ d['Module_name'] == dict1['Module_name'] and d['Operation_name'] == dict1['Operation_name']),
+ None)
+ if dict2 is None:
+ report.append(
+ f"PTQs do not match across the two dumps, please verify if the set of techniques applied matches.")
+ return report
+ report.append(
+ f"PTQs order does not match across the two dumps, but the set of techniques applied matches.")
+
+ for dict1 in api1:
+ dict2 = next((d for d in api2 if d['Module_name'] == dict1['Module_name'] and d['Operation_name'] == dict1['Operation_name']), None)
+ # Compare Parameters_of_algo and Additional_properties
+ compare_dicts(dict1.get('Parameters_of_algo', {}), dict2.get('Parameters_of_algo', {}),
+ f"Parameters_of_algo in {str(dict1.get('Operation_name'))}")
+ compare_dicts(dict1.get('Additional_properties', {}), dict2.get('Additional_properties', {}),
+ f"Additional_properties in {str(dict1.get('Operation_name'))}")
+ file.close()
+ report.append("\nNote: Please refer to the mismatches.csv file for details about the mismatches.")
+ return report
+
+def generate_report(report):
+ '''
+ This function reads in the report and writes it to the comparison_report.txt file.
+ '''
+ with open('comparison_report.txt', 'w') as file:
+ for line in report:
+ file.write(line + '\n')
+
+if __name__ == "__main__":
+ '''
+ The script compares the JSON dumps generated from the recipe_logger.
+ '''
+ parser = argparse.ArgumentParser(description='''
+ This script, when executed with two recipe dumps in JSON format, compares the environment and the Quantization PTQ API payload information between the two dumps.
+ It generates two files: `mismatches.csv`, which lists the mismatches along with the corresponding API,
+ and `report.txt`, which provides additional details on whether there is a discrepancy in the number of PTQ techniques applied between the two dumps.
+ Usage:
+ python recipe_checker.py path_to_recipe_dump1 path_to_recipe_dump2
+ ''')
+ parser.add_argument('json1_path', type=str, help='Path to the first JSON file')
+ parser.add_argument('json2_path', type=str, help='Path to the second JSON file')
+
+ args = parser.parse_args()
+
+ json1 = load_json(args.json1_path)
+ json2 = load_json(args.json2_path)
+
+ report = compare_json(json1, json2)
+ generate_report(report)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/recipe_logger.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/recipe_logger.py
new file mode 100644
index 000000000..36f4416a9
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/debug/recipe_logger.py
@@ -0,0 +1,269 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import json
+import logging
+from datetime import datetime
+import os
+import torch
+import types
+import inspect
+from importlib.metadata import version as impLib_version
+from packaging import version
+from enum import Enum
+filter_list = (torch.utils.data.dataloader.DataLoader, torch.nn.Module, torch.Tensor)
+logger = None
+json_dump = {'Environment': [], 'Property':[], 'Metric': [], "Quantization_PTQ_API":[]}
+log_json_path = None
+class CustomEncoder(json.JSONEncoder):
+ def default(self, obj):
+ try:
+ return super().default(obj)
+ except TypeError:
+ if isinstance(obj, types.FunctionType):
+ return "{Function}:" + obj.__name__
+ if obj.__class__.__name__ == "SeqMseParams":
+ return obj.__dict__
+ return repr(obj)
+
+class RecipeLogger:
+ '''
+ Defines a singleton class for initializing the logger and avoiding re-initializing the logger.
+ '''
+ _instance = None
+
+ def __new__(cls, output_dir, log_path_suffix=None):
+ if cls._instance is None:
+ cls._instance = super(RecipeLogger, cls).__new__(cls)
+ cls._instance._initialize(output_dir, log_path_suffix)
+ return cls._instance
+
+ def _initialize(self, output_dir, log_path_suffix=None):
+ global logger, log_json_path
+ if log_path_suffix:
+ log_file_path_prefix = os.path.join(output_dir, "_" + str(log_path_suffix))
+ else:
+ log_file_path_prefix = os.path.join(output_dir, "_" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
+ log_file_path = log_file_path_prefix + ".log"
+ log_json_path = log_file_path_prefix + ".json"
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s \n %(message)s')
+ logger = logging.getLogger("LLMDebug")
+ logger.setLevel(logging.DEBUG)
+ fileHandler = logging.FileHandler(log_file_path)
+ fileHandler.setLevel(logging.DEBUG)
+ fileHandler.setFormatter(formatter)
+ logger.addHandler(fileHandler)
+ logger.info("Logger initialized \n")
+
+def recipe_dump_init(output_dir, log_path_suffix=None):
+ """
+ This function is responsible for setting up the logger to log the recipe into the specified output directory.
+ :param output_dir: output dir to save the output log to
+ :param log_path_suffix: optional suffix for the log file name
+ """
+ return RecipeLogger(output_dir, log_path_suffix)
+
+
+class Property(Enum):
+ """
+ Defines the properties for the logging
+ dataset_name: The name of the dataset used
+ dataset_split: usually refers to "train" or "test" or "validation"
+ batch_size: number of samples in a single batch
+ num_batches: number of bathces used for a technique
+ mask_neg: masking value for the causal attention mask
+ ARN : maximum number of tokens that can be consumed by the model at each inference
+ context_len: maximum number of tokens that the model can consume in total
+ model_precision: represents whether the model was run in float16 or float32 precision
+ model_config: model config dictionary
+ model_id_or_path: path to model weights or the model_id
+ tokens_per_sample: number of tokens per sample can vary.
+ Currently, we set the block_size during dataset preprocessing to match the context length.
+ However, each sample might be shorter or longer than the model’s context length (Long Context)
+ """
+
+ dataset_name = "dataset_name"
+ dataset_split = "dataset_split"
+ batch_size = "batch_size"
+ num_batches = "num_batches"
+ mask_neg = "mask_neg"
+ context_length = "context_length"
+ model_precision = "model_precision"
+ model_config = "model_config"
+ model_id_or_path = "model_id_or_path"
+ ARN = "ARN"
+ tokens_per_sample = "tokens_per_sample"
+
+
+class Metric(Enum):
+ """ Defines the various metrics we can log """
+ ppl = "ppl"
+ mmlu = "mmlu"
+ sqnr = "sqnr"
+
+
+class ModelType(Enum):
+ """ Defines the model type we can log the metric/ property against"""
+ hf_model = "hf_model"
+ prepared_model = "prepared_model"
+ qsim_model = "qsim_model"
+ adapted_model = "adapted_model"
+
+
+def llm_lib_log_env_info(additional_env_info = None):
+ """
+ This function logs environment information, such as the versions of `aimet_torch` and `transformers`.
+ It should ideally be invoked after the logger is initially set up.
+
+ :param additional_env_info: a dictionary which may take in the additional environment information which a
+ user wants to log in addition to the transformers, torch and aimet version
+ """
+ env_var_list = ['AimetTorch', 'transformers', 'torch', 'aimet-torch']
+ found_vars={}
+ for env_var in env_var_list:
+ try:
+ env_version = version.parse(impLib_version(env_var))
+ found_vars[env_var] = env_version
+ except:
+ logger.warning(f"package {env_var} not found in the environment variables")
+
+ if additional_env_info is not None:
+ for env_var, env_version in additional_env_info.items():
+ found_vars[env_var] = env_version
+
+ log_message = "Environment variables found:\n" + "\n".join([f"{key}: {value}" for key, value in found_vars.items()])
+ logger.info(log_message)
+ json_dump['Environment'].append(found_vars)
+
+
+def llm_lib_log_and_execute(target_func, additional_properties=dict()):
+ """
+ This function creates a wrapper around the func, and logs the arguments, kwargs and additional properties before invoking func.
+ Since the arguments might include models or input tensors that aren't useful for logging, we filter them out before logging.
+
+ :param
+ target_func: the function whose args and kwargs we want to log
+ additional_properties: an optional dictionary which contains the keys from the pre-defined Property Enum.
+ """
+ def wrapper(*args, **kwargs):
+ bound_args = inspect.signature(target_func).bind(*args, **kwargs)
+ bound_args.apply_defaults()
+ args_ = _filter_collection(dict(bound_args.arguments))
+ for k, v in additional_properties.items():
+ if k not in Property:
+ raise KeyError("Property {} not defined".format(k))
+
+ property_payload = {str(k): v for k, v in additional_properties.items()}
+ payload = {"Module_name": target_func.__module__, "Operation_name": target_func.__name__, "Parameters_of_algo": args_, "Additional_properties": property_payload}
+ json_dump['Quantization_PTQ_API'].append(payload)
+ logger.info(f"Module_name : {target_func.__module__} \n Operation_name: {target_func.__name__} \n Parameters_of_algo: {args_}, \n Additional properties: {additional_properties} \n")
+ return target_func(*args, **kwargs)
+
+ return wrapper
+
+
+def llm_lib_log_property(property_dict = None):
+ """
+ This function is responsible for logging the property if it exists in the Property Enum
+
+ params:
+ property_dict: a dictionary containing the keys from one of the pre-defined property Enum with it's corresponding value.
+ """
+ payload = {}
+ if property_dict:
+ for property_name, value in property_dict.items():
+ if property_name not in Property:
+ raise KeyError("given Property is not supported")
+
+ log_message = f"{property_name.value} : {value} \n"
+ payload[property_name.value] = value
+
+ logger.info(log_message)
+ json_dump['Property'].append(payload)
+
+
+def llm_lib_log_metric(model_type, metric_name, value, model_name=None):
+ """
+ This function is responsible for logging the property if it exists in the Metric Enum
+
+ params:
+ model_type: a free form user passed model type, hf_model, QSim_model
+ metric_name: property passed which comes from the predefined Metric Enum
+ value: value corresponding to the Metric
+ model_name: name of the model whose property is being logged. Useful to differentiate models with different adapters in Lora.
+ """
+ if metric_name in Metric and model_type in ModelType:
+ metric_payload = {"Model_type": model_type.value, "Metric_name": metric_name.value, "Value": value}
+
+ if model_name is not None:
+ metric_payload["Model_name"]=model_name
+
+ log_message = "\n".join([f"{key}: {value}" for key, value in metric_payload.items()])
+ logger.info(log_message)
+ json_dump['Metric'].append(metric_payload)
+ else:
+ raise KeyError("given Metric or ModelType is not supported")
+
+
+def dump_logs_to_json():
+ '''
+ This function dumps the logs into a JSON file.
+ '''
+ with open(log_json_path, "w") as json_dump_file:
+ json.dump(json_dump, json_dump_file, indent=4,cls=CustomEncoder)
+ return log_json_path
+
+def _filter_collection(args):
+ """
+ This function filters the arguments by removing any noisy or verbose types specified in the filter_list,
+ or if the value corresponding to some key contains the object of QuantizationSimModel or it removes the
+ iterable (list) if it contains elements from the filter list. Additionally, to keep the logs clean,
+ we also remove the nested empty list/ tuple for instance arising from the dummy input.
+ :param args: function args or kwargs passed.
+ """
+ if isinstance(args, dict):
+ filtered_dict = {}
+ for k, v in args.items():
+ if not (isinstance(v, filter_list) or v.__class__.__name__ == 'QuantizationSimModel'):
+ filtered_value = _filter_collection(v)
+ # filter out if the filtered_value is either an empty list or tuple, for instance in dummy input
+ if filtered_value not in ([], ()):
+ filtered_dict[k] = filtered_value
+ return filtered_dict
+
+ if isinstance(args, (tuple, list)):
+ filtered= _filter_iterable(args)
+ return _remove_nested_empty_elements(filtered)
+
+ return args
+
+def _remove_nested_empty_elements(data):
+ """
+ Recursively removes all empty lists or tuples from a nested list or tuple structure.
+ :param data: The original list or tuple containing nested lists or tuples.
+ :return: The modified list or tuple with empty lists or tuples removed.
+ """
+
+ if isinstance(data, (list, tuple)):
+ # Recursively process each item and filter out empty lists or tuples
+ filtered = [_remove_nested_empty_elements(item) for item in data]
+
+ # Filter out any lists or tuples that are now empty after processing
+ filtered = [item for item in filtered if item]
+
+ return type(data)(filtered)
+ return data
+
+def _filter_iterable(args):
+ '''
+ This function recursively iterates over the iterables and filters out any elements that are part of the filter list.
+ '''
+ if isinstance(args, (tuple, list)):
+ return type(args)(_filter_iterable(arg) for arg in args if not isinstance(arg, filter_list))
+ return args
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/hadamard_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/hadamard_utils.py
new file mode 100644
index 000000000..db25de70e
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/hadamard_utils.py
@@ -0,0 +1,256 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# Cornell-RelaxML Quip-Sharp Hadamard Utility Functions
+# Copyright (C) 2023 Cornell RelaxML
+#
+# This file includes portions of code derived from the Cornell-RelaxML/quip-sharp project:
+# https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py
+#
+# This program is distributed under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+# =============================================================================
+
+import torch
+from torch import Tensor
+from typing import Optional, Tuple
+import math
+
+
+class HadamardTransform(torch.nn.Module):
+ """
+ Applies the Hadamard transform to the input tensor.
+ If `randomized` is True, the Hadamard matrix is randomized by multiplying with a seed vector.
+ If `seed` is provided, it is used as the seed vector for the randomized Hadamard matrix.
+ If `linear` is True, the Hadamard matrix is applied as a linear transformation
+ using a precomputed hadamard as weight matrix.
+ Args:
+ size (int): Size of the Hadamard matrix. Only powers of two are supported.
+ randomized (bool, optional): Whether to use a randomized Hadamard matrix. Defaults to False.
+ seed (Tensor, optional): Seed vector for the randomized Hadamard matrix. Defaults to None.
+ linear (bool, optional): Whether to apply the Hadamard matrix as a linear transformation. Defaults to False.
+ device (torch.device, optional): Device to use for the Hadamard matrix. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ randomized: bool = False,
+ seed: Optional[Tensor] = None,
+ linear: bool = False,
+ device=None,
+ ):
+ if not _is_pow2(size):
+ raise ValueError("size must be a power of 2")
+ if randomized and seed is not None:
+ raise ValueError("seed must be None if randomized is True")
+ super().__init__()
+ if seed is None and randomized:
+ seed = torch.bernoulli(0.5 * torch.ones(size, device=device)) * 2 - 1
+ self.register_buffer("seed", seed)
+ self.size = size
+ self.scale = 1 / math.sqrt(size)
+ if linear:
+ self.hadamard = torch.nn.Linear(
+ in_features=size,
+ out_features=size,
+ bias=False,
+ device=device,
+ )
+ self.hadamard.weight.requires_grad = False
+ self.hadamard.weight.data.copy_(hadamard_transform(torch.eye(size, device=device)))
+ self.eps = 1e-8
+
+ def forward(self, x: Tensor) -> Tensor:
+ y = x * self.seed if self.seed is not None else x
+ if hasattr(self, "hadamard"):
+ # If linear, apply the Hadamard matrix as a linear transformation.
+ return self.hadamard(y) * self.scale
+ else:
+ return hadamard_transform(y, scale=self.scale)
+
+ def apply_inverse(self, x: Tensor, transpose: bool = False) -> Tensor:
+ """
+ Applies the inverse of the Hadamard transform to the input tensor x.
+ For no transpose, equivalent to (X @ H^-1) @ diag(1/s) = X @ H^T @ diag(1/s), where H is the Hadamard matrix.
+ For transpose, equivalent to X @ (H^-1 @ diag(1/s))^T = X @ diag(1/s) @ H, where H is the Hadamard matrix.
+
+ Args:
+ x (Tensor): Input tensor of shape (..., size).
+ Returns:
+ Tensor: Output tensor of shape (..., size).
+ """
+ seed = self.seed if self.seed is not None else torch.ones(self.size, device=x.device)
+ if transpose:
+ # Use H^-1 = H^T, hence fact H^-T = H.
+ return hadamard_transform(x / (seed + self.eps), scale=self.scale)
+ return hadamard_transform(x, scale=self.scale)[..., : self.size] / (seed + self.eps)
+
+
+class GroupedHadamardTransform(HadamardTransform):
+ """
+ Applies the Hadamard transform to a tensor of shape (..., size) in groups of largest power of 2.
+ Args:
+ size (int): Size of input tensor. This should not be a power of two, but should be decomposable into groups of powers of two.
+ randomized (bool, optional): Whether to use a randomized Hadamard matrix. Defaults to False.
+ seed (Tensor, optional): Seed vector for the randomized Hadamard matrix. Defaults to None.
+ linear (bool, optional): Whether to apply the Hadamard matrix as a linear transformation. Defaults to False.
+ device (torch.device, optional): Device to use for the Hadamard matrix. Defaults to None.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ randomized: bool = False,
+ seed: Optional[Tensor] = None,
+ linear: bool = False,
+ device=None,
+ ):
+ if randomized and seed is not None:
+ raise ValueError("seed must be None if randomized is True")
+
+ if seed is None and randomized:
+ seed = torch.bernoulli(0.5 * torch.ones(size, device=device)) * 2 - 1
+ n_groups, groupsize = decompose_for_hadamard_grouped(size)
+
+ super().__init__(size=groupsize, seed=seed, linear=linear, device=device)
+
+ self.n_groups = n_groups
+ self.groupsize = groupsize
+ self.size = size
+
+ def forward(self, x: Tensor) -> Tensor:
+ y = x * self.seed if self.seed is not None else x
+ if hasattr(self, "hadamard"):
+ # If linear, apply the Hadamard matrix as a linear transformation.
+ y = y.view(-1, self.n_groups, self.groupsize)
+ return (self.hadamard(y) * self.scale).reshape(x.shape)
+ else:
+ return hadamard_grouped_transform(y, self.n_groups, self.groupsize, scale=self.scale)
+
+ def apply_inverse(self, x: Tensor, transpose: bool = False) -> Tensor:
+ """
+ Applies the inverse of the Hadamard transform to the input tensor x reshaped to groups.
+ For no transpose, equivalent to (X @ H^-1) @ diag(1/s) = X @ H^T @ diag(1/s), where H is the Hadamard matrix.
+ For transpose, equivalent to X @ (H^-1 @ diag(1/s))^T = X @ diag(1/s) @ H, where H is the Hadamard matrix.
+
+ Args:
+ x (Tensor): Input tensor of shape (..., size).
+ Returns:
+ Tensor: Output tensor of shape (..., size).
+ """
+ seed = self.seed if self.seed is not None else torch.ones(self.size, device=x.device)
+ if transpose:
+ # Use H^-1 = H^T, hence fact H^-T = H.
+ return hadamard_grouped_transform(
+ x / (seed + self.eps),
+ self.n_groups,
+ self.groupsize,
+ scale=self.scale,
+ )
+ return hadamard_grouped_transform(
+ x,
+ self.n_groups,
+ self.groupsize,
+ scale=self.scale,
+ ) / (seed + self.eps)
+
+
+def hadamard_transform(
+ X: Tensor,
+ scale: float = 1.0,
+) -> Tensor:
+ """
+ Apply Hadamard transform to a tensor and scale it.
+ e.g. X @ H.T / sqrt(n), where H is Hadamard matrix and n is the last dimension of X.
+
+ Caution: This function is meant to validate alignment with `fast_hadamard_transform`.
+ The einsum-based implementation shows differences up to rtol=1e-2 and is not fully aligned.
+
+ Args:
+ X (Tensor): Input tensor.
+ scale (float): Scaling factor to apply to the Hadamard transformed tensor.
+
+ Returns:
+ Tensor: Hadamard transformed tensor.
+ """
+ n = X.shape[-1]
+ if not _is_pow2(n):
+ raise ValueError(
+ f"Hadamard transform requires the last dimension to be a power of 2, got {n}."
+ )
+ input = X.clone().view(-1, n, 1)
+ output = input.clone()
+ while input.shape[1] > 1:
+ input = input.view(input.shape[0], input.shape[1] // 2, 2, input.shape[2])
+ output = output.view(input.shape)
+ output[:, :, 0, :] = input[:, :, 0, :] + input[:, :, 1, :]
+ output[:, :, 1, :] = input[:, :, 0, :] - input[:, :, 1, :]
+ output = output.view(input.shape[0], input.shape[1], -1)
+ (input, output) = (output, input)
+ del output
+
+ return input.view(X.shape) * scale
+
+
+def hadamard_grouped_transform(x: Tensor, n_groups: int, groupsize: int, scale: float = 1.0) -> Tensor:
+ """
+ Applied a grouped Hadamard to the last dimension.
+ Groups are defined as the largest power of 2 that the dimension decomposes into, e.g. if
+ d = 2^n * k, with k non-divisible by 2, then there are k groups of size 2^n.
+ Hadamard is applied to each group independently (i.e. the Hadamard matrix is block diagonal)
+
+ To get the sizes, call `decompose_for_hadamard_grouped(x.shape[-1])`
+ """
+
+ y = x.reshape(*x.shape[:-1], n_groups, groupsize)
+ return hadamard_transform(y, scale=scale).reshape(x.shape)
+
+
+def _is_pow2(n: int) -> bool:
+ """
+ Check if a number is a power of 2.
+
+ Args:
+ n (int): Number to check.
+
+ Returns:
+ bool: True if n is a power of 2, False otherwise.
+ """
+ return (n & (n - 1) == 0) and (n > 0)
+
+
+def decompose_for_hadamard_grouped(size: int) -> Tuple[int, int]:
+ """
+ Decomposes the size into groups of largest power of 2.
+ Returns the number of groups and the size of each group.
+ e.g. 12 -> 3 groups of size 4
+ Args:
+ size (int): Size to decompose.
+ Returns:
+ Tuple[int, int]: Number of groups and size of each group.
+ """
+ if size <= 0:
+ raise ValueError("size must be positive")
+ # Find the largest power of two that divides size
+ groupsize = 1 << (size.bit_length() - 1)
+ while size % groupsize != 0:
+ groupsize >>= 1
+ n_groups = size // groupsize
+ return n_groups, groupsize
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/model_adaptation/linear_to_conv.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/model_adaptation/linear_to_conv.py
new file mode 100644
index 000000000..af2b0cd59
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/model_adaptation/linear_to_conv.py
@@ -0,0 +1,74 @@
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" This file provides common utilities to replace Linear layers with Conv2d layers. """
+
+import warnings
+import torch
+from genai_lib.common.dev.utils import rsetattr
+from transformers.pytorch_utils import Conv1D
+
+class ConvInplaceLinear(torch.nn.Conv2d):
+ """ Convolution module that replaces a Linear layer inplace
+ We inherit from the torch.nn.Conv2d so that the resulting module can still behave like a leaf module when adding Lora Adapters.
+ This will equip us to use the same Lora Config for the adapted model as the original model config.
+ """
+
+ def __init__(self, mod):
+ if isinstance(mod, torch.nn.Linear):
+ weight, bias = mod.weight, mod.bias
+ elif isinstance(mod, Conv1D):
+ weight, bias = mod.weight.T, mod.bias
+
+ self.out_features, self.in_features = weight.shape
+
+ super(ConvInplaceLinear, self).__init__(
+ self.in_features,
+ self.out_features,
+ 1,
+ dtype=mod.weight.dtype,
+ bias=True if bias is not None else False
+ )
+
+ self.weight.data.copy_(weight.data[:, :, None, None])
+ if bias is not None:
+ self.bias.data.copy_(bias.data)
+ self.to(mod.weight.data.device)
+
+ def forward(self, x: torch.Tensor, scale: float = 1.0):
+ ndim = x.ndim
+ if ndim == 2:
+ x = x.unsqueeze(0).unsqueeze(-1).permute(0, 2, 3, 1) # (emb_dim, C) -> (1, C, 1, emb_dim)
+ elif ndim == 3:
+ x = x.unsqueeze(-1).permute(0, 2, 3, 1) # (B, emb_dim, C) -> (B, C, 1, emb_dim)
+ elif ndim == 4:
+ x = x.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
+ warnings.warn(f"{self.__class__.__name__} received an unexpected 4d input, assuming channels-last and proceeding.")
+ else:
+ raise NotImplementedError(f"{self.__class__.__name__} could not handle input with shape {x.shape}")
+
+ x = super().forward(x)
+
+ if ndim == 2:
+ return x.permute(0, 3, 1, 2).squeeze(-1).squeeze(0) # (1, C, 1, emb_dim) -> # (emb_dim, C)
+ elif ndim == 3:
+ return x.permute(0, 3, 1, 2).squeeze(-1) # (1, C, 1, emb_dim) -> # (B, emb_dim, C)
+ elif ndim == 4:
+ x = x.permute(0, 2, 3, 1) # (B, C, H, W) -> (B, H, W, C)
+ return x
+
+
+def replace_linears_with_convs(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Helper function to replace all linear modules with equivalent conv modules
+ """
+ for name, module in model.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ conv_layer = ConvInplaceLinear(module)
+ rsetattr(model, name, conv_layer)
+
+ return model
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/oset_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/oset_utils.py
new file mode 100644
index 000000000..a34b3df8b
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/oset_utils.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+"""OSET utilities for GenAI Lib"""
+
+from typing import Iterable, Optional, Type, TypeAlias
+
+import torch
+from aimet_common.utils import AimetLogger
+from aimet_torch.v2 import nn
+from aimet_torch.v2.nn.modules import custom as custom_ops
+from aimet_torch.v2.quantsim import QuantizationSimModel
+from aimet_torch.v2.utils import _ContextManager
+
+from genai_lib.common.dev.utils import extract_qmodules
+
+_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
+
+try:
+ from aimet.aimet_oset import OsetLib
+except ImportError:
+ _logger.error(
+ 'In order to import OsetLib, the OSET repository must be added to module search path such as PYTHONPATH'
+ )
+ raise
+else:
+ _ModuleType: TypeAlias = Type[torch.nn.Module]
+ _QMODULE_TYPE_TO_KERNEL = {
+ # Linear, Conv, MatMul ops
+ nn.QuantizedLinear: OsetLib.linear,
+ nn.QuantizedConv2d: OsetLib.conv2d,
+ nn.QuantizedConvTranspose2d: OsetLib.conv_transpose2d,
+ custom_ops.QuantizedMatMul: OsetLib.matmul,
+ # Norm ops
+ nn.QuantizedLayerNorm: OsetLib.layer_norm,
+ nn.QuantizedInstanceNorm2d: OsetLib.instance_norm,
+ nn.QuantizedGroupNorm: OsetLib.group_norm,
+ # Elementwise ops
+ custom_ops.QuantizedAdd: OsetLib.add,
+ custom_ops.QuantizedSubtract: OsetLib.sub,
+ custom_ops.QuantizedMultiply: OsetLib.mul,
+ custom_ops.QuantizedDivide: OsetLib.div,
+ custom_ops.QuantizedConcat: OsetLib.cat,
+ custom_ops.QuantizedSin: OsetLib.sin,
+ custom_ops.QuantizedCos: OsetLib.cos,
+ # Activation ops
+ nn.QuantizedGELU: OsetLib.gelu,
+ nn.QuantizedSigmoid: OsetLib.sigmoid,
+ nn.QuantizedSoftmax: OsetLib.softmax,
+ }
+
+ def register_oset_kernels(
+ sim: QuantizationSimModel,
+ module_types_to_exclude: Optional[Iterable[_ModuleType]] = None,
+ modules_to_exclude: Optional[Iterable[torch.nn.Module]] = None,
+ ) -> _ContextManager:
+ """
+ Registers OSET kernels for quantization modules within a QuantizationSimModel, with options to exclude certain types or instances of modules.
+ Ensures original kernels can be restored after registration with context manager block.
+
+ :param sim: QuantizationSimModel instance.
+ :param module_types_to_exclude: Optional iterable of module types to exclude from kernel registration.
+ :param modules_to_exclude: Optional iterable of specific module instances to exclude from kernel registration.
+ :return: A context manager that handles the registration and restoration of kernels.
+ """
+ if module_types_to_exclude is None:
+ module_types_to_exclude = tuple()
+ else:
+ module_types_to_exclude = tuple(module_types_to_exclude)
+
+ if modules_to_exclude is None:
+ modules_to_exclude = set()
+ else:
+ modules_to_exclude = set(modules_to_exclude)
+
+ def register_kernel_if_applicable():
+ for qmodule in extract_qmodules(sim, nn.QuantizationMixin):
+ if isinstance(qmodule, module_types_to_exclude):
+ continue
+
+ if qmodule in modules_to_exclude:
+ continue
+
+ kernel = _QMODULE_TYPE_TO_KERNEL.get(type(qmodule))
+ if kernel is not None:
+ qmodule.set_kernel(kernel)
+ _logger.info_once(
+ f'{type(qmodule)} will be executed with the OSET kernel'
+ )
+ else:
+ _logger.warning_once(
+ f'There is no OSET kernel corresponding to {type(qmodule)}. It will be executed with the PyTorch kernel'
+ )
+
+ orig = {
+ qmodule: qmodule.get_kernel()
+ for qmodule in extract_qmodules(sim, nn.QuantizationMixin)
+ }
+
+ def restore_kernels():
+ for qmodule, kernel in orig.items():
+ qmodule.set_kernel(kernel)
+
+ ctx = _ContextManager(action=lambda: None, cleanup=restore_kernels)
+
+ try:
+ register_kernel_if_applicable()
+ except Exception:
+ ctx._cleanup()
+ raise
+ else:
+ return ctx
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/peft/peft.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/peft/peft.py
new file mode 100644
index 000000000..7f15d88db
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/peft/peft.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" Implementation for handling LoRA adapters added using PEFT """
+
+import torch
+
+from packaging import version
+
+from peft.tuners.lora.layer import LoraLayer as PeftLoraLayer
+from peft.tuners.lora.layer import Conv2d as PeftConv2d
+import aimet_torch.utils as aimet_torch_utils
+from aimet_torch.peft import LoraLayer as AimetLoraLayer
+from aimet_torch.elementwise_ops import Add, Multiply
+from genai_lib.common.dev.utils import get_aimet_version
+
+class LoraLayerNoCast(AimetLoraLayer):
+ """
+ Quantizable lora layer
+
+ This is the same implementation with the AIMET LoRA layers except that it removes
+ data type casting in the forward function
+ """
+ # pylint: disable=too-many-instance-attributes
+ def __init__(self, lora_layer: PeftLoraLayer):
+ """
+ :param lora_layer: Lora layer we want to replace
+ """
+ super().__init__(lora_layer)
+
+ self.add_lora_to_res = torch.nn.ModuleList([Add() for _ in range(len(self.lora_A))])
+ self.mul_scale = torch.nn.ModuleList([Multiply() for _ in range(len(self.lora_A))])
+
+ # reshape lora scaling to be broadcastable when it's a vector
+ for i, scaling in enumerate(self.scaling):
+ if isinstance(scaling, torch.Tensor) and scaling.size():
+ self.scaling[i] = torch.nn.Parameter(
+ scaling.view(1, scaling.size()[0], 1, 1), requires_grad=False
+ ).to(scaling.device)
+
+ # self.scaling = torch.nn.ParameterList(self.scaling) if isinstance(self.scaling, list) else self.scaling
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ """ Forward pass for replaced layer"""
+ result = self.base_layer(x, *args, **kwargs)
+ # torch_result_dtype = result.dtype
+ for active_adapter in self.active_adapters:
+ if not self.active_adapters[active_adapter]:
+ continue
+
+ adapter_index = self.adapter_name_to_index[active_adapter]
+
+ lora_A = self.lora_A[adapter_index]
+ lora_B = self.lora_B[adapter_index]
+ dropout = self.lora_dropout[adapter_index]
+ # TODO: remove the to(device) once MPP error fixed
+ scaling = self.scaling[adapter_index].to(lora_A.weight.device)
+ # x = x.to(lora_A.weight.dtype)
+
+ result = self.add_lora_to_res[adapter_index](
+ result, lora_B(self.mul_scale[adapter_index](lora_A(dropout(x)), scaling.detach()))
+ )
+
+ # result = result.to(torch_result_dtype)
+ return result
+
+class LoraLayer4d(AimetLoraLayer):
+ """
+ Quantizable lora layer
+ """
+ # pylint: disable=too-many-instance-attributes
+ def __init__(self, lora_layer: PeftLoraLayer):
+ """
+ :param lora_layer: Lora layer we want to replace
+ """
+ super().__init__(lora_layer)
+
+ self.add_lora_to_res = torch.nn.ModuleList([Add() for _ in range(len(self.lora_A))])
+ self.mul_scale = torch.nn.ModuleList([Multiply() for _ in range(len(self.lora_A))])
+
+ # reshape lora scaling to be broadcastable when it's a vector
+ for i, scaling in enumerate(self.scaling):
+ if isinstance(scaling, torch.Tensor) and scaling.size():
+ self.scaling[i].data = scaling.data.view(1, scaling.size()[0], 1, 1)
+ # self.register_parameter(f'scaling_{i}', self.scaling[i])
+
+ # TODO: Eventually we want to enable the below snippet but MPP errors out for now. See more in MORPH-19557
+ # self.scaling = []
+ # for i, scaling in enumerate(lora_layer.scaling.values()):
+ # if isinstance(scaling, torch.Tensor) and scaling.size():
+ # self.scaling.append(torch.nn.Parameter(
+ # torch.as_tensor(scaling).view(1, scaling.size()[0], 1, 1), requires_grad=False
+ # ).to(self.base_layer.weight.device))
+
+ # self.scaling = torch.nn.ParameterList(self.scaling) if isinstance(self.scaling, list) else self.scaling
+
+ self._become_4d()
+
+ def _become_4d(self):
+ base_layer_conv = torch.nn.Conv2d(self.base_layer.in_channels if hasattr(self.base_layer, 'in_channels') else self.base_layer.in_features,
+ self.base_layer.out_channels if hasattr(self.base_layer, 'out_channels') else self.base_layer.out_features,
+ 1, bias=True if self.base_layer.bias is not None else False)
+ base_weight = self.base_layer.weight.reshape(self.base_layer.weight.shape[:2])
+ base_layer_conv.weight.data.copy_(base_weight[:,:,None,None])
+ if self.base_layer.bias is not None:
+ base_layer_conv.bias.data.copy_(self.base_layer.bias.data)
+ base_layer_conv.to(self.base_layer.weight.data.device)
+ self.base_layer = base_layer_conv
+
+ for idx in range(len(self.lora_A)):
+ lora_A_conv = torch.nn.Conv2d(self.lora_A[idx].in_channels if hasattr(self.lora_A[idx], 'in_channels') else self.lora_A[idx].in_features,
+ self.lora_A[idx].out_channels if hasattr(self.lora_A[idx], 'out_channels') else self.lora_A[idx].out_features,
+ 1, bias=True if self.lora_A[idx].bias is not None else False)
+ lora_B_conv = torch.nn.Conv2d(self.lora_B[idx].in_channels if hasattr(self.lora_B[idx], 'in_channels') else self.lora_B[idx].in_features,
+ self.lora_B[idx].out_channels if hasattr(self.lora_B[idx], 'out_channels') else self.lora_B[idx].out_features,
+ 1, bias=True if self.lora_B[idx].bias is not None else False)
+ lora_A_weight = self.lora_A[idx].weight.reshape(self.lora_A[idx].weight.shape[:2])
+ lora_B_weight = self.lora_B[idx].weight.reshape(self.lora_B[idx].weight.shape[:2])
+ lora_A_conv.weight.data.copy_(lora_A_weight[:,:,None,None])
+ lora_B_conv.weight.data.copy_(lora_B_weight[:,:,None,None])
+ if self.lora_A[idx].bias is not None: lora_A_conv.bias.data.copy_(self.lora_A[0].bias.data)
+ if self.lora_B[idx].bias is not None: lora_B_conv.bias.data.copy_(self.lora_B[0].bias.data)
+ lora_A_conv.to(self.lora_A[idx].weight.data.device)
+ lora_B_conv.to(self.lora_B[idx].weight.data.device)
+ self.lora_A[idx] = lora_A_conv
+ self.lora_B[idx] = lora_B_conv
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ """ Forward pass for replaced layer """
+ if not x.ndim == 4:
+ raise ValueError(f"{self.__class__.__name__} expects 4D inputs formatted (NCHW), received inputs of shape {x.shape}")
+
+ result = self.base_layer(x, *args, **kwargs)
+ torch_result_dtype = result.dtype
+ for active_adapter in self.active_adapters:
+ if not self.active_adapters[active_adapter]:
+ continue
+
+ adapter_index = self.adapter_name_to_index[active_adapter]
+
+ lora_A = self.lora_A[adapter_index]
+ lora_B = self.lora_B[adapter_index]
+ dropout = self.lora_dropout[adapter_index]
+ # TODO: remove the to(device) once MPP error fixed
+ scaling = self.scaling[adapter_index].to(lora_A.weight.device)
+ x = x.to(lora_A.weight.dtype)
+
+ result = self.add_lora_to_res[adapter_index](
+ result, lora_B(self.mul_scale[adapter_index](lora_A(dropout(x)), scaling.detach()))
+ )
+
+ result = result.to(torch_result_dtype)
+ return result
+
+def _is_peft_lora_layer(layer: torch.nn.Module) -> bool:
+ """
+ Check if the layer is instance of peft.tuners.lora.layer.LoraLayer
+ :param layer: The torch layer to be checked
+ :return: If the layer is instance of peft.tuners.lora.layer.LoraLayer
+ """
+ return isinstance(layer, PeftLoraLayer)
+
+def _replace_peft_lora_modules(model: torch.nn.Module, replace_layer: torch.nn.Module) -> None:
+ """
+ Replace PEFT LoRA modules with the target replace_layer
+ :param model: The model on which to replace with LoRA modules
+ :param replace_layer: The target layer to replace with
+ :return: None. Replacement happens in-place
+ """
+
+ aimet_version = get_aimet_version()
+
+ if aimet_version >= version.Version('2.1'):
+ # For aimet_torch 2.1.x or greater
+ aimet_torch_utils.replace_modules(model, _is_peft_lora_layer, replace_layer)
+ else:
+ # For aimet_torch 2.0.x or earlier
+ aimet_torch_utils.replace_modules_of_type1_using_constructor(model, PeftLoraLayer, replace_layer)
+ aimet_torch_utils.replace_modules_of_type1_using_constructor(model, PeftConv2d, replace_layer)
+
+def replace_lora_layers_with_4d_quantizable_layers(model: torch.nn.Module):
+ """
+ Utility to replace lora layers with Quantizable Lora layers
+
+ :param model: PEFT model
+ """
+ _replace_peft_lora_modules(model, LoraLayer4d)
+
+
+def replace_lora_layers_with_no_cast_quantizable_layers(model: torch.nn.Module):
+ """
+ Utility to replace lora layers with Quantizable Lora layers
+
+ :param model: PEFT model
+ """
+ _replace_peft_lora_modules(model, LoraLayerNoCast)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/utils.py
new file mode 100644
index 000000000..4df63d828
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/dev/utils.py
@@ -0,0 +1,220 @@
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+# =============================================================================
+# Copyright 2020 Optuna, Hugging Face
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+""" Common utilities for GenAI Lib """
+
+import functools
+import inspect
+import logging
+from importlib import metadata, util
+from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
+
+import torch
+
+from packaging import version
+
+AIMET_TORCH_AVAILABLE = util.find_spec("aimet_torch") is not None
+
+
+@functools.cache
+def is_package_greater_or_equal(package_name: str, package_version: str) -> bool:
+ return version.parse(metadata.version(package_name)) >= version.parse(
+ package_version
+ )
+
+
+def is_transformers_greater_or_equal(package_version: str) -> bool:
+ return is_package_greater_or_equal('transformers', package_version)
+
+
+is_transformers_greater_or_equal_than_4_48 = is_transformers_greater_or_equal('4.48.0')
+is_transformers_greater_or_equal_than_4_51 = is_transformers_greater_or_equal('4.51.0')
+is_transformers_greater_or_equal_than_4_53 = is_transformers_greater_or_equal('4.53.0')
+
+
+def get_aimet_version():
+ # Here we are fetching the AIMET version by giving precedence to the open source version (if exists)
+ # This function returns the o/p of type packaging.version.Version, and not str.
+ package_names = ["AimetTorch", "aimet-torch"]
+ aimet_version = None
+ for package in package_names:
+ try:
+ aimet_version = version.parse(metadata.version(package))
+ except:
+ continue
+ assert aimet_version is not None, "Could not find any AIMET version"
+ return aimet_version
+
+def rgetattr(obj, attr, *args):
+ def _getattr(obj, attr):
+ return getattr(obj, attr, *args)
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def rsetattr(obj, attr, val):
+ pre, _, post = attr.rpartition('.')
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def change_signature_defaults(func: Callable, defaults_dict: Dict[str, Any]) -> Callable:
+ """
+ Utility that changes default values on a function's signature.
+ This is useful for boolean inputs that cannot be part of a dummy input for preparation
+ as they disappear on the traced graph.
+
+ Args:
+ func (Callable): The function whose signature defaults are to be changed.
+ defaults_dict (Dict[str, Any]): A dictionary where keys are parameter names and values are the new default values.
+
+ Returns:
+ Callable: A new function with the updated signature defaults.
+ """
+ sig = inspect.signature(func)
+ params = list(sig.parameters.values())
+
+ for i, param in enumerate(params):
+ if param.name in defaults_dict:
+ params[i] = param.replace(default=defaults_dict[param.name])
+
+ new_sig = sig.replace(parameters=params)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ bound_args = new_sig.bind(*args, **kwargs)
+ bound_args.apply_defaults()
+ return func(*bound_args.args, **bound_args.kwargs)
+
+ wrapper.__signature__ = new_sig
+ return wrapper
+
+
+@functools.lru_cache(None)
+def warning_once(self, *args, **kwargs):
+ """
+ This method is identical to `logger.warning()`, but will emit the warning with the same message only once
+
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
+ The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
+ another type of cache that includes the caller frame information in the hashing function.
+ """
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_once = warning_once
+
+
+@functools.lru_cache(None)
+def info_once(self, *args, **kwargs):
+ """
+ This method is identical to `logger.info()`, but will emit the info with the same message only once
+
+ Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
+ The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
+ another type of cache that includes the caller frame information in the hashing function.
+ """
+ self.info(*args, **kwargs)
+
+
+logging.Logger.info_once = info_once
+
+
+def filter_outputs(outputs: List[Union[torch.Tensor, Tuple]],
+ output_index_filter: List[Optional[Union[int, str]]]) -> Tuple:
+ """
+ Filters the outputs based on the provided index filter.
+
+ Args:
+ outputs (List[Union[torch.Tensor, Tuple]]): The list of model outputs.
+ output_index_filter (List[Optional[Union[int, str]]]): The list of index filters.
+ Each element can be an integer, a string (':'), or None.
+ If integer, that index is kept.
+ If string ":", all indexes are kept.
+ If None, the output is removed.
+
+ Returns:
+ Tuple: A tuple of filtered outputs.
+
+ Raises:
+ AssertionError: If the length of outputs does not match the length of output_index_filter.
+ """
+ assert len(outputs) == len(output_index_filter), \
+ f'output_index_filter must match expected model output length, ' \
+ f'got {len(outputs)} model outputs, but {len(output_index_filter)} index filters'
+
+ filtered_outputs = []
+ for idx, output_filter in enumerate(output_index_filter):
+ if output_filter is None:
+ continue
+ if output_filter == ':':
+ chosen_output = outputs[idx]
+ else:
+ chosen_output = outputs[idx][output_filter]
+ # unsqueeze to make up for the lost dimension when we access at [output_filter]
+ if isinstance(outputs[idx], torch.Tensor):
+ chosen_output = chosen_output.unsqueeze(0)
+ else:
+ chosen_output = (chosen_output, )
+ filtered_outputs.append(chosen_output)
+
+ return tuple(filtered_outputs)
+
+if AIMET_TORCH_AVAILABLE:
+ from aimet_torch.v2.nn import QuantizationMixin
+ from aimet_torch.v2.quantsim import QuantizationSimModel
+ from aimet_torch.v2.utils import _ContextManager
+
+ def extract_qmodules(sim, check_type) -> Generator:
+ """
+ Extracts and returns a generator of qmodules from the given simulation object that are instances of the specified type.
+
+ :param sim: QuantizationSimModel instance.
+ :param check_type: The type to check each qmodule against.
+ :return: A generator of qmodules that are instances of check_type.
+ """
+ return (x for x in sim.qmodules() if isinstance(x, check_type))
+
+ def reset_kernels(sim: QuantizationSimModel) -> _ContextManager:
+ """
+ Resets the kernels of all quantization modules within a QuantizationSimModel to None.
+
+ :param sim: QuantizationSimModel instance.
+ :return: A context manager that handles the registration and restoration of kernels.
+ """
+ orig = {
+ qmodule: qmodule.get_kernel()
+ for qmodule in extract_qmodules(sim, QuantizationMixin)
+ }
+
+ def restore_kernels():
+ for qmodule, kernel in orig.items():
+ qmodule.set_kernel(kernel)
+
+ ctx = _ContextManager(action=lambda: None, cleanup=restore_kernels)
+
+ try:
+ for qmodule in extract_qmodules(sim, QuantizationMixin):
+ qmodule.set_kernel(None)
+ except Exception:
+ ctx._cleanup()
+ raise
+ else:
+ return ctx
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/onnxruntime_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/onnxruntime_utils.py
new file mode 100644
index 000000000..e866b2e99
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/common/onnxruntime_utils.py
@@ -0,0 +1,177 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import torch
+from torch.utils._pytree import tree_map_only
+
+import numpy as np
+
+import onnxruntime as ort
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Protocol, Optional, List
+
+class ONNXNameMapper(Protocol):
+ """
+ Callable signature that maps a flattened ONNX input/output name to the
+ corresponding PyTorch tensor stored in an input/ouput dict.
+
+ A function that matches this protocol is responsible for looking up the
+ correct tensor when we require a memory pointer for I/O-binding.
+
+ params:
+ onnx_name (str): The flattened name of the ONNX input/output.
+ tensor_dict (Dict[str, Any]): The prepared_inputs/output_buffer.
+ is_output (bool): Whether the ONNX name is an output, in case different logic is necessary.
+ delim (str): The delimiter used in the flattened name.
+ returns:
+ torch.Tensor: The corresponding tensor in the prepared_inputs/output_buffer.
+ """
+ def __call__(
+ self,
+ onnx_name: str,
+ tensor_dict: Dict[str, Any],
+ is_output: bool,
+ delim: str = "_"
+ ) -> torch.Tensor:
+ ...
+
+class OutputBufferCreator(ABC):
+ """
+ Interface that builds an output buffer for ONNX Runtime IO Binding.
+
+ Every concrete implementation can customize the parameters required
+ in its constructor. Caller then keeps a reference to the instance and invokes
+ create_buffer() or calls the instance itself, with no need for arguments.
+ """
+ @abstractmethod
+ def create_buffer(self) -> Dict[str, Any]:
+ """
+ Allocate and return the buffer that will hold the model outputs.
+ """
+ ...
+
+ def __call__(self) -> Dict[str, Any]:
+ return self.create_buffer()
+
+class ORTInferenceModule(torch.nn.Module):
+ """
+ A PyTorch module that wraps an ONNX Runtime inference session. It applies I/O binding and creates an output buffer.
+
+ params:
+ ort_session (ort.InferenceSession): The ONNX Runtime inference session.
+ device (torch.device): device to place module on
+ onnx_name_to_tensor (ONNXNameMapper): A function that maps flattened ONNX names (str) to PyTorch Tensors.
+ prepare_output_buffer (Callable[..., Dict[str, Any]]): A function that prepares the output buffer.
+ """
+ def __init__(self,
+ ort_session:ort.InferenceSession,
+ device: torch.device,
+ onnx_name_to_tensor: ONNXNameMapper,
+ output_buffer_creator: OutputBufferCreator):
+ """
+ Constructor
+ """
+ super().__init__()
+ self.session = ort_session
+
+ self.onnx_name_to_tensor = onnx_name_to_tensor
+
+ self.output_buffer_creator = output_buffer_creator
+
+ # register a single dummy parameter so Transformers pipeline() is able to know where the module is located
+ self.register_parameter("_dummy", torch.nn.Parameter(torch.empty(0, device=device), requires_grad=False))
+
+ @staticmethod
+ def bind_io(session:ort.InferenceSession,
+ onnx_name_to_tensor: ONNXNameMapper,
+ input_buffer: Dict[str, Any],
+ output_buffer: Dict[str, Any],
+ output_names: Optional[List[str]] = None) -> ort.IOBinding:
+ """
+ This function produces an IOBinding that binds the input buffer and output buffer to the ONNX Runtime Session.
+
+ params:
+ session: (ort.InferenceSession): ONNX Runtime inference session to create IO Binding for
+ onnx_name_to_tensor (ONNXNameMapper): A function that maps flattened ONNX names (str) to PyTorch Tensors.
+ input_buffer (Dict[str, Any]): Dict that contains prepared_inputs.
+ output_buffer (Dict[str, Any]): Dict that contains empty memory allocated for outputs.
+ output_names (Optional[List[str]]): List of output names, in case user doesn't want to bind all outputs of the session
+
+ returns:
+ ort.IOBinding: IOBinding that binds the input buffer and output buffer to the ONNX Runtime Session
+ """
+ io_binding = session.io_binding()
+ pt_to_np = {
+ "torch.int32": np.int32,
+ "torch.int64": np.int64,
+ "torch.float32": np.float32,
+ "torch.float16": np.float16
+ }
+
+ for input in session.get_inputs():
+ # Retrive the tensor corresponding to the ONNX Runtime input name from the input_buffer
+ tensor_to_bind = onnx_name_to_tensor(input.name, input_buffer, is_output=False)
+ io_binding.bind_input(
+ name = input.name,
+ device_type = tensor_to_bind.device.type,
+ device_id = 0 if tensor_to_bind.device.type == "cpu" else tensor_to_bind.device.index,
+ element_type = pt_to_np[repr(tensor_to_bind.dtype)],
+ shape = tuple(tensor_to_bind.shape),
+ buffer_ptr = tensor_to_bind.data_ptr()
+ )
+
+ output_names = (
+ output_names
+ if output_names is not None
+ else [o.name for o in session.get_outputs()]
+ )
+
+ for output in output_names:
+ # Retrive the tensor corresponding to the ONNX Runtime output name from the output_buffer
+ tensor_to_bind = onnx_name_to_tensor(output, output_buffer, is_output=True)
+ io_binding.bind_output(
+ name = output,
+ device_type = tensor_to_bind.device.type,
+ device_id = 0 if tensor_to_bind.device.type == "cpu" else tensor_to_bind.device.index,
+ element_type = pt_to_np[repr(tensor_to_bind.dtype)],
+ shape = tuple(tensor_to_bind.shape),
+ buffer_ptr = tensor_to_bind.data_ptr()
+ )
+
+ return io_binding
+
+ @torch.no_grad()
+ def forward(self, output_names: Optional[List[str]] = None, **prepared_inputs: Any) -> Dict[str, Any]:
+ """
+ Forward function of ORTInferenceModule
+
+ params:
+ output_names (Optional[List[str]]): List of output names, in case user doesn't want to bind all outputs of the session
+ **prepared_inputs (Any): inputs to model
+
+ returns:
+ ort.IOBinding: IOBinding that binds the input buffer and output buffer to the ONNX Runtime Session
+ """
+ # ensure torch tensors are contiguous, for compatibility with IO Bindings
+ prepared_inputs = tree_map_only(torch.Tensor, lambda t: t.contiguous(), prepared_inputs)
+ # allocate new memory for output buffer, so there is no overlap with input
+ output_buffer = self.output_buffer_creator.create_buffer()
+ # Bind IO
+ io_binding = self.bind_io(session=self.session,
+ onnx_name_to_tensor=self.onnx_name_to_tensor,
+ input_buffer=prepared_inputs,
+ output_buffer=output_buffer,
+ output_names=output_names)
+
+ # binds extra outputs
+
+ self.session.run_with_iobinding(io_binding)
+
+ return output_buffer
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/baichuan/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/baichuan/adaptation.py
new file mode 100644
index 000000000..3652ef9c5
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/baichuan/adaptation.py
@@ -0,0 +1,351 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# ==============================================================================
+# Copyright 2023 Baichuan Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+""" This file provides adaptations to the Baichuan model. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/baichuan-inc/baichuan-7B/blob/main/models/modeling_baichuan.py"""
+
+import math
+from typing import List, Optional, Tuple, Union
+from threading import Thread
+import importlib
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+from torch.nn import functional as F
+from transformers import PreTrainedModel, PretrainedConfig
+from transformers.activations import ACT2FN
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.generation.utils import GenerationConfig
+from transformers.utils import logging, ContextManagers
+import aimet_torch.elementwise_ops as op
+
+module_name = 'transformers_modules.HF-Baichuan2-7B-Instruct.modeling_baichuan'
+modeling_baichuan = importlib.import_module(module_name)
+for attribute_name in dir(modeling_baichuan):
+ if not attribute_name.startswith('__'):
+ globals()[attribute_name] = getattr(modeling_baichuan, attribute_name)
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2:]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos_, sin_, position_ids):
+ cos = cos_.squeeze(1).squeeze(0) # [seq_len, dim]
+ sin = sin_.squeeze(1).squeeze(0) # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
+ q_embed = (q.float() * cos) + (rotate_half(q.float()) * sin)
+ k_embed = (k.float() * cos) + (rotate_half(k.float()) * sin)
+ return q_embed.to(q.dtype), k_embed.to(k.dtype)
+
+
+def apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ '''
+ Based on FacebookResearch's llama, provided by Carl
+ '''
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim/2
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim/2
+
+ # TODO: Why HF uses different coordinates from the paper
+ x_real = x[:, :, :, :x.shape[-1] // 2] # extract first half elements
+ x_im = x[:, :, :, x.shape[-1] // 2:] # extract second half elements
+
+ x_prod_real = x_real * rope_real - x_im * rope_im
+ x_prod_im = x_real * rope_im + x_im * rope_real
+
+ # TODO: HF need to uses different interleaving
+ x = torch.cat((x_prod_real, x_prod_im), dim=3)
+ return x
+
+
+class QCAttention(Attention):
+ def __init__(self, config: PretrainedConfig):
+ Attention.__init__(self, config)
+ self.use_unpack_qkv = False
+
+ def unpack_qkv(self):
+ self.use_unpack_qkv = True
+
+ device = self.W_pack.weight.device
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False).to(device)
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False).to(device)
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False).to(device)
+
+ # Adjust the slicing to match the shapes
+ self.q_proj.weight.data.copy_(self.W_pack.weight[:self.hidden_size, :])
+ self.k_proj.weight.data.copy_(self.W_pack.weight[self.hidden_size:2 * self.hidden_size, :])
+ self.v_proj.weight.data.copy_(self.W_pack.weight[2 * self.hidden_size:, :])
+
+ def scaled_dot_product_attention(self, query, key, value, attn_mask, transposed_key_cache):
+ scale_factor = 1 / math.sqrt(query.size(-1))
+
+ if transposed_key_cache:
+ attn_weight = torch.matmul(query, key) * scale_factor
+ else:
+ attn_weight = torch.matmul(query, key.transpose(-2, -1)) * scale_factor
+
+ attn_weight += attn_mask
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ return torch.matmul(attn_weight, value)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # QC
+ transposed_key_cache = (
+ self.config.transposed_key_cache if hasattr(self.config, "transposed_key_cache") else False
+ )
+ return_new_key_value_only = (
+ self.config.return_new_key_value_only if hasattr(self.config, "return_new_key_value_only") else False
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+ if self.use_unpack_qkv:
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ else:
+ # Added original fused QKV implementation just to be safe
+ proj = self.W_pack(hidden_states)
+ proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
+ query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[1].shape[-2]
+
+ if isinstance(position_ids, (tuple, list)):
+ rope_embedding = position_ids
+ query_states = apply_rope_single(query_states, rope_embedding)
+ key_states = apply_rope_single(key_states, rope_embedding)
+ else:
+ cos, sin = self.rotary_emb(value_states, kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ # QC
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if return_new_key_value_only:
+ present_key_value = (key_states, value_states) if use_cache else None
+
+ if past_key_value is not None:
+ dim = 3 if transposed_key_cache else 2
+ key_states = torch.cat([past_key_value[0], key_states], dim=dim)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ if not return_new_key_value_only:
+ present_key_value = (key_states, value_states) if use_cache else None
+
+ attn_output = self.scaled_dot_product_attention(query_states, key_states, value_states, attention_mask,
+ transposed_key_cache)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, present_key_value
+
+
+_prepare_decoder_attention_mask = modeling_baichuan.BaichuanModel._prepare_decoder_attention_mask
+def adapted_prepare_decoder_attention_mask(self, attention_mask, *args, **kwargs):
+ if attention_mask is not None and attention_mask.dim() == 4:
+ return attention_mask
+ else:
+ return _prepare_decoder_attention_mask(self, attention_mask, *args, **kwargs)
+
+
+def QCBaichuanModel_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ elif isinstance(position_ids, (tuple, list)):
+ pass
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+ attention_mask = self._prepare_decoder_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/baichuan/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/baichuan/utils.py
new file mode 100644
index 000000000..9a0f72079
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/baichuan/utils.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that the pipeline would need to work with the adaptations made to Mistral model. """
+
+import torch
+import functools
+import importlib
+
+
+def llm_update_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, model_id_or_path, mask_neg = -100.0):
+ '''
+
+ This function creates a causal mask (2D) from the 1D attention mask
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. model_id_or_path: #module name of baichuan under transformers_modules
+ 6. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ '''
+ baichuan_model = _get_model(model_id_or_path)
+ input_embeds = torch.ones((input_tensor.shape[0], input_tensor.shape[1], 1), device = input_tensor.device)
+ prepared_attention_mask = baichuan_model._prepare_decoder_attention_mask(attention_mask=prepared_1d_attn_mask, input_shape = (input_tensor.shape[0], input_tensor.shape[1]), inputs_embeds=input_embeds,past_key_values_length = model_context_len-max_input_tokens)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+
+def llm_create_position_embeddings(config, model_id_or_path, position_ids=None):
+ '''
+ This function creates position embedding (RoPE) from the position ids.
+ params:
+ 1. config: model configuration to create the LLamaRotaryEmbedding object, expect config for backward compatibility, in future transformers, we only expect to pass the config, the one we have in docker, takes in the req argument
+ 2. model_id_or_path: #module name of baichuan under transformers_modules
+ 3. position_ids: required position ids passed into the model
+ '''
+ max_position_embeddings = config.max_position_embeddings
+ dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size // config.num_attention_heads
+ device = position_ids.device
+ x = torch.ones(1, device=device)
+ rotary_emb = _get_rotary_embedding(dim=dim, max_position_embeddings=max_position_embeddings, device=device, model_id_or_path=model_id_or_path)
+ cos, sin = rotary_emb(x, seq_len=max_position_embeddings)
+ # the cos and sin returned are of shape (1, 1, max_position_emb, dim), in order to index into the max_position_emb, we remove the first two dimensions.
+ cos, sin = cos.squeeze(dim=1).squeeze(dim=0), sin.squeeze(dim=1).squeeze(dim=0)
+ cos, sin = (torch.stack([cos[[position_ids[i]]] for i in range(position_ids.shape[0])]),
+ torch.stack([sin[[position_ids[i]]] for i in range(position_ids.shape[0])]))
+ cos, sin = cos.unsqueeze(dim = 1), sin.unsqueeze(dim = 1)
+ cos = cos[:,:,:,:dim//2]
+ sin = sin[:,:,:, :dim//2]
+ return cos, sin
+
+@functools.cache
+def _get_rotary_embedding(dim, max_position_embeddings, device, model_id_or_path):
+ # here, model_id_or_path represents the module name of Baichuan model under transformers_modules
+ modeling_baichuan_name = model_id_or_path + '.modeling_baichuan'
+ modeling_baichuan = importlib.import_module(modeling_baichuan_name)
+ rotary_emb = modeling_baichuan.RotaryEmbedding(dim=dim, max_position_embeddings=max_position_embeddings, base=10000, device=device)
+ return rotary_emb
+
+@functools.cache
+def _get_model(model_id_or_path):
+ # here, model_id_or_path represents the module name of Baichuan model under transformers_modules
+ modeling_baichuan_name = model_id_or_path + '.modeling_baichuan'
+ modeling_baichuan = importlib.import_module(modeling_baichuan_name)
+ config_path_name = model_id_or_path + '.configuration_baichuan'
+ config_path = importlib.import_module(config_path_name)
+ config = config_path.BaichuanConfig()
+ config.num_hidden_layers = 1
+ model = modeling_baichuan.BaichuanModel(config)
+ return model
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/common/spinquant.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/common/spinquant.py
new file mode 100644
index 000000000..637003f33
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/common/spinquant.py
@@ -0,0 +1,38 @@
+
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" Implementation of concepts from the SpinQuant paper: https://arxiv.org/pdf/2405.16406"""
+
+import torch
+from torch import nn
+import scipy
+import math
+
+class R3Hadamard(nn.Linear):
+ def __init__(self, head_dim: int):
+ super(R3Hadamard, self).__init__(in_features=head_dim, out_features=head_dim, bias=False, dtype=torch.float)
+ self.head_dim = head_dim
+ self.is_initialized = False
+
+ def initialize_r3_hadamard(self):
+ r3_weight = torch.tensor(scipy.linalg.hadamard(self.head_dim) / math.sqrt(self.head_dim), dtype=torch.float)
+ if isinstance(self, nn.Linear):
+ self.weight.data.copy_(r3_weight.T)
+ elif isinstance(self, nn.Conv2d): # For support to ConvInplaceLinear
+ self.weight.data.copy_(r3_weight.T[:,:,None,None])
+ else:
+ raise TypeError(f"Class {self.__class__.__name__} became an instance of {type(self)}, \
+ but only nn.Linear and nn.Conv2d are supported types")
+ self.is_initialized = True
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if not self.is_initialized:
+ raise AssertionError(f"{self.__class__.__name__} has not been fully initialized. \
+ Invoke class method `initialize_r3_hadamard()` before doing a forward pass with this object")
+ return super.forward(x)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/common/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/common/utils.py
new file mode 100644
index 000000000..2b60874cd
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/common/utils.py
@@ -0,0 +1,122 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that are common across models. These utilities are needed because of model adaptations """
+
+from transformers.utils import ModelOutput
+
+from genai_lib.llm.utils import _concat, _shift
+
+KEY_CONCAT_AXIS = 3
+VALUE_CONCAT_AXIS = 2
+
+def llm_update_kv_cache(unpadded_past_kv, current_key_values, key_concat_axis=KEY_CONCAT_AXIS, value_concat_axis=VALUE_CONCAT_AXIS, input_ids_slice = None, inputs_embeds_slice=None, pad_to_left=True, skip_pad_layers=()):
+ """
+ This function concats the KV cache that the model outputs in the current iteration (unpadded_past_kv) with the KV$ that the model has accumulated so far(unpadded_past_kv)
+ 1. remove the non-useful padding kv from the current_key_values depending on whether it was padded to left or to right
+ 2. concatenate the stripped current kv with past useful kv if it exists
+
+ params:
+ 1. unpadded_past_kv: the unpadded useful kv that is accumulated from the previous model invokations
+ 2. current_key_values: current padded kv returned from the model
+ 3. key_concat_axis: the axis to which we want to append the keys
+ 4. value_concat_axis: the axis to which we want to append the values
+ 5. input_ids_slice: the slice of inputs returned from the iterator (this is before any padding has been applied to meet the static shape requirement)
+ 6. inputs_embeds_slice: the slice of inputs returned from the iterator (this is before any padding has been applied to meet the static shape requirement)
+ 7. pad_to_left: boolean value indicating whether padding is done towards the left or right.
+ 8. skip_pad_layers: Layers not to pad, such as layers which are not updated each LLM inference (e.g. cross-attention layers)
+
+ """
+ input = input_ids_slice if input_ids_slice is not None else inputs_embeds_slice
+
+ # TODO determine whether we need to trim or not based on the current_pad_len, if negative do not trim. min set to 0
+ trimmed_current_key_values = trim_current_kv(current_key_values, input , key_concat_axis, value_concat_axis, pad_to_left)
+ # slicing in place before sending to concat function to avoid memory spiking.
+ if unpadded_past_kv:
+ concatenated_key_values = tuple(
+ (
+ _concat(unpadded_key, current_key, key_concat_axis) if i not in skip_pad_layers else current_key_values[i][0],
+ _concat(unpadded_value, current_value, value_concat_axis) if i not in skip_pad_layers else current_key_values[i][1]
+ ) for i, ((unpadded_key, unpadded_value), (current_key, current_value)) in enumerate(zip(unpadded_past_kv, trimmed_current_key_values))
+ )
+ return concatenated_key_values
+
+ return trimmed_current_key_values
+
+def trim_current_kv(current_key_values, input, key_concat_axis, value_concat_axis=2, pad_to_left=True, layer_indices_to_perform_trimming = None):
+ """
+ params:
+ 1. current_key_values: current padded/ unpadded kv returned from the model
+ 2. input: tensor with the shape (at dimension 1) of the post trimmed KV$
+ 3. key_concat_axis: the axis to which we want to append the keys
+ 4. value_concat_axis: the axis to which we want to append the values
+ 5. pad_to_left: whether to trim from left or right
+ 6. layer_idx_to_perform_trimming: None if trimming to be done to all layers, else pass a list representing indices on which we want to perform trimming.
+ if the current_pad_length we compute is positive, means the current keys have padding which need to be removed. But if the current_pad_length is negative, we do not remove anything since the user has already trimmed the kv before. (This is possible when this API gets invoked from the update_kv_cache API where the current_kv only refers to the selected draft and valid token, this will be smaller than the input_ids_slice size.
+ """
+ input_length = input.shape[1]
+ # limit the value of this to be non-negative, if it is negative, we assign it to be 0, hence no shifting.
+ if layer_indices_to_perform_trimming is None:
+ current_pad_length = max(0, (current_key_values[0][1].shape[2] - input_length))
+ else:
+ current_pad_length = max(0, (current_key_values[layer_indices_to_perform_trimming[0]][1].shape[2] - input_length))
+ trimmed_kv = []
+ for layer_idx, (current_key, current_value) in enumerate(current_key_values):
+ if layer_indices_to_perform_trimming is None or layer_idx in layer_indices_to_perform_trimming:
+ trimmed_key = _shift(current_key, key_concat_axis, current_pad_length, pad_to_left)
+ trimmed_value = _shift(current_value, value_concat_axis, current_pad_length, pad_to_left)
+ trimmed_kv.append((trimmed_key, trimmed_value))
+ elif layer_idx not in layer_indices_to_perform_trimming:
+ trimmed_kv.append((current_key, current_value))
+ return tuple(trimmed_kv)
+
+def llm_extract_past_kv_at_idx(accepted_token_index_list, current_kv, is_key_transposed=True):
+ '''
+ This function is responsible for extracting the kv cache at indices in the accepted tokens index list.
+ params:
+ 1. accepted_token_index_list: the list containing indices of the accepted tokens
+ 2. current_kv: current KV from which the subset KV is extracted.
+ 3. is_key_transposed: true if key cache is transposed
+ '''
+ n_hidden_layers = len(current_kv)
+ pruned_keys = [[] for _ in range(n_hidden_layers)]
+ pruned_values = [[] for _ in range(n_hidden_layers)]
+
+ for n_layer in range(n_hidden_layers):
+ if is_key_transposed:
+ pruned_keys[n_layer] = current_kv[n_layer][0][:, :, :, accepted_token_index_list]
+ else:
+ pruned_keys[n_layer] = current_kv[n_layer][0][:, :, accepted_token_index_list, :]
+ pruned_values[n_layer] = current_kv[n_layer][1][:, :, accepted_token_index_list, :]
+
+ return tuple((pruned_keys[i], pruned_values[i]) for i in range(n_hidden_layers))
+
+
+class QcUtilityMixin:
+ """
+ UtilityMixin intended to maintain backward compatibility with various versions of transformers package
+ """
+
+ def _extract_past_from_model_output(self, outputs: ModelOutput):
+ """
+ Copied this method from generation/utils.py to ensure compatibility with transformers>=4.49, as modeling_glm.py depends on it
+ """
+ past_key_values = None
+ cache_name = 'past_key_values'
+ if 'past_key_values' in outputs:
+ past_key_values = outputs.past_key_values
+ elif 'mems' in outputs:
+ past_key_values = outputs.mems
+ elif 'past_buckets_states' in outputs:
+ past_key_values = outputs.past_buckets_states
+ elif 'cache_params' in outputs:
+ past_key_values = outputs.cache_params
+ cache_name = 'cache_params'
+
+ return cache_name, past_key_values
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/adaptation.py
new file mode 100644
index 000000000..af3a5c9bf
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/adaptation.py
@@ -0,0 +1,314 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to the GLM-4v model. These adaptations are being done to optimize the model execution on the HTP backend. https://huggingface.co/THUDM/glm-edge-v-2b, https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm/modeling_glm.py"""
+
+import torch
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch.utils.checkpoint
+from torch import nn
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from genai_lib.llm.dev.model_adaptation.common.utils import QcUtilityMixin
+
+# TODO: import the modeling_glm from transformers package when available
+try:
+ from transformers_modules.modeling_glm import (
+ repeat_kv,
+ Cache,
+ DynamicCache,
+ GlmAttention,
+ GlmForCausalLM,
+ GlmModel,
+ GlmRotaryEmbedding,
+ GlmConfig,
+ apply_rotary_pos_emb
+ )
+except ImportError as e:
+ print(f"{e} \n Please instantiate Glm4v with AutoModelForCausalLM first to have transformers_modules cache")
+
+from genai_lib.common.dev.utils import filter_outputs
+
+def _apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ """
+ Perform rotary embedding based on the rope cos and sin values and return the embedded tensor
+ Inputs:
+ x: tensor, the tensor to be rotary-embedded
+ rope_vals: tuple, a tuple with rope values (cos, sin)
+ """
+
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim/2
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim/2
+
+ x_real = x[:, :, :, 0::2] # extract first half elements
+ x_im = x[:, :, :, 1::2] # extract second half elements
+
+ x_prod_real = x_real * rope_real - x_im * rope_im
+ x_prod_im = x_real * rope_im + x_im * rope_real
+
+ x = torch.cat((x_prod_real, x_prod_im), dim=3).view(*x.shape)
+ return x
+
+
+class QcGlmAttention(GlmAttention):
+ def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
+ super(QcGlmAttention, self).__init__(config, layer_idx)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # QC
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config,
+ 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config,
+ 'transposed_key_cache') else False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+
+ # QC
+ if cos.shape[-1] == query_states.shape[-1]:
+ # cos and sin haven't split by half yet
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, partial_rotary_factor=self.partial_rotary_factor
+ )
+ else:
+ query_states = _apply_rope_single(query_states, position_embeddings)
+ key_states = _apply_rope_single(key_states, position_embeddings)
+
+ # QC
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ # QC
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # QC
+ if transposed_key_cache:
+ attn_weights = torch.matmul(query_states, key_states) * self.scaling
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
+
+ # QC
+ if attention_mask is not None: # no matter the length, we just slice it
+ if attention_mask.shape[-1] != value_states.shape[-2]:
+ attention_mask = attention_mask[:, :, :, :value_states.shape[-2]]
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ attn_output = attn_output.view(bsz, q_len, -1)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class QcGlmForCausalLM(GlmForCausalLM, QcUtilityMixin):
+ def __init__(self, config: GlmConfig):
+ super().__init__(config)
+
+ # QC: replace the forward pass order
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ pixel_values: Optional[torch.Tensor] = None, # remove default value for MPP
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: Optional[int] = None, # remove default value for MPP
+ **loss_kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ num_logits_to_keep = num_logits_to_keep if num_logits_to_keep else getattr(self.config, "num_logits_to_keep", 0)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # QC
+ if pixel_values is not None:
+ batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape
+ pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ images=pixel_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
+
+ unfiltered = (logits,) + outputs[1:]
+
+ idx_filter = getattr(self.config, "output_index_filter", None)
+ if idx_filter is not None:
+ filtered = filter_outputs(list(unfiltered), idx_filter)
+ else:
+ filtered = unfiltered
+
+ if not return_dict:
+ return ((loss,) + filtered) if loss is not None else filtered
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+orig_causal_mask = GlmModel._update_causal_mask
+def adapted_update_causal_mask(self, attention_mask, *args, **kwargs):
+ if attention_mask is not None and attention_mask.dim() == 4:
+ return attention_mask
+ else:
+ return orig_causal_mask(self, attention_mask, *args, **kwargs)
+
+
+orig_embedding_fwd = GlmRotaryEmbedding.forward
+def adapted_RotaryEmbedding(self, x, position_ids, *args, **kwargs):
+ if isinstance(position_ids, tuple) and len(position_ids) == 2:
+ return position_ids
+ else:
+ return orig_embedding_fwd(self, x, position_ids, *args, **kwargs)
+
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = value_cache
+
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.value_cache) <= layer_idx:
+ return 0
+ return self.value_cache[layer_idx].shape[-2]
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/modeling_eaglet_glm4v.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/modeling_eaglet_glm4v.py
new file mode 100644
index 000000000..163625955
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/modeling_eaglet_glm4v.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+from torch import nn
+from genai_lib.llm.dev.model_adaptation.glm4v.adaptation import adapted_update_causal_mask, QcGlmAttention
+from genai_lib.llm.eaglet.base_draft_model import BaseDraftModel
+
+try:
+ from transformers_modules import modeling_glm
+except ImportError as e:
+ print(f"{e} \n Please instantiate Glm4v with AutoModelForCausalLM first to have transformers_modules cache")
+
+__all__ = ["GlmEagletDraftModel", "GlmEagletDecoderLayer"]
+
+
+class GlmEagletDraftModel(BaseDraftModel):
+ def __init__(self, config):
+ super().__init__(
+ config,
+ decoder_cls=GlmEagletDecoderLayer,
+ )
+
+ _update_causal_mask = adapted_update_causal_mask
+
+
+class GlmEagletDecoderLayer(modeling_glm.GlmDecoderLayer):
+ def __init__(self, config, layer_idx: int):
+ super().__init__(config=config, layer_idx=layer_idx)
+ if layer_idx == 0:
+ self.input_layernorm = nn.Identity()
+ self.self_attn = QcGlmAttention(config=config, layer_idx=layer_idx)
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/utils.py
new file mode 100644
index 000000000..73db00f65
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/glm4v/utils.py
@@ -0,0 +1,149 @@
+import torch
+from typing import Optional, List
+
+try:
+ from transformers_modules.modeling_glm import GlmRotaryEmbedding
+except ImportError as e:
+ print(f"{e} \n Please instantiate Glm4v with AutoModelForCausalLM first to have transformers_modules cache")
+
+import functools
+
+
+def lmm_update_causal_mask(prepared_1d_attention_mask, input_embeds, max_input_tokens, model_context_len, model_id_or_path, mask_neg):
+ """
+ Precompute causal mask
+ Inputs:
+ prepared_1d_attention_mask: attention mask of shape (batch_size, model_context_length)
+ inputs_embeds: inputs_embeds sent to the model
+ max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ model_context_len: maximum number of tokens that the model can consume in total
+ model_id_or_path: model name or path to pretrained model
+ mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be
+ large enough to drown out tokens that should not be attended to
+ """
+ model = _get_model(model_id_or_path=model_id_or_path)
+
+ cache_position = torch.arange(model_context_len-max_input_tokens, model_context_len, device = input_embeds.device)
+
+ causal_mask = model._update_causal_mask(attention_mask=prepared_1d_attention_mask, input_tensor=input_embeds, output_attentions=True,
+ cache_position=cache_position, past_key_values=None)
+
+ causal_mask = causal_mask.clamp_min(mask_neg)
+ return causal_mask.to(dtype=input_embeds.dtype)
+
+
+def llm_create_position_embeddings(config, position_ids):
+ """
+ Precompute rotary embeddings
+ Inputs:
+ config: model config
+ position_ids: position_ids sent to the model
+ max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ """
+ assert position_ids is not None
+ device = position_ids.device
+ dim = int(config.head_dim * config.partial_rotary_factor)
+ rotary_embed = _get_rotary_embedding(dim, config.max_position_embeddings, config.rope_theta, device)
+
+ cos, sin = rotary_embed(torch.ones(1, device=device), position_ids.cpu())
+ cos, sin = cos.unsqueeze(dim=1), sin.unsqueeze(dim=1)
+ cos = cos[:, :, :, :dim // 2]
+ sin = sin[:, :, :, :dim // 2]
+
+ return (cos.to(device=device), sin.to(device=device))
+
+
+@functools.cache
+def _get_rotary_embedding(dim, max_position_embeddings, base, device):
+ return GlmRotaryEmbedding(dim=dim, max_position_embeddings=max_position_embeddings, base=base, device=device)
+
+
+def lmm_preprocess_inputs(input_ids=None, images=None, inputs_embeds=None, past_key_values=None, boi_token_id=None,
+ embedding_layer=None, vision_model=None):
+ """
+ Preprocess the inputs to accommodate image features to inputs_embeds, by default the first inference shouldn't
+ contain past_key_values to run multimodal generation
+ Inputs:
+ input_ids: input_ids sent to the model
+ images: images sent to the model
+ inputs_embeds: embedded inputs sent to the model
+ past_key_values: past_key_valyes sent to the model
+ boi_token_id: begin of image token id defined in config
+ embedding_layer: the embedding layer used to embed input_ids
+ vision_model: the model used to generate image features from given images
+ Output:
+ input_embeds: embedded inputs
+
+ """
+
+ if past_key_values is None or len(past_key_values) == 0:
+ assert inputs_embeds is None, "inputs_embeds should be generated from the input_ids and image_embeds " \
+ "(if provided) for the first inference (none kvcache mode)"
+ assert input_ids is not None, f"input_ids should be provided for the first inference"
+ if not _is_empty(images): # multi-modality
+ assert embedding_layer is not None
+ assert vision_model is not None
+
+ assert len(input_ids) == len(images), f"{len(input_ids)} {len(images)}"
+ inputs_embeds = embedding_layer(input_ids)
+ multi_flags = [True if boi_token_id in input_id.tolist() else False for input_id in input_ids]
+
+ images = images.to(dtype=inputs_embeds.dtype)
+ batch_size, num_concurrent_media, num_tiles, num_channels, height, width = images.shape
+ images = images.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width)
+ images_features = vision_model(images)
+
+ new_input_embeds = []
+
+ """
+ Process inputs_embeds for each sample and batch them back to tensors
+ boi_token and eoi_token will be trimmed and the position id for the image features will be repeated to
+ the size of the image features
+ """
+ image_count = 0
+ for i in range(len(input_ids)):
+ input_id = input_ids[i].tolist()
+ if multi_flags[i]:
+ boi_token_pos = input_id.index(boi_token_id)
+ assert boi_token_pos >= 0, "begin_of_image not found!"
+ num_image_padding_tokens = input_id.count(boi_token_id)
+ assert (
+ num_image_padding_tokens == images_features[image_count].shape[0]
+ ), f"Wrong image padding token number: {num_image_padding_tokens}"
+
+ new_input_embeds.append(torch.cat(
+ (inputs_embeds[i, :boi_token_pos], images_features[image_count].to(inputs_embeds.device),
+ inputs_embeds[i, boi_token_pos + num_image_padding_tokens:])))
+ image_count += 1
+ else:
+ new_input_embeds.append(inputs_embeds[i])
+ inputs_embeds = torch.stack(new_input_embeds, dim=0)
+ return inputs_embeds
+
+ return embedding_layer(input_ids)
+
+
+def _is_empty(images_list: Optional[List[List[torch.Tensor]]]):
+ if images_list is None or len(images_list) == 0:
+ return True
+ for image_list in images_list:
+ if image_list is not None:
+ return False
+ return True
+
+
+def _get_position_ids(input_ids, device):
+ batch_size, seq_length = input_ids.shape
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
+ return position_ids
+
+
+@functools.cache
+def _get_model(model_id_or_path):
+ from transformers import AutoConfig
+ from transformers_modules.modeling_glm import GlmModel
+ config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=True)
+ config.num_layers = 1
+ config.vision_config['num_hidden_layers'] = 1
+ model = GlmModel(config)
+ return model
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/gpt2/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/gpt2/utils.py
new file mode 100644
index 000000000..c19efa99b
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/gpt2/utils.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import torch
+
+def _update_causal_mask(attention_mask, input_tensor, cache_position, past_key_values):
+ """
+ The helper function which creates the 4d causal mask given a 2d attention mask.
+ The below implementation is a simplified version of the _update_causal_mask found in huggingface transformers for Llama.
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. cache_position : tensor that captures the positions of the valid input ids (query) in the given attention mask
+ 4. past_key_values : past kv sent into the model
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ return attention_mask
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ batch_size = input_tensor.shape[0]
+ min_dtype = torch.finfo(dtype).min
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+
+ sequence_length = input_tensor.shape[1]
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+ return causal_mask
+
+def llm_update_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, mask_neg = -100.0):
+ '''
+
+ This function creates a causal mask (2D) from the 1D attention mask
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ '''
+ cache_position = torch.arange(model_context_len - max_input_tokens, model_context_len, device=input_tensor.device)
+ input_embeds = torch.ones((input_tensor.shape[0], input_tensor.shape[1], 1), device=input_tensor.device)
+ prepared_attention_mask = _update_causal_mask(attention_mask=prepared_1d_attn_mask,
+ input_tensor=input_embeds,
+ cache_position=cache_position,
+ past_key_values=None)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/indus/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/indus/adaptation.py
new file mode 100644
index 000000000..c130c5320
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/indus/adaptation.py
@@ -0,0 +1,453 @@
+#!/usr/bin/env python3
+# ======================================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# ======================================================================================
+
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" This file provides adaptations to the Indus model. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py"""
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.pytorch_utils import Conv1D
+from transformers.models.gpt2 import modeling_gpt2
+from transformers.models.gpt2.modeling_gpt2 import (
+ GPT2Attention,
+ BaseModelOutputWithPastAndCrossAttentions,
+)
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+
+def snr(signal, noisy, eps=1e-10):
+ s = signal.pow(2).mean()
+ n = (signal-noisy).pow(2).mean() + eps
+ return 10 * torch.log10(s / n)
+
+class QcGpt2Attention(GPT2Attention):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ GPT2Attention.__init__(self, config, is_cross_attention, layer_idx)
+ self.use_unpack_qkv = False
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ #QC
+ use_combined_mask_input = self.config.use_combined_mask_input if hasattr(self.config, 'use_combined_mask_input') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ #QC
+ attn_weights = torch.matmul(query, key if transposed_key_cache else key.transpose(-1, -2))
+
+ if self.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if self.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(self.layer_idx + 1)
+
+ if not use_combined_mask_input and not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+
+ #QC
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ #QC
+
+ if self.use_unpack_qkv:
+ query = self.q_proj(hidden_states)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ shape_q = (*query.shape[:-1], -1, self.head_dim)
+ shape_kv = (*key.shape[:-1], -1, self.head_dim)
+
+ query = query.view(shape_q).transpose(1, 2)
+ key = key.view(shape_kv).transpose(1, 2)
+ value = value.view(shape_kv).transpose(1, 2)
+
+ #QC
+ if transposed_key_cache:
+ key = key.transpose(-1, -2)
+
+ #QC
+ if use_cache and return_new_key_value_only:
+ present = (key, value)
+
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=-1 if transposed_key_cache else -2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ #QC
+ if not return_new_key_value_only:
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ if self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
+ else:
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs # a, present, (attentions)
+
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
+ """
+ :param tensor: attn_output tensor and its shape is (bsz, num_attn_heads, seq_len, head_dim)
+
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
+ """
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
+ # QC
+ def unpack_qkv(self):
+ self.use_unpack_qkv = True
+
+ device = self.c_attn.weight.device
+ self.q_proj = Conv1D(self.embed_dim, self.embed_dim).to(device)
+ self.k_proj = Conv1D(self.embed_dim, self.embed_dim).to(device)
+ self.v_proj = Conv1D(self.embed_dim, self.embed_dim).to(device)
+
+ self.q_proj.weight.data.copy_(self.c_attn.weight[:, : self.embed_dim])
+ self.k_proj.weight.data.copy_(self.c_attn.weight[:, self.embed_dim : 2 * self.embed_dim])
+ self.v_proj.weight.data.copy_(self.c_attn.weight[:, 2 * self.embed_dim :])
+
+ self.q_proj.bias.data.copy_(self.c_attn.bias[: self.embed_dim])
+ self.k_proj.bias.data.copy_(self.c_attn.bias[self.embed_dim : 2 * self.embed_dim])
+ self.v_proj.bias.data.copy_(self.c_attn.bias[2 * self.embed_dim :])
+
+ #del self.c_attn
+
+
+def GPT2Model_forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+
+ # QC
+ use_combined_mask_input = self.config.use_combined_mask_input if hasattr(self.config, 'use_combined_mask_input') else False
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ if position_ids is None:
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+
+ # Attention mask.
+ _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
+ attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None and not use_combined_mask_input else attention_mask
+ if self._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif _use_sdpa:
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask=attention_mask,
+ input_shape=(batch_size, input_shape[-1]),
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_length,
+ )
+ else:
+ if attention_mask is not None and not use_combined_mask_input:
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ if _use_sdpa:
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+ elif not self._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ outputs = self._gradient_checkpointing_func(
+ block.__call__,
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ use_cache,
+ output_attentions,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = value_cache
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.value_cache) <= layer_idx:
+ return 0
+ return self.value_cache[layer_idx].shape[-2]
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/indus/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/indus/utils.py
new file mode 100644
index 000000000..e57eaba9f
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/indus/utils.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that the pipeline would need to work with the adaptations made to LLaMa model. """
+
+from genai_lib.llm.dev.model_adaptation.gpt2.utils import llm_update_causal_mask
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/jais/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/jais/adaptation.py
new file mode 100644
index 000000000..18a2ce52d
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/jais/adaptation.py
@@ -0,0 +1,457 @@
+#!/usr/bin/env python3
+# ======================================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# ======================================================================================
+
+# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" This file provides adaptations to the Jais model. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py"""
+
+import math
+import os
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+)
+from transformers.pytorch_utils import Conv1D
+from transformers_modules.modeling_jais import JAISBlock, JAISModel, JAISAttention, AlibiPositionEmbeddingLayer, \
+ JAISPreTrainedModel
+
+from genai_lib.llm.dev.model_adaptation.jais.utils import llm_create_position_embeddings
+
+
+def QCAlibiPositionEmbeddingLayer_forward(
+ self,
+ position_id,
+ cached_qk_len,
+ use_relative_position_ids=True
+):
+ if use_relative_position_ids:
+ relative_position = position_id
+ else:
+ relative_position = llm_create_position_embeddings(position_id, cached_qk_len)
+
+ relative_position = relative_position.unsqueeze(0).expand(self.num_heads, -1, -1)
+
+ if self.alibi_scaling is None:
+ scale = 1.0
+ elif self.alibi_scaling.get("factor") is not None:
+ scale = self.alibi_scaling["factor"]
+ elif relative_position.shape[-1] > self.alibi_scaling["train_seq_len"]:
+ scale = relative_position.shape[-1] / self.alibi_scaling["train_seq_len"]
+ else:
+ scale = 1.0
+
+ alibi = (self.slopes.to(position_id.device) / -scale).unsqueeze(1) * relative_position
+ return alibi
+
+
+class QCJAISAttention(JAISAttention):
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ JAISAttention.__init__(self, config, is_cross_attention, layer_idx)
+
+ # QC
+ self.transposed_key_cache = (
+ config.transposed_key_cache if hasattr(config, "transposed_key_cache") else False
+ )
+ self.return_new_key_value_only = (
+ config.return_new_key_value_only if hasattr(config, "return_new_key_value_only") else False
+ )
+ self.use_combined_mask_input = (
+ config.use_combined_mask_input if hasattr(config, "use_combined_mask_input") else False
+ )
+ self.use_unpack_qkv = False
+
+ # QC
+ def unpack_qkv(self):
+ self.use_unpack_qkv = True
+
+ device = self.c_attn.weight.device
+ self.q_proj = Conv1D(self.embed_dim, self.embed_dim).to(device)
+ self.k_proj = Conv1D(self.embed_dim, self.embed_dim).to(device)
+ self.v_proj = Conv1D(self.embed_dim, self.embed_dim).to(device)
+
+ self.q_proj.weight.data.copy_(self.c_attn.weight[:, : self.embed_dim])
+ self.k_proj.weight.data.copy_(self.c_attn.weight[:, self.embed_dim: 2 * self.embed_dim])
+ self.v_proj.weight.data.copy_(self.c_attn.weight[:, 2 * self.embed_dim:])
+
+ self.q_proj.bias.data.copy_(self.c_attn.bias[: self.embed_dim])
+ self.k_proj.bias.data.copy_(self.c_attn.bias[self.embed_dim: 2 * self.embed_dim])
+ self.v_proj.bias.data.copy_(self.c_attn.bias[2 * self.embed_dim:])
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None, position_bias=None):
+ # QC
+ if self.transposed_key_cache:
+ attn_weights = torch.matmul(query, key)
+ else:
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
+
+ if self.scale_attn_weights:
+ attn_weights = attn_weights / torch.full(
+ [], value.size(-1) ** self.attn_scale_power, dtype=attn_weights.dtype, device=attn_weights.device
+ )
+
+ # Layer-wise attention scaling
+ if self.scale_attn_by_inverse_layer_idx:
+ attn_weights = attn_weights / float(self.layer_idx + 1)
+
+ # QC
+ if not self.use_combined_mask_input and not self.is_cross_attention:
+ # if only "normal" attention layer implements causal mask
+ query_length, key_length = query.size(-2), key.size(-2)
+ causal_mask = self.bias[:, :, key_length - query_length: key_length, :key_length]
+ mask_value = torch.finfo(attn_weights.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ if position_bias is not None:
+ attn_weights += position_bias.type_as(attn_weights).unsqueeze(0)
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
+ attn_weights = attn_weights.type(value.dtype)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ position_bias: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `JAISAttention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else: # QC
+ if self.use_unpack_qkv:
+ query = self.q_proj(hidden_states)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ # QC
+ if self.transposed_key_cache:
+ key = key.transpose(2, 3)
+
+ # QC
+ if self.return_new_key_value_only:
+ present = (key, value) if use_cache is True else None
+
+ if layer_past is not None:
+ key_seq_dim = -1 if self.transposed_key_cache else -2 # QC
+ past_key, past_value = layer_past
+ key = torch.cat((past_key, key), dim=key_seq_dim) # QC
+ value = torch.cat((past_value, value), dim=-2)
+
+ if not self.return_new_key_value_only: # QC
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ if self.reorder_and_upcast_attn:
+ attn_output, attn_weights = self._upcast_and_reordered_attn(
+ query, key, value, attention_mask, head_mask, position_bias
+ )
+ else:
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask, position_bias)
+
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
+ attn_output = self.c_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs # a, present, (attentions)
+
+
+def QCJAISModel_forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ # QC
+ use_relative_position_ids = self.config.use_relative_position_ids if hasattr(self.config, 'use_relative_position_ids') else False
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # QC
+ use_combined_mask_input = self.config.use_combined_mask_input if hasattr(self.config,
+ 'use_combined_mask_input') else False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ # QC
+ if not use_relative_position_ids and position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].size(-2)
+ if position_ids is None:
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ # QC: JAISAttention mask.
+ if attention_mask is not None and not use_combined_mask_input:
+ if batch_size <= 0:
+ raise ValueError("batch_size has to be defined and > 0")
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # head_mask has shape n_layer x batch x n_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+ if self.wpe is not None:
+ position_embeds = self.wpe(position_ids)
+ hidden_states = inputs_embeds + position_embeds
+ else:
+ hidden_states = inputs_embeds
+ hidden_states *= torch.tensor(
+ float(self.embeddings_scale), dtype=hidden_states.dtype, device=hidden_states.device
+ )
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ if self.relative_pe is not None:
+ # QC
+ cached_kv_length = 0
+ cached_kv = past_key_values[0]
+ if cached_kv is not None:
+ cached_kv_length = cached_kv[1].shape[-2] # QC
+ # QC
+ position_bias = self.relative_pe(position_ids, cached_kv_length, use_relative_position_ids)
+ else:
+ position_bias = None
+
+ output_shape = input_shape + (hidden_states.size(-1),)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+ # Model parallel
+ if self.model_parallel:
+ torch.cuda.set_device(hidden_states.device)
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
+ if layer_past is not None:
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
+ # Ensure that attention_mask is always on the same device as hidden_states
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(hidden_states.device)
+ if isinstance(head_mask, torch.Tensor):
+ head_mask = head_mask.to(hidden_states.device)
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, use_cache, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ None,
+ attention_mask,
+ head_mask[i],
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=attention_mask,
+ head_mask=head_mask[i],
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ position_bias=position_bias,
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
+
+ # Model Parallel: If it's the last layer for that device, put things on the next device
+ if self.model_parallel:
+ for k, v in self.device_map.items():
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
+
+ hidden_states = self.ln_f(hidden_states)
+ hidden_states = hidden_states.view(output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
+
+
+orig_embedding_fwd = AlibiPositionEmbeddingLayer.forward
+def adapted_AlibiPositionEmbedding(self, *args, **kwargs):
+ if isinstance(args[0], torch.Tensor):
+ return QCAlibiPositionEmbeddingLayer_forward(self, *args, **kwargs)
+ else:
+ return orig_embedding_fwd(self, *args, **kwargs)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/jais/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/jais/utils.py
new file mode 100644
index 000000000..546062cf3
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/jais/utils.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import torch
+from genai_lib.llm.dev.model_adaptation.gpt2.utils import llm_update_causal_mask
+
+
+def llm_create_position_embeddings(position_id, past_length):
+ position_id = position_id - position_id.min() + past_length
+ context_position = position_id.view(-1, 1)
+ memory_position = torch.concat([torch.arange(past_length, device=position_id.device).view(1, -1), position_id], dim=1)
+ relative_position = memory_position - context_position
+ relative_position = torch.abs(relative_position)
+ return relative_position
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/adaptation.py
new file mode 100644
index 000000000..62d14a5ef
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/adaptation.py
@@ -0,0 +1,473 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to the LLaMa model. These adaptations are being done to
+optimize the model execution on the HTP backend.
+https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py"""
+
+""" This file provides adaptations to the LLaMa model. These adaptations are being done to optimize the model execution on the HTP backend. """
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+from importlib.metadata import version
+from importlib import util
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from transformers.modeling_outputs import CausalLMOutputWithPast
+
+from transformers import LlamaForCausalLM
+from transformers import cache_utils
+from transformers.models.llama import modeling_llama
+from transformers.models.llama.modeling_llama import (
+ repeat_kv,
+ Cache,
+ DynamicCache,
+ LlamaAttention,
+ LlamaConfig,
+ apply_rotary_pos_emb,
+)
+from genai_lib.common.dev.utils import AIMET_TORCH_AVAILABLE
+
+if AIMET_TORCH_AVAILABLE:
+ from genai_lib.llm.long_context_utils import AnchorUpdaterKeySecond
+
+from genai_lib.common.dev.utils import filter_outputs
+from genai_lib.common.dev.utils import is_transformers_greater_or_equal
+
+from transformers.utils import logging
+logger = logging.get_logger(__name__)
+
+
+def _apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ '''
+ Based on FacebookResearch's llama, provided by Carl
+ '''
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim/2
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim/2
+
+ # TODO: Why HF uses different coordinates from the paper
+ x_real = x[:,:,:,:x.shape[-1]//2] # extract first half elements
+ x_im = x[:,:,:,x.shape[-1]//2:] # extract second half elements
+
+ x_prod_real = x_real*rope_real - x_im * rope_im
+ x_prod_im = x_real*rope_im + x_im*rope_real
+
+ # TODO: HF need to uses different interleaving
+ x = torch.cat((x_prod_real,x_prod_im),dim=3).view(*x.shape)
+ return x
+
+
+class QcLlamaAttention(LlamaAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper
+ We override the init and initialize the anchor updater.
+ The user can optionally pass in the alpha value to initialize the anchor_updater through the model config.
+ """
+
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ super(QcLlamaAttention, self).__init__(config, layer_idx)
+
+ # We only initialize anchor_updater when the anchor_alpha is present in the config
+ if getattr(config, "anchor_alpha", None) is not None:
+ if not AIMET_TORCH_AVAILABLE:
+ raise ValueError("Long Context is currently only supported in AIMET Torch")
+ self.anchor_updater = AnchorUpdaterKeySecond(alpha=config.anchor_alpha)
+
+ # We only use "torch.where(attention_mask, input, min(input)-20)" sequence when the enable_masked_softmax is present in the config
+ self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[torch.Tensor]] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ #QC Adaptation
+ # HF made `past_key_value` --> `past_key_values`; Release notes: https://github.com/huggingface/transformers/releases/tag/v4.56.0#:~:text=Harmonize%20past_key_value%20to%20past_key_valueS%20everywhere, corresponding PR page- https://github.com/huggingface/transformers/pull/39956
+ past_key_values = past_key_value if past_key_value is not None else past_key_values
+ valid_token_mask = kwargs.get('valid_token_mask', None)
+ anchor_buffer = kwargs.get('anchor_buffer', None)
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ bsz, q_len, _ = hidden_states.size()
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ if isinstance(position_ids, (tuple, list)): # QC
+ position_embeddings = position_ids
+ else:
+ position_embeddings = self.rotary_emb(value_states, position_ids)
+ cos, sin = position_embeddings
+ if cos.shape[-1] == query_states.shape[-1]:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ else:
+ query_states = _apply_rope_single(query_states, position_embeddings)
+ key_states = _apply_rope_single(key_states, position_embeddings)
+
+ # we cache the un-transposed keys before we pass into the combined scoring model.
+ untransposed_keys = key_states
+
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_values is not None:
+ assert isinstance(past_key_values, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ "num_key_value_heads": self.config.num_key_value_heads,
+ "head_dim": self.head_dim,
+ }
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # key_states is the concatenated keys
+ # past_key_values.key_cache[layer_idx] is the new key states
+ # we invoke the scoring model and insert the outputs, anchor and the evict indices into the past_key_values (DynamicCache object) as additional attributes.
+ if valid_token_mask is not None and anchor_buffer is not None:
+ anchor_buffer = anchor_buffer[self.layer_idx]
+ anchor = self.anchor_updater(new_keys = untransposed_keys, valid_token_mask=valid_token_mask, old_anchor = anchor_buffer)
+
+ insert_meta_info_to_pastkv(past_key_values = past_key_values, meta_info_bundle={"anchor_buffer":anchor}, layer_idx=self.layer_idx)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache:
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != value_states.shape[-2]:
+ attention_mask = attention_mask[:, :, :, : value_states.shape[-2]]
+ if self.enable_masked_softmax:
+ attn_weights_min, _ = torch.min(attn_weights, dim=-1, keepdim=True)
+ minus_value = -20
+ attn_weights = torch.where(attention_mask==0, attn_weights, attn_weights_min + minus_value)
+ else:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ # handle version-specific return
+ if version('transformers') >= '4.48.0':
+ return attn_output, attn_weights
+ else:
+ return attn_output, attn_weights, past_key_values
+
+def DynamicLayer_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the cache as: https://github.com/huggingface/transformers/blob/d79b2d981f28b2730d402244ac3c2e9a8c054eee/src/transformers/cache_utils.py#L98
+ if self.keys is None:
+ self.keys = key_states
+ self.values = value_states
+ return self.keys , self.values
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ cache_position = cache_kwargs.get('cache_position')
+ num_key_value_heads = cache_kwargs.get('num_key_value_heads')
+ head_dim = cache_kwargs.get('head_dim')
+ key_cat_dim = -1 if transposed_key_cache else -2
+ # if the size of past key cache passed is smaller in value than the last position where the new kv is to be inserted
+ # [in case when Cache position determined automatically by HF] (Ctx_len+ARN), then we want to perform concat and not do scattering.
+ if self.values.shape[-2] <= cache_position[-1]:
+ key_cache = torch.cat([self.keys, key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.values, value_states], dim=-2)
+ else:
+ # the cache_position passed in as model i/p by user is a 1d tensor reflecting the positions
+ # from valid_kv_end to valid_kv_end+ARN, we convert this into the indices for scattering. [# bsz, num_key_value_heads, head_dim, seq_len]-> works for transposed keys
+ indices = cache_position.view(1, 1, 1, -1).expand(value_states.shape[0], num_key_value_heads, head_dim, cache_position.shape[-1])
+
+ value_cache = self.values.scatter(dim=-2, index=indices.transpose(-1,-2), src=value_states)
+
+ indices = indices.transpose(-1, -2) if key_cat_dim== -2 else indices
+ key_cache = self.keys.scatter(dim=key_cat_dim, index=indices, src=key_states)
+
+
+ if return_new_key_value_only:
+ self.keys = key_states
+ self.values = value_states
+ else:
+ self.keys = key_cache
+ self.values = value_cache
+ return key_cache, value_cache
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ cache_position = cache_kwargs.get('cache_position')
+ num_key_value_heads = cache_kwargs.get('num_key_value_heads')
+ head_dim = cache_kwargs.get('head_dim')
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ # if the size of past key cache passed is smaller in value than the last position where the new kv is to be inserted
+ # [in case when Cache position determined automatically by HF] (Ctx_len+ARN), then we want to perform concat and not do scattering.
+ if self.value_cache[layer_idx].shape[-2] <= cache_position[-1]:
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ else:
+ # the cache_position passed in as model i/p by user is a 1d tensor reflecting the positions
+ # from valid_kv_end to valid_kv_end+ARN, we convert this into the indices for scattering. [# bsz, num_key_value_heads, head_dim, seq_len]-> works for transposed keys
+ indices = cache_position.view(1, 1, 1, -1).expand(value_states.shape[0], num_key_value_heads, head_dim, cache_position.shape[-1])
+
+ value_cache = self.value_cache[layer_idx].scatter(dim=-2, index=indices.transpose(-1,-2), src=value_states)
+
+ indices = indices.transpose(-1, -2) if key_cat_dim== -2 else indices
+ key_cache = self.key_cache[layer_idx].scatter(dim=key_cat_dim, index=indices, src=key_states)
+
+
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = value_cache
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.value_cache) <= layer_idx:
+ return 0
+ return self.value_cache[layer_idx].shape[-2]
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
+
+if not is_transformers_greater_or_equal("4.53.0"):
+ orig_causal_mask = modeling_llama.LlamaModel._update_causal_mask
+ def adapted_update_causal_mask(self, attention_mask, *args, **kwargs):
+ if attention_mask is not None and attention_mask.dim() == 4:
+ return attention_mask
+ else:
+ return orig_causal_mask(self, attention_mask, *args, **kwargs)
+
+orig_embedding_fwd = modeling_llama.LlamaRotaryEmbedding.forward
+def adapted_RotaryEmbedding(self, x, position_ids, *args, **kwargs):
+ if isinstance(position_ids, tuple) and len(position_ids)==2:
+ return position_ids
+ else:
+ return orig_embedding_fwd(self, x, position_ids, *args, **kwargs)
+
+class QcLlamaForCausalLM(LlamaForCausalLM):
+ """
+ Subclass of original LlamaForCausalLM. This is needed to serve two purposes:
+
+ 1. Starting from transformers version 4.45.0, the num_logits_to_keep argument is now required argument.
+ Consequently, the prepared static graph will always include this additional argument.
+ To maintain compatibility with our existing pipelines, we create a new class that inherits from
+ LlamaForCausalLM. In this new class, we redefine the forward method without the num_logits_to_keep
+ argument and in inside the forward we infer the num_logits_to_keep from the config and then call the superclass's forward method.
+
+ 2. For the Long Context scoring model within the LLM, we need to pass two additional arguments:
+ anchor_buffer and valid_token_mask. These can be provided as keyword arguments (introduced in transformers version 4.47.0).
+ However, this approach is incompatible with Onnx export, as Onnx does not support keyword arguments when creating the onnx graph.
+ Therefore, we pass valid_token_mask and anchor_buffer as model inputs to LlamaForCausalLM,
+ which in turn get recognized through keyword arguments in the downstream blocks.
+ This is similar to how the DynamicCache object is not traced by jit.trace at the topmost level in LlamaForCausalLM.
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ if getattr(config, "input_tokens_per_inference", None) is not None:
+ self.register_buffer(name='cache_tensor', tensor=torch.arange(config.input_tokens_per_inference))
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: Optional[int] = None,
+ valid_token_mask: Optional[torch.Tensor]=None,
+ anchor_buffer: Optional[torch.Tensor]=None,
+ cache_index: Optional[torch.Tensor]=None,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ num_logits_to_keep = num_logits_to_keep if num_logits_to_keep else getattr(self.config, "num_logits_to_keep", 0)
+ return_dict = return_dict if return_dict else False
+ if cache_index is not None:
+ assert hasattr(self, "cache_tensor"), "QcLlamaForCausal doesn't have attribute \"cache_tensor\", " \
+ "check if \"input_tokens_per_inference\" is specified in model config"
+ cache_position = cache_index + self.cache_tensor
+
+ if type(past_key_values) == tuple and version('transformers') >= '4.48.0':
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ outputs = super().forward(
+ input_ids = input_ids,
+ attention_mask= attention_mask,
+ position_ids= position_ids,
+ past_key_values= past_key_values,
+ inputs_embeds= inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ num_logits_to_keep=num_logits_to_keep,
+ valid_token_mask=valid_token_mask,
+ anchor_buffer=anchor_buffer,
+ **kwargs)
+
+ if version('transformers') >= '4.48.0':
+ if return_dict:
+ assert type(outputs.past_key_values) != tuple
+ past_key_values_output = DynamicCache.to_legacy_cache(outputs.past_key_values)
+ outputs.past_key_values = past_key_values_output
+ else:
+ new_outputs = []
+ for item in outputs:
+ if isinstance(item, DynamicCache):
+ past_key_values_output = DynamicCache.to_legacy_cache(item)
+ new_outputs.append(past_key_values_output)
+ else:
+ new_outputs.append(item)
+ outputs = tuple(new_outputs)
+
+ if hasattr(self.config, "output_index_filter"):
+ return filter_outputs(outputs, self.config.output_index_filter)
+ return outputs
+
+def insert_meta_info_to_pastkv(
+ past_key_values,
+ meta_info_bundle: {},
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ This function adds two new model outputs to the DynamicCache object, eliminating the need to write a new adaptation for the additional output.
+
+ params:
+ past_key_values: The DynamicCache object
+ evict_info_bundle: A dictionary containing the additional information to attach to the DynamicCache object.
+ layer_idx: An attribute pointing to a list, where layer_idx corresponds to the data of the given layer.
+ cache_kwargs: Additional keyword arguments for the cache.
+
+ """
+ for key, value in meta_info_bundle.items():
+ if getattr(past_key_values, key, None) is None:
+ setattr(past_key_values, key, [])
+
+ key_attr = getattr(past_key_values, key, None)
+ # Update the cache
+ if len(key_attr) <= layer_idx:
+ key_attr.append(value)
+ else:
+ key_attr[layer_idx] = value
+
+def DynamicCache_to_legacy_cache(self):
+
+ """
+ Converts the DynamicCache instance to its equivalent in the legacy cache format for backward compatibility.
+
+ The past_key_values passed into the model as input is a tuple.
+ The LlamaModel converts it into a Cache object if it isn't one already. Within the model, past_key_values flow as a DynamicCache object.
+ Just before returning the output, the LlamaModel converts the DynamicCache back to the legacy cache (tuple format).
+ Since we added new attributes to our Cache object, we need to ensure they are included as additional entries in the returned tuple."""
+
+ legacy_cache = ()
+ for layer_idx in range(len(self)):
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+ if "anchor_buffer" in dir(self):
+ return (legacy_cache, self.anchor_buffer)
+ return legacy_cache
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/lorav1_forward_passes.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/lorav1_forward_passes.py
new file mode 100644
index 000000000..2599f36e1
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/lorav1_forward_passes.py
@@ -0,0 +1,599 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to the llama model. These adaptations are being
+done to optimize the model execution on the HTP backend.
+https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py"""
+
+import types
+import math
+from typing import List, Optional, Tuple, Union, Dict
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache, StaticCache
+from transformers.modeling_attn_mask_utils import AttentionMaskConverter
+from transformers.modeling_flash_attention_utils import _flash_attention_forward
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from transformers.modeling_utils import PreTrainedModel
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
+from transformers.utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from transformers.models.llama.configuration_llama import LlamaConfig
+
+from transformers.models.llama import modeling_llama
+from transformers.models.llama.modeling_llama import (
+ repeat_kv,
+ Cache,
+ DynamicCache,
+ LlamaAttention,
+ LlamaMLP,
+ LlamaConfig,
+ apply_rotary_pos_emb,
+)
+
+from genai_lib.llm.dev.model_adaptation.llama.adaptation import _apply_rope_single
+from aimet_torch.nn.modules import custom as aimet_ops
+
+logger = logging.get_logger(__name__)
+
+class LlamaMLPLora(LlamaMLP):
+ def __init__(self, config: LlamaConfig):
+ super().__init__(config)
+ self.lora_add_gate = aimet_ops.Add()
+ self.lora_add_up = aimet_ops.Add()
+ self.lora_add_down = aimet_ops.Add()
+ self.lora_matmul_a_gate = aimet_ops.MatMul()
+ self.lora_matmul_b_gate = aimet_ops.MatMul()
+ self.lora_matmul_a_up = aimet_ops.MatMul()
+ self.lora_matmul_b_up = aimet_ops.MatMul()
+ self.lora_matmul_a_down = aimet_ops.MatMul()
+ self.lora_matmul_b_down = aimet_ops.MatMul()
+ self.lora_multiply_gate = aimet_ops.Multiply()
+ self.lora_multiply_up = aimet_ops.Multiply()
+ self.lora_multiply_down = aimet_ops.Multiply()
+ self.lora_transpose_a_gate = aimet_ops.Permute()
+ self.lora_transpose_b_gate = aimet_ops.Permute()
+ self.lora_transpose_a_up = aimet_ops.Permute()
+ self.lora_transpose_b_up = aimet_ops.Permute()
+ self.lora_transpose_a_down = aimet_ops.Permute()
+ self.lora_transpose_b_down = aimet_ops.Permute()
+
+ def forward(self, x,
+ lora_scale: float= None,
+ lora_weights: Optional[Dict[str, torch.Tensor]] = None,
+ layer_idx: int = None
+ ):
+
+ gate_lora_A = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.mlp.gate_proj.lora_A.weight") if lora_weights else None
+ gate_lora_B = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.mlp.gate_proj.lora_B.weight") if lora_weights else None
+ down_lora_A = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.mlp.down_proj.lora_A.weight") if lora_weights else None
+ down_lora_B = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.mlp.down_proj.lora_B.weight") if lora_weights else None
+ up_lora_A = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.mlp.up_proj.lora_A.weight") if lora_weights else None
+ up_lora_B = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.mlp.up_proj.lora_B.weight") if lora_weights else None
+
+ gate_proj_out = self.gate_proj(x)
+ if gate_lora_A is not None:
+ lora_gate_proj = self.lora_matmul_a_gate(x, self.lora_transpose_a_gate(gate_lora_A, [1, 0]))
+ lora_gate_proj = self.lora_multiply_gate(lora_gate_proj, lora_scale)
+ lora_gate_proj = self.lora_matmul_b_gate(lora_gate_proj, self.lora_transpose_b_gate(gate_lora_B, [1, 0]))
+ gate_proj_out = self.lora_add_gate(gate_proj_out, lora_gate_proj)
+
+ up_proj_out = self.up_proj(x)
+ if up_lora_A is not None:
+ lora_up_proj = self.lora_matmul_a_up(x, self.lora_transpose_a_up(up_lora_A, [1, 0]))
+ lora_up_proj = self.lora_multiply_up(lora_up_proj, lora_scale)
+ lora_up_proj = self.lora_matmul_b_up(lora_up_proj, self.lora_transpose_b_up(up_lora_B, [1, 0]))
+ up_proj_out = self.lora_add_up(up_proj_out, lora_up_proj)
+
+ act_fn_out = self.act_fn(gate_proj_out)
+ pre_down_proj = act_fn_out * up_proj_out
+ down_proj_out = self.down_proj(pre_down_proj)
+ if up_lora_A is not None:
+ lora_down_proj = self.lora_matmul_a_down(pre_down_proj, self.lora_transpose_a_down(down_lora_A, [1, 0]))
+ lora_down_proj = self.lora_multiply_down(lora_down_proj, lora_scale)
+ lora_down_proj = self.lora_matmul_b_down(lora_down_proj, self.lora_transpose_b_down(down_lora_B, [1, 0]))
+ down_proj_out = self.lora_add_down(down_proj_out, lora_down_proj)
+
+ return down_proj_out
+
+class LlamaAttentionLora(LlamaAttention):
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.lora_add_q = aimet_ops.Add()
+ self.lora_add_k = aimet_ops.Add()
+ self.lora_add_v = aimet_ops.Add()
+ self.lora_add_o = aimet_ops.Add()
+ self.lora_matmul_a_q = aimet_ops.MatMul()
+ self.lora_matmul_b_q = aimet_ops.MatMul()
+ self.lora_matmul_a_k = aimet_ops.MatMul()
+ self.lora_matmul_b_k = aimet_ops.MatMul()
+ self.lora_matmul_a_v = aimet_ops.MatMul()
+ self.lora_matmul_b_v = aimet_ops.MatMul()
+ self.lora_matmul_a_o = aimet_ops.MatMul()
+ self.lora_matmul_b_o = aimet_ops.MatMul()
+ self.lora_multiply_q = aimet_ops.Multiply()
+ self.lora_multiply_k = aimet_ops.Multiply()
+ self.lora_multiply_v = aimet_ops.Multiply()
+ self.lora_multiply_o = aimet_ops.Multiply()
+ self.lora_transpose_a_q = aimet_ops.Permute()
+ self.lora_transpose_b_q = aimet_ops.Permute()
+ self.lora_transpose_a_k = aimet_ops.Permute()
+ self.lora_transpose_b_k = aimet_ops.Permute()
+ self.lora_transpose_a_v = aimet_ops.Permute()
+ self.lora_transpose_b_v = aimet_ops.Permute()
+ self.lora_transpose_a_o = aimet_ops.Permute()
+ self.lora_transpose_b_o = aimet_ops.Permute()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ lora_scale: float= None,
+ lora_weights: Optional[Dict[str, torch.Tensor]] = None,
+ layer_idx: int = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ #QC
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ q_lora_A = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.q_proj.lora_A.weight") if lora_weights else None
+ q_lora_B = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.q_proj.lora_B.weight") if lora_weights else None
+ k_lora_A = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.k_proj.lora_A.weight") if lora_weights else None
+ k_lora_B = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.k_proj.lora_B.weight") if lora_weights else None
+ v_lora_A = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.v_proj.lora_A.weight") if lora_weights else None
+ v_lora_B = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.v_proj.lora_B.weight") if lora_weights else None
+ o_lora_A = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj.lora_A.weight") if lora_weights else None
+ o_lora_B = lora_weights.get(
+ f"base_model.model.model.layers.{layer_idx}.self_attn.o_proj.lora_B.weight") if lora_weights else None
+
+ if q_lora_A is not None:
+ lora_q = self.lora_matmul_a_q(hidden_states, self.lora_transpose_a_q(q_lora_A, [1, 0]))
+ lora_q = self.lora_multiply_q(lora_q, lora_scale)
+ lora_q = self.lora_matmul_b_q(lora_q, self.lora_transpose_b_q(q_lora_B, [1, 0]))
+ query_states = self.lora_add_q(query_states, lora_q)
+
+ if k_lora_A is not None:
+ lora_k = self.lora_matmul_a_k(hidden_states, self.lora_transpose_a_k(k_lora_A, [1, 0]))
+ lora_k = self.lora_multiply_k(lora_k, lora_scale)
+ lora_k = self.lora_matmul_b_k(lora_k, self.lora_transpose_b_k(k_lora_B, [1, 0]))
+ key_states = self.lora_add_k(key_states, lora_k)
+
+ if v_lora_A is not None:
+ lora_v = self.lora_matmul_a_v(hidden_states, self.lora_transpose_a_v(v_lora_A, [1, 0]))
+ lora_v = self.lora_multiply_v(lora_v, lora_scale)
+ lora_v = self.lora_matmul_b_v(lora_v, self.lora_transpose_b_v(v_lora_B, [1, 0]))
+ value_states = self.lora_add_v(value_states, lora_v)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ if isinstance(position_ids, (tuple, list)): # QC
+ position_embeddings = position_ids
+ else:
+ position_embeddings = self.rotary_emb(value_states, position_ids)
+ cos, sin = position_embeddings
+ query_states = _apply_rope_single(query_states, position_embeddings)
+ key_states = _apply_rope_single(key_states, position_embeddings)
+
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_value is not None:
+ assert isinstance(past_key_value, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache:
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if o_lora_A is not None:
+ lora_o = self.lora_matmul_a_o(attn_output, self.lora_transpose_a_o(o_lora_A, [1, 0]))
+ lora_o = self.lora_multiply_o(lora_o, lora_scale)
+ lora_o = self.lora_matmul_b_o(lora_o, self.lora_transpose_b_o(o_lora_B, [1, 0]))
+ attn_output = self.lora_add_o(attn_output, lora_o)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def forward_decoder(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ lora_scale: float= None,
+ lora_weights: Optional[Dict[str, torch.Tensor]] = None,
+ layer_idx: int = None,
+ **kwargs,
+) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ kwargs (`dict`, *optional*):
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
+ into the model
+ """
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ lora_scale=lora_scale,
+ lora_weights=lora_weights,
+ layer_idx=layer_idx,
+ **kwargs,
+ )
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states,
+ lora_scale=lora_scale,
+ lora_weights=lora_weights,
+ layer_idx=layer_idx)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+def forward_model(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ lora_scale: Optional[float] = None,
+ lora_weights: Optional[Dict[str, torch.Tensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ return_legacy_cache = False
+ if (
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+ logger.warning_once(
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
+ )
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+ hidden_states = inputs_embeds
+
+ # create position embeddings to be shared across the decoder layers
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for layer_idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ lora_scale=lora_scale,
+ lora_weights=lora_weights,
+ layer_idx=layer_idx
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ lora_scale=lora_scale,
+ lora_weights=lora_weights,
+ layer_idx=layer_idx
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+def forward_llama(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ lora_weights: Optional[Dict[str, torch.Tensor]] = None,
+ **kwargs
+) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
+
+ >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ lora_scale=lora_scale,
+ lora_weights=lora_weights
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ logits_to_keep = logits_to_keep if logits_to_keep else getattr(self.config, "logits_to_keep", 0)
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/onnx_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/onnx_utils.py
new file mode 100644
index 000000000..c484c5748
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/onnx_utils.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import torch
+from typing import Dict, Any
+from genai_lib.common.onnxruntime_utils import OutputBufferCreator
+from transformers import PretrainedConfig
+
+def map_llama_onnx_flattened_name_to_tensor(onnx_name: str,
+ tensor_dict: Dict[str, Any],
+ is_output: bool,
+ delim: str = "_") -> torch.Tensor:
+ """
+ Maps a flattened ONNX input/output name to the corresponding PyTorch tensor in a condensed prepared_inputs/output_buffer, for llama-like models.
+
+ params:
+ onnx_name (str): The flattened name of the ONNX input/output.
+ tensor_dict (Dict[str, Any]): The prepared_inputs/output_buffer.
+ is_output (bool): Whether the ONNX name is an output. (Unused for llama-like models)
+ delim (str): The delimiter used in the flattened name.
+ returns:
+ torch.Tensor: The corresponding tensor in the prepared_inputs/output_buffer.
+ """
+ parts = onnx_name.split(delim)
+ if parts[0] == "position":
+ return tensor_dict["position_ids"][0 if parts[2] == "cos" else 1]
+
+ if parts[0] == "past":
+ # past_{key/value}_{layer_idx}_{in/out}
+ kv_idx = 1 if parts[1] == "value" else 0
+ return tensor_dict["past_key_values"][int(parts[2])][kv_idx]
+
+ return tensor_dict[onnx_name]
+
+class LlamaOutputBufferCreator(OutputBufferCreator):
+ """
+ Allocates an empty output buffer for a vanilla Llama-like model.
+
+ params:
+ batch_size (int): The batch size.
+ max_input_tokens (int): The max input tokens.
+ config (PretrainedConfig): The model configuration.
+ device (torch.device): The device to allocate the output buffer.
+ """
+ def __init__(
+ self,
+ batch_size:int,
+ max_input_tokens:int,
+ config:PretrainedConfig,
+ device:torch.device):
+ self.batch_size = batch_size
+ self.max_input_tokens = max_input_tokens
+ self.config = config
+ self.device = device
+
+ def create_buffer(self) -> Dict[str, Any]:
+ """
+ Method to build the output buffer assuming logits and past_key_values are the only outputs.
+
+ returns:
+ Dict[str, Any]: The output buffer.
+ """
+ head_dim = self.config.hidden_size // self.config.num_attention_heads
+ output = {
+ "logits": torch.empty(self.batch_size, self.max_input_tokens, self.config.vocab_size, device=self.device, dtype=torch.float32).contiguous(),
+ "past_key_values": tuple((
+ torch.empty(self.batch_size, self.config.num_key_value_heads, head_dim, self.max_input_tokens, device=self.device,
+ dtype=torch.float32).contiguous(),
+ torch.empty(self.batch_size, self.config.num_key_value_heads, self.max_input_tokens,head_dim, device=self.device,
+ dtype=torch.float32).contiguous()) for _ in range(self.config.num_hidden_layers))
+ }
+ return output
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/utils.py
new file mode 100644
index 000000000..f3be74c78
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/llama/utils.py
@@ -0,0 +1,125 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that the pipeline would need to work with the adaptations made to LLaMa model. """
+
+import torch
+import functools
+from importlib.metadata import version
+from genai_lib.common.dev.utils import is_transformers_greater_or_equal
+
+def llm_update_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, model_id_or_path, mask_neg = -100, cache_index=None, pad_to_left=True):
+ '''
+
+ This function creates a 4D causal mask from the 2D attention mask.
+ We either call the BaseModel's `_update_causal_mask` (for versions earlier than 4.53.0) or use the `create_causal_mask` API for later versions, which is not specific to any one model.
+ Since these are two different APIs, HF a=has updated the expected inputs to these (text_config, input_embeds) as shown below.
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. model_id: Model name or path to pretrained model
+ 6. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ 7. cache_index: the index for the starting position of kvcaches
+ 8. pad_to_left: determines if the KV cache is padded to the left or right
+ '''
+
+ # if the cache position is None, then we assume that the current input ids will be concatenated to the right end and
+ # #hence we construct the cache position accordingly to be sent into the create_causal_mask
+ if pad_to_left:
+ # Cache index should not be passed. Concat op is used in doing the KV cache update
+ assert cache_index is None, "Invalid argument error: we do not support the combination of performing left padding and doing scatter for KV cache update."
+ else:
+ # if the user is doing right padding, it is necessary to pass the cache_index.
+ assert cache_index is not None, "Invalid argument error: we do not support the combination of performing right padding and doing concat for KV cache update"
+
+ if cache_index is None:
+ cache_position = torch.arange(model_context_len - max_input_tokens, model_context_len, device=input_tensor.device)
+ else:
+ cache_position = torch.arange(max_input_tokens, dtype=torch.float32, device=input_tensor.device) + cache_index.to(input_tensor.device)
+
+ if is_transformers_greater_or_equal("4.53.0"):
+ from transformers.masking_utils import create_causal_mask
+
+ config = _get_config(model_id_or_path)
+ config._attn_implementation = "eager"
+
+ # For the create causal mask API, the input_embeds second shape must reflect the total KV$ len (including the current tokens KV length)-https://github.com/huggingface/transformers/blob/cbb290ec23ccd9b5c1d1ff4d333477449891debb/src/transformers/masking_utils.py#L729C32-L729C53
+ input_embeds = torch.ones((input_tensor.shape[0], model_context_len, 1), device=input_tensor.device)
+ mask_kwargs = {
+ "config": config,
+ "input_embeds": input_embeds,
+ "attention_mask": prepared_1d_attn_mask,
+ "cache_position": cache_position,
+ "past_key_values": None,
+ "position_ids": None,
+ }
+ prepared_attention_mask = create_causal_mask(**mask_kwargs)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+ else:
+ model = _get_model(model_id_or_path)
+
+ # When we invoked the update_causal_mask of the BaseModel, then the second shape in the input_embeds (input_tensor argument) represents the sequence_length- https://github.com/huggingface/transformers/blob/5f4ecf2d9f867a1255131d2461d75793c0cf1db2/src/transformers/models/llama/modeling_llama.py#L636
+ input_embeds = torch.ones((input_tensor.shape[0], input_tensor.shape[1], 1), device=input_tensor.device)
+ prepared_attention_mask = model._update_causal_mask(attention_mask=prepared_1d_attn_mask,
+ input_tensor=input_embeds, output_attentions=True,
+ cache_position=cache_position, past_key_values=None)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+
+
+def llm_create_position_embeddings(config, dtype=torch.float32, position_ids=None):
+ '''
+ This function creates position embedding (RoPE) from the position ids.
+ params:
+ 1. config: model configuration to create the LLamaRotaryEmbedding object, expect config for backward compatibility, in future transformers, we only expect to pass the config, the one we have in docker, takes in the req argument
+ 2. position_ids: required position ids passed into the model
+ '''
+
+ hidden_size = config.hidden_size
+ max_position_embeddings = config.max_position_embeddings
+ num_attention_heads = config.num_attention_heads
+ rope_theta = config.rope_theta
+ dim = int((hidden_size // num_attention_heads))
+ device = position_ids.device
+ x = torch.ones(1, device=device, dtype=dtype)
+ rotary_emb = _get_rotary_embedding(dim=dim, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, device=device, config=config)
+ cos, sin = rotary_emb(x, position_ids=position_ids)
+ cos, sin = cos.unsqueeze(dim = 1), sin.unsqueeze(dim = 1)
+ cos = cos[:,:,:,:dim//2]
+ sin = sin[:,:,:, :dim//2]
+ return cos, sin
+
+
+def _get_rotary_embedding(dim, max_position_embeddings, rope_theta, device, config=None):
+ from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
+ if version('transformers') >= '4.48.0':
+ rotary_emb = LlamaRotaryEmbedding(config).to(device)
+ else:
+ rotary_emb = LlamaRotaryEmbedding(dim=dim, max_position_embeddings=max_position_embeddings, base=rope_theta, device=device, config=config)
+ return rotary_emb
+
+@functools.cache
+def _get_model(model_id_or_path):
+ from transformers import AutoConfig
+ from transformers.models.llama.modeling_llama import LlamaModel
+ config = AutoConfig.from_pretrained(model_id_or_path)
+ config.num_hidden_layers = 1
+ model = LlamaModel(config)
+ return model
+
+@functools.cache
+def _get_config(model_id_or_path):
+ from transformers import AutoConfig
+ config = AutoConfig.from_pretrained(model_id_or_path)
+ return config
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/mistral/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/mistral/adaptation.py
new file mode 100644
index 000000000..9b454db3a
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/mistral/adaptation.py
@@ -0,0 +1,319 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2024 Mistral AI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to the Mistral model. These adaptations are being
+done to optimize the model execution on the HTP backend.
+https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/mistral/modeling_mistral.py"""
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from transformers import MistralForCausalLM
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.models.mistral import modeling_mistral
+from transformers.models.mistral.modeling_mistral import (
+ Cache,
+ DynamicCache,
+ MistralAttention,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+
+from genai_lib.common.dev.utils import (
+ is_transformers_greater_or_equal_than_4_48,
+ is_transformers_greater_or_equal_than_4_51,
+)
+
+
+def _apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ '''
+ Based on FacebookResearch's llama, provided by Carl
+ '''
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim/2
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim/2
+
+ # TODO: Why HF uses different coordinates from the paper
+ x_real = x[:,:,:,:x.shape[-1]//2] # extract first half elements
+ x_im = x[:,:,:,x.shape[-1]//2:] # extract second half elements
+
+ x_prod_real = x_real*rope_real - x_im * rope_im
+ x_prod_im = x_real*rope_im + x_im*rope_real
+
+ # TODO: HF need to uses different interleaving
+ x = torch.cat((x_prod_real,x_prod_im),dim=3).view(*x.shape)
+ return x
+
+orig_embedding_fwd = modeling_mistral.MistralRotaryEmbedding.forward
+def adapted_RotaryEmbedding(self, x, position_ids, *args, **kwargs):
+ if isinstance(position_ids, tuple) and len(position_ids) == 2:
+ return position_ids
+ else:
+ return orig_embedding_fwd(self, x, position_ids, *args, **kwargs)
+
+class QcMistralAttention(MistralAttention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ #QC
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ if isinstance(position_ids, (tuple, list)): # QC
+ position_embeddings = position_ids
+ else:
+ position_embeddings = self.rotary_emb(value_states, position_ids)
+ cos, sin = position_embeddings
+ query_states = _apply_rope_single(query_states, position_embeddings)
+ key_states = _apply_rope_single(key_states, position_embeddings)
+
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_value is not None:
+ assert isinstance(past_key_value, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache:
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ if is_transformers_greater_or_equal_than_4_48:
+ return attn_output, attn_weights
+
+ return attn_output, attn_weights, past_key_value
+
+orig_causal_mask = modeling_mistral.MistralModel._update_causal_mask
+def adapted_update_causal_mask(self, attention_mask, *args, **kwargs):
+ if attention_mask is not None and attention_mask.dim() == 4:
+ return attention_mask
+ else:
+ return orig_causal_mask(self, attention_mask, *args, **kwargs)
+
+
+class QcMistralForCausalLM(MistralForCausalLM):
+ def __init__(self, config):
+ super().__init__(config)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: Optional[int] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ if isinstance(past_key_values, tuple) and is_transformers_greater_or_equal_than_4_48:
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ kwargs = self._construct_kwargs(
+ input_ids,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ inputs_embeds,
+ labels,
+ use_cache,
+ output_attentions,
+ output_hidden_states,
+ return_dict,
+ cache_position,
+ num_logits_to_keep,
+ )
+
+ outputs = super().forward(**kwargs)
+
+ if is_transformers_greater_or_equal_than_4_48:
+ if return_dict or self.config.use_return_dict:
+ assert not isinstance(outputs.past_key_values, tuple)
+ outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
+ else:
+ new_outputs = []
+ for item in outputs:
+ if isinstance(item, DynamicCache):
+ new_outputs.append(item.to_legacy_cache())
+ else:
+ new_outputs.append(item)
+ outputs = tuple(new_outputs)
+
+ return outputs
+
+ def _construct_kwargs(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: Optional[int] = None,
+ ):
+ num_logits_to_keep = num_logits_to_keep or getattr(
+ self.config, 'num_logits_to_keep', 0
+ )
+
+ kwargs = {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ 'position_ids': position_ids,
+ 'past_key_values': past_key_values,
+ 'inputs_embeds': inputs_embeds,
+ 'labels': labels,
+ 'use_cache': use_cache,
+ 'output_attentions': output_attentions,
+ 'output_hidden_states': output_hidden_states,
+ 'return_dict': return_dict,
+ 'cache_position': cache_position,
+ 'num_logits_to_keep': num_logits_to_keep,
+ }
+
+ if is_transformers_greater_or_equal_than_4_51:
+ kwargs.pop('return_dict')
+ kwargs['logits_to_keep'] = kwargs.pop('num_logits_to_keep')
+
+ return kwargs
+
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = value_cache
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.value_cache) <= layer_idx:
+ return 0
+ return self.value_cache[layer_idx].shape[-2]
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/mistral/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/mistral/utils.py
new file mode 100644
index 000000000..43cac4040
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/mistral/utils.py
@@ -0,0 +1,104 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that the pipeline would need to work with the adaptations made to Mistral model. """
+
+import functools
+
+import torch
+
+from genai_lib.common.dev.utils import is_transformers_greater_or_equal_than_4_48
+
+
+def llm_update_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, model_id_or_path, mask_neg = -100.0):
+ '''
+
+ This function creates a causal mask (2D) from the 1D attention mask
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. model_id_or_path: Model name or path to pretrained model
+ 6. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ '''
+ mistral_model = _get_model(model_id_or_path)
+ cache_position = torch.arange(model_context_len-max_input_tokens, model_context_len, device = input_tensor.device)
+ input_embeds = torch.ones((input_tensor.shape[0], input_tensor.shape[1], 1), device = input_tensor.device)
+
+ kwargs = {
+ 'attention_mask': prepared_1d_attn_mask,
+ 'input_tensor': input_embeds,
+ 'cache_position': cache_position,
+ 'past_key_values': None,
+ 'use_cache': False,
+ 'output_attentions': True,
+ }
+ if is_transformers_greater_or_equal_than_4_48:
+ # `use_cache` argument is not valid since transformers>=4.48.0
+ kwargs.pop('use_cache')
+
+ prepared_attention_mask = mistral_model._update_causal_mask(**kwargs)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+
+def llm_create_position_embeddings(config, position_ids=None):
+ '''
+ This function creates position embedding (RoPE) from the position ids.
+ params:
+ 1. config: model configuration to create the LLamaRotaryEmbedding object, expect config for backward compatibility, in future transformers, we only expect to pass the config, the one we have in docker, takes in the req argument
+ 2. position_ids: required position ids passed into the model
+ '''
+
+ max_position_embeddings = config.max_position_embeddings
+ rope_theta = config.rope_theta
+
+ # NOTE: transformers==4.52.4, config has head_dim attribute but `None` value
+ # https://github.com/huggingface/transformers/blob/51f94ea06d19a6308c61bbb4dc97c40aabd12bad/src/transformers/models/mistral/configuration_mistral.py#L124-L146
+ if hasattr(config, 'head_dim') and config.head_dim is not None:
+ dim = config.head_dim
+ else:
+ dim = config.hidden_size // config.num_attention_heads
+
+ device = position_ids.device
+ x = torch.ones(1, device=device)
+
+ if is_transformers_greater_or_equal_than_4_48:
+ rotary_emb = _get_rotary_embedding_from_config(config, device)
+ else:
+ rotary_emb = _get_rotary_embedding(dim=dim, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, device=device)
+
+ cos, sin = rotary_emb(x, position_ids=position_ids)
+ cos, sin = cos.unsqueeze(dim = 1), sin.unsqueeze(dim = 1)
+ cos = cos[:,:,:,:dim//2]
+ sin = sin[:,:,:, :dim//2]
+ return cos, sin
+
+@functools.cache
+def _get_rotary_embedding(dim, max_position_embeddings, rope_theta, device):
+ from transformers.models.mistral.modeling_mistral import MistralRotaryEmbedding
+ rotary_emb = MistralRotaryEmbedding(dim=dim, max_position_embeddings=max_position_embeddings, base=rope_theta, device=device)
+ return rotary_emb
+
+
+def _get_rotary_embedding_from_config(config, device):
+ from transformers.models.mistral.modeling_mistral import MistralRotaryEmbedding
+ return MistralRotaryEmbedding(config, device)
+
+
+@functools.cache
+def _get_model(model_id_or_path):
+ from transformers import AutoConfig
+ from transformers.models.mistral.modeling_mistral import MistralModel
+ config = AutoConfig.from_pretrained(model_id_or_path)
+ config.num_hidden_layers = 1
+ model = MistralModel(config)
+ return model
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/phi/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/phi/adaptation.py
new file mode 100644
index 000000000..e39384705
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/phi/adaptation.py
@@ -0,0 +1,480 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to the Phi3 model. These adaptations are being done to
+optimize the model execution on the HTP backend.
+https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py"""
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.modeling_outputs import BaseModelOutputWithPast
+
+
+from transformers import Phi3ForCausalLM
+from transformers import cache_utils
+from transformers.models.phi3 import modeling_phi3
+from transformers.models.phi3.modeling_phi3 import (
+ repeat_kv,
+ Cache,
+ DynamicCache,
+ Phi3Attention,
+ Phi3Config,
+ apply_rotary_pos_emb,
+ Phi3Model
+)
+
+from genai_lib.llm.long_context_utils import AnchorUpdaterKeySecond
+from genai_lib.common.dev.utils import filter_outputs
+from importlib.metadata import version
+
+from transformers.utils import logging
+logger = logging.get_logger(__name__)
+
+def _apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ '''
+ Based on FacebookResearch's llama, provided by Carl
+ '''
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim/2
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim/2
+
+ # TODO: Why HF uses different coordinates from the paper
+ x_real = x[:,:,:,:x.shape[-1]//2] # extract first half elements
+ x_im = x[:,:,:,x.shape[-1]//2:] # extract second half elements
+
+ x_prod_real = x_real*rope_real - x_im * rope_im
+ x_prod_im = x_real*rope_im + x_im*rope_real
+
+ # TODO: HF need to uses different interleaving
+ x = torch.cat((x_prod_real,x_prod_im),dim=3).view(*x.shape)
+ return x
+
+
+class QcPhiAttention(Phi3Attention):
+ """Multi-headed attention from 'Attention Is All You Need' paper
+ We override the init and initialize the anchor updater.
+ The user can optionally pass in the alpha value to initialize the anchor_updater through the model config.
+ """
+
+ def __init__(self, config: Phi3Config, layer_idx: int):
+ super(QcPhiAttention, self).__init__(config, layer_idx)
+ self.use_unpack_qkv = False
+
+ # We only initialize anchor_updater when the anchor_alpha is present in the config
+ if getattr(config, "anchor_alpha", None) is not None:
+ self.anchor_updater = AnchorUpdaterKeySecond(alpha=config.anchor_alpha)
+
+
+
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+ def unpack_qkv(self):
+ self.use_unpack_qkv = True
+
+ device = self.qkv_proj.weight.device
+
+ self.q_proj = nn.Linear(self.config.hidden_size, self.config.num_attention_heads * self.head_dim, bias=False).to(device)
+ self.k_proj = nn.Linear(self.config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=False).to(device)
+ self.v_proj = nn.Linear(self.config.hidden_size, self.config.num_key_value_heads * self.head_dim, bias=False).to(device)
+
+ # Calculate the positions for slicing
+ total_hidden_size = self.config.num_attention_heads * self.head_dim # query size
+ key_value_size = self.config.num_key_value_heads * self.head_dim # key and value size
+
+ # Slicing and copying weights to q_proj, k_proj, v_proj
+ self.q_proj.weight.data.copy_(self.qkv_proj.weight[:total_hidden_size, :])
+ self.k_proj.weight.data.copy_(self.qkv_proj.weight[total_hidden_size: total_hidden_size + key_value_size, :])
+ self.v_proj.weight.data.copy_(self.qkv_proj.weight[total_hidden_size + key_value_size:, :])
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ #QC
+ valid_token_mask = kwargs.get('valid_token_mask', None)
+ anchor_buffer = kwargs.get('anchor_buffer', None)
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ bsz, q_len, _ = hidden_states.size()
+ input_shape = hidden_states.shape[:-1]
+ hidden_shape = (*input_shape, -1, self.head_dim)
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(hidden_shape).transpose(1, 2)
+ key_states = key_states.view(hidden_shape).transpose(1, 2)
+ value_states = value_states.view(hidden_shape).transpose(1, 2)
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ if isinstance(position_ids, (tuple, list)): # QC
+ position_embeddings = position_ids
+ else:
+ position_embeddings = self.rotary_emb(value_states, position_ids)
+ cos, sin = position_embeddings
+ query_states = _apply_rope_single(query_states, position_embeddings)
+ key_states = _apply_rope_single(key_states, position_embeddings)
+
+ # we cache the un-transposed keys before we pass into the combined scoring model.
+ untransposed_keys = key_states
+
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_value is not None:
+ assert isinstance(past_key_value, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ "num_key_value_heads": self.num_key_value_heads,
+ "head_dim": self.head_dim,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # key_states is the concatenated keys
+ # past_key_value.key_cache[layer_idx] is the new key states
+ # we invoke the scoring model and insert the outputs, anchor and the evict indices into the past_key_values (DynamicCache object) as additional attributes.
+ if valid_token_mask is not None and anchor_buffer is not None:
+ anchor_buffer = anchor_buffer[self.layer_idx]
+ anchor = self.anchor_updater(new_keys = untransposed_keys, valid_token_mask=valid_token_mask, old_anchor = anchor_buffer)
+
+ insert_meta_info_to_pastkv(past_key_values = past_key_value, meta_info_bundle={"anchor_buffer":anchor}, layer_idx=self.layer_idx)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache:
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ # handle version-specific return
+ if version('transformers') >= '4.48.0':
+ return attn_output, attn_weights
+ else:
+ return attn_output, attn_weights, past_key_value
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ cache_position = cache_kwargs.get('cache_position')
+ num_key_value_heads = cache_kwargs.get('num_key_value_heads')
+ head_dim = cache_kwargs.get('head_dim')
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ # if the size of past key cache passed is smaller in value than the last position where the new kv is to be inserted
+ # [in case when Cache position determined automatically by HF] (Ctx_len+ARN), then we want to perform concat and not do scattering.
+ if self.value_cache[layer_idx].shape[-2] <= cache_position[-1]:
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ else:
+ # the cache_position passed in as model i/p by user is a 1d tensor reflecting the positions
+ # from valid_kv_end to valid_kv_end+ARN, we convert this into the indices for scattering. [# bsz, num_key_value_heads, head_dim, seq_len]-> works for transposed keys
+ indices = cache_position.view(1, 1, 1, -1).expand(value_states.shape[0], num_key_value_heads, head_dim, cache_position.shape[-1])
+
+ value_cache = self.value_cache[layer_idx].scatter(dim=-2, index=indices.transpose(-1,-2), src=value_states)
+
+ indices = indices.transpose(-1, -2) if key_cat_dim== -2 else indices
+ key_cache = self.key_cache[layer_idx].scatter(dim=key_cat_dim, index=indices, src=key_states)
+
+
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = value_cache
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.value_cache) <= layer_idx:
+ return 0
+ return self.value_cache[layer_idx].shape[-2]
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
+
+
+orig_causal_mask = modeling_phi3.Phi3Model._update_causal_mask
+def adapted_update_causal_mask(self, attention_mask, *args, **kwargs):
+ if attention_mask is not None and attention_mask.dim() == 4:
+ return attention_mask
+ else:
+ return orig_causal_mask(self, attention_mask, *args, **kwargs)
+
+orig_embedding_fwd = modeling_phi3.Phi3RotaryEmbedding.forward
+def adapted_RotaryEmbedding(self, x, position_ids, *args, **kwargs):
+ if isinstance(position_ids, tuple) and len(position_ids)==2:
+ return position_ids
+ else:
+ return orig_embedding_fwd(self, x, position_ids, *args, **kwargs)
+
+class QcPhi3ForCausalLM(Phi3ForCausalLM):
+ """
+ Subclass of original Phi3ForCausalLM. This is needed to serve two purposes:
+
+ 1. Starting from transformers version 4.45.0, the num_logits_to_keep argument is now required argument.
+ Consequently, the prepared static graph will always include this additional argument.
+ To maintain compatibility with our existing pipelines, we create a new class that inherits from
+ Phi3ForCausalLM. In this new class, we redefine the forward method without the num_logits_to_keep
+ argument and in inside the forward we infer the num_logits_to_keep from the config and then call the superclass's forward method.
+
+ 2. For the Long Context scoring model within the LLM, we need to pass two additional arguments:
+ anchor_buffer and valid_token_mask. These can be provided as keyword arguments (introduced in transformers version 4.47.0).
+ However, this approach is incompatible with Onnx export, as Onnx does not support keyword arguments when creating the onnx graph.
+ Therefore, we pass valid_token_mask and anchor_buffer as model inputs to Phi3ForCausalLM,
+ which in turn get recognized through keyword arguments in the downstream blocks.
+ This is similar to how the DynamicCache object is not traced by jit.trace at the topmost level in Phi3ForCausalLM.
+ """
+ def __init__(self, config):
+ super().__init__(config)
+ if getattr(config, "input_tokens_per_inference", None) is not None:
+ self.register_buffer(name='cache_tensor', tensor=torch.arange(config.input_tokens_per_inference))
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Optional[int] = None,
+ valid_token_mask: Optional[torch.Tensor]=None,
+ anchor_buffer: Optional[torch.Tensor]=None,
+ cache_index: Optional[torch.Tensor]=None,
+ **kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ logits_to_keep = logits_to_keep if logits_to_keep else getattr(self.config, "logits_to_keep", 0)
+
+ if cache_index is not None:
+ assert hasattr(self, "cache_tensor"), "QcPhi3ForCausal doesn't have attribute \"cache_tensor\", " \
+ "check if \"input_tokens_per_inference\" is specified in model config"
+ cache_position = cache_index + self.cache_tensor
+
+ if type(past_key_values) == tuple and version('transformers') >= '4.48.0':
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ outputs = super().forward(
+ input_ids = input_ids,
+ attention_mask= attention_mask,
+ position_ids= position_ids,
+ past_key_values= past_key_values,
+ inputs_embeds= inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ valid_token_mask=valid_token_mask,
+ anchor_buffer=anchor_buffer,
+ **kwargs)
+
+ if version('transformers') >= '4.48.0':
+ if return_dict:
+ assert type(outputs.past_key_values) != tuple
+ past_key_values_output = DynamicCache.to_legacy_cache(outputs.past_key_values)
+ outputs.past_key_values = past_key_values_output
+ else:
+ new_outputs = []
+ for item in outputs:
+ if isinstance(item, DynamicCache):
+ past_key_values_output = DynamicCache.to_legacy_cache(item)
+ new_outputs.append(past_key_values_output)
+ else:
+ new_outputs.append(item)
+ outputs = tuple(new_outputs)
+
+ if hasattr(self.config, "output_index_filter"):
+ return filter_outputs(outputs, self.config.output_index_filter)
+ return outputs
+
+def insert_meta_info_to_pastkv(
+ past_key_values,
+ meta_info_bundle: {},
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ This function adds two new model outputs to the DynamicCache object, eliminating the need to write a new adaptation for the additional output.
+
+ params:
+ past_key_values: The DynamicCache object
+ evict_info_bundle: A dictionary containing the additional information to attach to the DynamicCache object.
+ layer_idx: An attribute pointing to a list, where layer_idx corresponds to the data of the given layer.
+ cache_kwargs: Additional keyword arguments for the cache.
+
+ """
+ for key, value in meta_info_bundle.items():
+ if getattr(past_key_values, key, None) is None:
+ setattr(past_key_values, key, [])
+
+ key_attr = getattr(past_key_values, key, None)
+ # Update the cache
+ if len(key_attr) <= layer_idx:
+ key_attr.append(value)
+ else:
+ key_attr[layer_idx] = value
+
+def DynamicCache_to_legacy_cache(self):
+
+ """
+ Converts the DynamicCache instance to its equivalent in the legacy cache format for backward compatibility.
+
+ The past_key_values passed into the model as input is a tuple.
+ The Phi3Model converts it into a Cache object if it isn't one already. Within the model, past_key_values flow as a DynamicCache object.
+ Just before returning the output, the Phi3Model converts the DynamicCache back to the legacy cache (tuple format).
+ Since we added new attributes to our Cache object, we need to ensure they are included as additional entries in the returned tuple."""
+
+ legacy_cache = ()
+ for layer_idx in range(len(self)):
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+ if "anchor_buffer" in dir(self):
+ return (legacy_cache, self.anchor_buffer)
+ return legacy_cache
+
+class QcPhi3Model(Phi3Model):
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ if type(past_key_values) == tuple and version('transformers') >= '4.48.0':
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ outputs = super().forward(
+ input_ids = input_ids,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_values = past_key_values,
+ inputs_embeds = inputs_embeds,
+ use_cache = use_cache,
+ cache_position = cache_position,
+ **kwargs
+ )
+
+ if isinstance(outputs, BaseModelOutputWithPast) and isinstance(outputs.past_key_values, DynamicCache):
+ outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
+ elif isinstance(outputs, tuple):
+ new_outputs = []
+ for item in outputs:
+ if isinstance(item, DynamicCache):
+ new_outputs.append(item.to_legacy_cache())
+ else:
+ new_outputs.append(item)
+ outputs = tuple(new_outputs)
+ else:
+ raise ValueError(f"Model output is expected to be an instance of BaseModelOutputWithPast or Tuple, got {type(outputs)}")
+
+ return outputs
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/phi/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/phi/utils.py
new file mode 100644
index 000000000..7fcd950d0
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/phi/utils.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that the pipeline would need to work with the adaptations made to Phi3 model. """
+
+import torch
+import functools
+
+def llm_update_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, model_id_or_path, mask_neg = -100.0, cache_index=None, pad_to_left = True):
+ '''
+
+ This function creates a causal mask (2D) from the 1D attention mask
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. model_id: Model name or path to pretrained model
+ 6. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ 7. cache_index: the index for the starting position of kvcaches
+ 6. pad_to_left: determines if the KV cache is padded to the left or right
+ '''
+ phi_model = _get_model(model_id_or_path)
+
+ # if the cache position is None, then we assume that the current input ids will be concatenated to the right end and
+ # #hence we construct the cache position accordingly to be sent into the update_causal_mask
+
+ if pad_to_left:
+ # Cache index should not be passed. Concat op is used in doing the KV cache update
+ assert cache_index is None, "Invalid argument error: we do not support the combination of performing left padding and doing scatter for KV cache update."
+ else:
+ # if the user is doing right padding, it is necessary to pass the cache_index.
+ assert cache_index is not None, "Invalid argument error: we do not support the combination of performing right padding and doing concat for KV cache update"
+
+ if cache_index is None:
+ cache_position = torch.arange(model_context_len-max_input_tokens, model_context_len, device = input_tensor.device)
+ else:
+ cache_position = torch.arange(max_input_tokens, dtype=torch.float32, device=input_tensor.device) + cache_index.to(input_tensor.device)
+
+ input_embeds = torch.ones((input_tensor.shape[0], input_tensor.shape[1], 1), device = input_tensor.device)
+ prepared_attention_mask = phi_model._update_causal_mask(attention_mask=prepared_1d_attn_mask, input_tensor=input_embeds, output_attentions=True,
+ cache_position=cache_position, past_key_values=None)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+
+def llm_create_position_embeddings(config, position_ids=None):
+ '''
+ This function creates position embedding (RoPE) from the position ids.
+ params:
+ 1. config: model configuration to create the Phi3RotaryEmbedding object, expect config for backward compatibility, in future transformers, we only expect to pass the config, the one we have in docker, takes in the req argument
+ 2. position_ids: required position ids passed into the model
+ '''
+
+ hidden_size = config.hidden_size
+ max_position_embeddings = config.max_position_embeddings
+ num_attention_heads = config.num_attention_heads
+ rope_theta = config.rope_theta
+ dim = int((hidden_size // num_attention_heads))
+ device = position_ids.device
+ x = torch.ones(1, device=device)
+ rotary_emb = _get_rotary_embedding(dim=dim, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, device=device, config=config)
+ # print(rotary_emb)
+ # import inspect
+ # print(inspect.getsource(rotary_emb.forward))
+
+ cos, sin = rotary_emb(x, position_ids=position_ids)
+ # Change
+ # print(x.shape, position_ids.shape)
+ # emb = rotary_emb(x, position_ids=position_ids)
+ # print(type(emb))
+ # print(emb)
+ # import transformers
+ # print(transformers.__version__)
+ # cos = emb[0]
+ # sin = emb[1]
+ cos, sin = cos.unsqueeze(dim = 1), sin.unsqueeze(dim = 1)
+ cos = cos[:,:,:,:dim//2]
+ sin = sin[:,:,:, :dim//2]
+ return cos, sin
+
+def _get_rotary_embedding(dim, max_position_embeddings, rope_theta, device, config=None):
+ from transformers.models.phi3.modeling_phi3 import Phi3RotaryEmbedding
+ rotary_emb = Phi3RotaryEmbedding(config=config)
+ return rotary_emb
+
+@functools.cache
+def _get_model(model_id_or_path):
+ from transformers import AutoConfig
+ from transformers.models.phi3.modeling_phi3 import Phi3Model
+ config = AutoConfig.from_pretrained(model_id_or_path)
+ config.num_hidden_layers = 1
+ model = Phi3Model(config)
+ return model
+
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/adaptation.py
new file mode 100644
index 000000000..f686d94c6
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/adaptation.py
@@ -0,0 +1,454 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to the Mistral model. These adaptations are being
+done to optimize the model execution on the HTP backend.
+https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/qwen2/modeling_qwen2.py"""
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.qwen2 import modeling_qwen2
+from transformers.models.qwen2.modeling_qwen2 import (
+ repeat_kv,
+ Cache,
+ Qwen2Model,
+ DynamicCache,
+ Qwen2Attention,
+ apply_rotary_pos_emb,
+ Qwen2ForCausalLM,
+)
+from genai_lib.llm.dev.model_adaptation.common.spinquant import R3Hadamard
+from importlib.metadata import version
+from genai_lib.llm.long_context_utils import AnchorUpdaterKeySecond
+from genai_lib.common.dev.utils import filter_outputs
+
+from transformers.utils import logging
+logger = logging.get_logger(__name__)
+
+def _apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ '''
+ Based on FacebookResearch's llama, provided by Carl
+ '''
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim/2
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim/2
+
+ # TODO: Why HF uses different coordinates from the paper
+ x_real = x[:,:,:,:x.shape[-1]//2] # extract first half elements
+ x_im = x[:,:,:,x.shape[-1]//2:] # extract second half elements
+
+ x_prod_real = x_real*rope_real - x_im * rope_im
+ x_prod_im = x_real*rope_im + x_im*rope_real
+
+ # TODO: HF need to uses different interleaving
+ x = torch.cat((x_prod_real,x_prod_im),dim=3).view(*x.shape)
+ return x
+
+orig_embedding_fwd = modeling_qwen2.Qwen2RotaryEmbedding.forward
+def adapted_RotaryEmbedding(self, x, position_ids, seq_len=None, *args, **kwargs):
+ if isinstance(position_ids, tuple) and len(position_ids) == 2:
+ return position_ids
+ else:
+ # transformers will move to taking in position_ids like llama moving forward for qwen2. Qwen2RotaryEmbedding -> https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2/modeling_qwen2.py
+ #In transformers version > 4.44.2, the rotary_embedding's forward function returns position_embeddings (cos, sin) for each position_id in the batch.
+ if version('transformers') >= '4.44.2':
+ return orig_embedding_fwd(self, x, position_ids=position_ids, *args, **kwargs)
+
+ cos, sin= orig_embedding_fwd(self, x, seq_len=seq_len, *args, **kwargs)
+ # For versions < 4.44.2, it returns a single view of cos, sin using the provided seq_len. Therefore, we need to stack the position embeddings for each position_id in the batch.
+ cos, sin = (torch.stack([cos[[position_ids[i]]] for i in range(position_ids.shape[0])]),
+ torch.stack([sin[[position_ids[i]]] for i in range(position_ids.shape[0])]))
+ return cos, sin
+
+class QcQwen2Attention(Qwen2Attention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+ def __init__(self, config, layer_idx: Optional[int] = None):
+ super(QcQwen2Attention, self).__init__(config, layer_idx)
+
+ if getattr(self.config, "enable_r3_hadamard", False):
+ """ R3 Hadamard sourced from SpinQuant paper: https://arxiv.org/pdf/2405.16406"""
+ self.q_R3 = R3Hadamard(head_dim = self.head_dim)
+ self.k_R3 = R3Hadamard(head_dim = self.head_dim)
+
+ # LONG_CONTEXT : We only initialize anchor_updater when the anchor_alpha is present in the config
+ if getattr(config, "anchor_alpha", None) is not None:
+ self.anchor_updater = AnchorUpdaterKeySecond(alpha=config.anchor_alpha)
+
+ # We only use "torch.where(attention_mask, input, min(input)-20)" sequence when the enable_masked_softmax is present in the config
+ self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ #QC
+
+ # LONG_CONTEXT
+ valid_token_mask = kwargs.get('valid_token_mask', None)
+ anchor_buffer = kwargs.get('anchor_buffer', None)
+
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ if cos.shape[-1] == query_states.shape[-1]:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ else:
+ query_states = _apply_rope_single(query_states, position_embeddings)
+ key_states = _apply_rope_single(key_states, position_embeddings)
+
+ if getattr(self.config, "enable_r3_hadamard", False):
+ query_states = self.q_R3(query_states)
+ key_states = self.k_R3(key_states)
+
+ # we cache the un-transposed keys before we pass into the combined scoring model.
+ untransposed_keys = key_states
+
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_value is not None:
+ assert isinstance(past_key_value, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ "num_key_value_heads": self.config.num_key_value_heads,
+ "head_dim": self.head_dim
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # key_states is the concatenated keys
+ # past_key_value.key_cache[layer_idx] is the new key states
+ # we invoke the scoring model and insert the outputs, anchor and the evict indices into the past_key_values (DynamicCache object) as additional attributes.
+ if valid_token_mask is not None and anchor_buffer is not None:
+ anchor_buffer = anchor_buffer[self.layer_idx]
+ anchor = self.anchor_updater(new_keys = untransposed_keys, valid_token_mask=valid_token_mask, old_anchor = anchor_buffer)
+
+ insert_meta_info_to_pastkv(past_key_values = past_key_value, meta_info_bundle={"anchor_buffer":anchor}, layer_idx=self.layer_idx)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache:
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1]!=value_states.shape[-2]:
+ attention_mask = attention_mask[:, :, :, : value_states.shape[-2]]
+ if self.enable_masked_softmax:
+ attn_weights_min, _ = torch.min(attn_weights, dim=-1, keepdim=True)
+ minus_value = -20
+ attn_weights = torch.where(attention_mask==0, attn_weights, attn_weights_min + minus_value)
+ else:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ # handle version-specific return
+ if version('transformers') >= '4.48.0':
+ return attn_output, attn_weights
+ else:
+ return attn_output, attn_weights, past_key_value
+
+
+orig_causal_mask = modeling_qwen2.Qwen2Model._update_causal_mask
+def adapted_update_causal_mask(self, attention_mask, *args, **kwargs):
+ if attention_mask is not None and attention_mask.dim() == 4:
+ return attention_mask
+ else:
+ return orig_causal_mask(self, attention_mask, *args, **kwargs)
+
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ cache_position = cache_kwargs.get('cache_position')
+ num_key_value_heads = cache_kwargs.get('num_key_value_heads')
+ head_dim = cache_kwargs.get('head_dim')
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ # if the size of past key cache passed is smaller in value than the last position where the new kv is to be inserted
+ # [in case when Cache position determined automatically by HF] (Ctx_len+ARN), then we want to perform concat and not do scattering.
+ if self.value_cache[layer_idx].shape[-2] <= cache_position[-1]:
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ else:
+ # the cache_position passed in as model i/p by user is a 1d tensor reflecting the positions
+ # from valid_kv_end to valid_kv_end+ARN, we convert this into the indices for scattering. [# bsz, num_key_value_heads, head_dim, seq_len]-> works for transposed keys
+ indices = cache_position.view(1, 1, 1, -1).expand(value_states.shape[0], num_key_value_heads, head_dim, cache_position.shape[-1])
+
+ value_cache = self.value_cache[layer_idx].scatter(dim=-2, index=indices.transpose(-1,-2), src=value_states)
+
+ indices = indices.transpose(-1, -2) if key_cat_dim== -2 else indices
+ key_cache = self.key_cache[layer_idx].scatter(dim=key_cat_dim, index=indices, src=key_states)
+
+
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = value_cache
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.value_cache) <= layer_idx:
+ return 0
+ return self.value_cache[layer_idx].shape[-2]
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
+
+def insert_meta_info_to_pastkv(
+ past_key_values,
+ meta_info_bundle: {},
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ This function adds two new model outputs to the DynamicCache object, eliminating the need to write a new adaptation for the additional output.
+
+ params:
+ past_key_values: The DynamicCache object
+ evict_info_bundle: A dictionary containing the additional information to attach to the DynamicCache object.
+ layer_idx: An attribute pointing to a list, where layer_idx corresponds to the data of the given layer.
+ cache_kwargs: Additional keyword arguments for the cache.
+
+ """
+ for key, value in meta_info_bundle.items():
+ if getattr(past_key_values, key, None) is None:
+ setattr(past_key_values, key, [])
+
+ key_attr = getattr(past_key_values, key, None)
+ # Update the cache
+ if len(key_attr) <= layer_idx:
+ key_attr.append(value)
+ else:
+ key_attr[layer_idx] = value
+
+def DynamicCache_to_legacy_cache(self):
+
+ """
+ Converts the DynamicCache instance to its equivalent in the legacy cache format for backward compatibility.
+
+ The past_key_values passed into the model as input is a tuple.
+ The LlamaModel converts it into a Cache object if it isn't one already. Within the model, past_key_values flow as a DynamicCache object.
+ Just before returning the output, the LlamaModel converts the DynamicCache back to the legacy cache (tuple format).
+ Since we added new attributes to our Cache object, we need to ensure they are included as additional entries in the returned tuple."""
+
+ legacy_cache = ()
+ for layer_idx in range(len(self)):
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+ if "anchor_buffer" in dir(self):
+ return (legacy_cache, self.anchor_buffer)
+ return legacy_cache
+
+
+class QcQwen2ForCausalLM(Qwen2ForCausalLM):
+ def __init__(self, config):
+ super().__init__(config)
+ if getattr(config, "input_tokens_per_inference", None) is not None:
+ self.register_buffer(name='cache_tensor', tensor=torch.arange(config.input_tokens_per_inference), persistent=False)
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Optional[int] = None,
+ valid_token_mask: Optional[torch.Tensor]=None,
+ anchor_buffer: Optional[torch.Tensor]=None,
+ cache_index: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ logits_to_keep = logits_to_keep if logits_to_keep else getattr(self.config, "logits_to_keep", 0)
+ if cache_index is not None:
+ assert hasattr(self, "cache_tensor"), "QcQwen2ForCausalLM doesn't have attribute \"cache_tensor\", " \
+ "check if \"input_tokens_per_inference\" is specified in model config"
+ cache_position = cache_index + self.cache_tensor
+
+ if type(past_key_values) == tuple and version('transformers') >= '4.48.0':
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ outputs = super().forward(
+ input_ids = input_ids,
+ attention_mask= attention_mask,
+ position_ids= position_ids,
+ past_key_values= past_key_values,
+ inputs_embeds= inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ valid_token_mask=valid_token_mask,
+ anchor_buffer=anchor_buffer,
+ **kwargs)
+
+ if version('transformers') >= '4.48.0':
+ if return_dict:
+ assert type(outputs.past_key_values) != tuple
+ past_key_values_output = DynamicCache.to_legacy_cache(outputs.past_key_values)
+ outputs.past_key_values = past_key_values_output
+ else:
+ new_outputs = []
+ for item in outputs:
+ if isinstance(item, DynamicCache):
+ past_key_values_output = DynamicCache.to_legacy_cache(item)
+ new_outputs.append(past_key_values_output)
+ else:
+ new_outputs.append(item)
+ outputs = tuple(new_outputs)
+
+ if hasattr(self.config, "output_index_filter"):
+ return filter_outputs(outputs, self.config.output_index_filter)
+ return outputs
+
+
+class QcQwen2Model(Qwen2Model):
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+
+ outputs = super().forward(
+ input_ids = input_ids,
+ attention_mask = attention_mask,
+ position_ids = position_ids,
+ past_key_values = past_key_values,
+ inputs_embeds = inputs_embeds,
+ use_cache = use_cache,
+ cache_position = cache_position,
+ **kwargs
+ )
+
+ if isinstance(outputs, BaseModelOutputWithPast) and isinstance(outputs.past_key_values, DynamicCache):
+ outputs.past_key_values = outputs.past_key_values.to_legacy_cache()
+ elif isinstance(outputs, tuple):
+ new_outputs = []
+ for item in outputs:
+ if isinstance(item, DynamicCache):
+ new_outputs.append(item.to_legacy_cache())
+ else:
+ new_outputs.append(item)
+ outputs = tuple(new_outputs)
+ else:
+ raise ValueError(f"Model output is expected to be an instance of BaseModelOutputWithPast or Tuple, got {type(outputs)}")
+
+ return outputs
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/modeling_eaglet_qwen2.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/modeling_eaglet_qwen2.py
new file mode 100644
index 000000000..cfa7ba120
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/modeling_eaglet_qwen2.py
@@ -0,0 +1,35 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+from torch import nn
+from genai_lib.llm.dev.model_adaptation.qwen2.adaptation import adapted_update_causal_mask, QcQwen2Attention
+from genai_lib.llm.eaglet.base_draft_model import BaseDraftModel
+from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm, Qwen2DecoderLayer
+
+__all__ = ["Qwen2EagletDraftModel", "Qwen2EagletDecoderLayer"]
+
+
+class Qwen2EagletDraftModel(BaseDraftModel):
+ def __init__(self, config, *args, **kwargs):
+ super().__init__(
+ config,
+ decoder_cls=Qwen2EagletDecoderLayer,
+ norm_cls=Qwen2RMSNorm,
+ dual_fc=True,
+ )
+
+ _update_causal_mask = adapted_update_causal_mask
+
+
+class Qwen2EagletDecoderLayer(Qwen2DecoderLayer):
+ def __init__(self, config, layer_idx: int):
+ super().__init__(config=config, layer_idx=layer_idx)
+ if layer_idx == 0:
+ self.input_layernorm = nn.Identity()
+ self.self_attn = QcQwen2Attention(config=config, layer_idx=layer_idx)
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/utils.py
new file mode 100644
index 000000000..003fbe924
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen2/utils.py
@@ -0,0 +1,89 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that the pipeline would need to work with the adaptations made to Qwen2 model. """
+
+import torch
+import functools
+from importlib.metadata import version
+
+def llm_update_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, model_id_or_path, mask_neg = -100.0, cache_index=None, pad_to_left = True):
+ '''
+ This function creates a causal mask (2D) from the 1D attention mask
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. model_id_or_path: Model name or path to pretrained model
+ 6. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ 7. cache_index: the index for the starting position of kvcaches
+ 6. pad_to_left: determines if the KV cache is padded to the left or right
+ '''
+ qwen2_model = _get_model(model_id_or_path)
+
+ # if the cache position is None, then we assume that the current input ids will be concatenated to the right end and
+ # #hence we construct the cache position accordingly to be sent into the update_causal_mask
+
+ if pad_to_left:
+ # Cache index should not be passed. Concat op is used in doing the KV cache update
+ assert cache_index is None, "Invalid argument error: we do not support the combination of performing left padding and doing scatter for KV cache update."
+ else:
+ # if the user is doing right padding, it is necessary to pass the cache_index.
+ assert cache_index is not None, "Invalid argument error: we do not support the combination of performing right padding and doing concat for KV cache update"
+
+ if cache_index is None:
+ cache_position = torch.arange(model_context_len-max_input_tokens, model_context_len, device = input_tensor.device)
+ else:
+ cache_position = torch.arange(max_input_tokens, dtype=torch.float32, device=input_tensor.device) + cache_index.to(input_tensor.device)
+
+ input_embeds = torch.ones((input_tensor.shape[0], input_tensor.shape[1], 1), device = input_tensor.device)
+ prepared_attention_mask = qwen2_model._update_causal_mask(attention_mask=prepared_1d_attn_mask, input_tensor=input_embeds, output_attentions=True,
+ cache_position=cache_position, past_key_values=None)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+
+def llm_create_position_embeddings(config, position_ids=None):
+ '''
+ This function creates position embedding (RoPE) from the position ids.
+ params:
+ 1. config: model configuration to create the LLamaRotaryEmbedding object, expect config for backward compatibility, in future transformers, we only expect to pass the config, the one we have in docker, takes in the req argument
+ 2. position_ids: required position ids passed into the model
+ '''
+ max_position_embeddings = config.max_position_embeddings
+ rope_theta = config.rope_theta
+ dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size // config.num_attention_heads
+ device = position_ids.device
+ x = torch.ones(1, device=device)
+ rotary_emb = _get_rotary_embedding(dim=dim, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, device=device, config=config)
+ cos, sin = rotary_emb(x, position_ids=position_ids, seq_len=max_position_embeddings)
+ cos, sin = cos.unsqueeze(dim = 1), sin.unsqueeze(dim = 1)
+ cos = cos[:,:,:,:dim//2]
+ sin = sin[:,:,:, :dim//2]
+ return cos, sin
+
+def _get_rotary_embedding(dim, max_position_embeddings, rope_theta, device, config=None):
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2RotaryEmbedding
+ if version('transformers') >= '4.48.0':
+ assert config!=None, "Starting 4.48.0, HF have changed the init to only take the config which is required parameter https://github.com/huggingface/transformers/blob/6bc0fbcfa7acb6ac4937e7456a76c2f7975fefec/src/transformers/models/qwen2/modeling_qwen2.py#L285"
+ rotary_emb = Qwen2RotaryEmbedding(config).to(device)
+ else:
+ rotary_emb = Qwen2RotaryEmbedding(dim=dim, max_position_embeddings=max_position_embeddings, base=rope_theta,
+ device=device, config=config)
+ return rotary_emb
+@functools.cache
+def _get_model(model_id_or_path):
+ from transformers import AutoConfig
+ from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
+ config = AutoConfig.from_pretrained(model_id_or_path)
+ config.num_hidden_layers = 1
+ model = Qwen2Model(config)
+ return model
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/adaptation.py
new file mode 100644
index 000000000..e15084699
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/adaptation.py
@@ -0,0 +1,389 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to the Qwen3 4B model. These adaptations are being
+done to optimize the model execution on the HTP backend.
+https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/qwen3/modeling_qwen3.py"""
+
+import math
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from transformers.cache_utils import Cache
+from transformers.models.qwen3 import modeling_qwen3
+from transformers.models.qwen3.modeling_qwen3 import (
+ repeat_kv,
+ DynamicCache,
+ Qwen3Attention,
+ Qwen3ForCausalLM,
+ apply_rotary_pos_emb
+)
+from genai_lib.llm.dev.model_adaptation.common.spinquant import R3Hadamard
+from importlib.metadata import version
+from genai_lib.llm.long_context_utils import AnchorUpdaterKeySecond
+from genai_lib.common.dev.utils import filter_outputs
+
+
+from transformers.utils import logging
+logger = logging.get_logger(__name__)
+
+def _apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ '''
+ Based on FacebookResearch's llama, provided by Carl
+ '''
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim/2
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim/2
+
+ # TODO: Why HF uses different coordinates from the paper
+ x_real = x[:,:,:,:x.shape[-1]//2] # extract first half elements
+ x_im = x[:,:,:,x.shape[-1]//2:] # extract second half elements
+
+ x_prod_real = x_real*rope_real - x_im * rope_im
+ x_prod_im = x_real*rope_im + x_im*rope_real
+
+ # TODO: HF need to uses different interleaving
+ x = torch.cat((x_prod_real,x_prod_im),dim=3).view(*x.shape)
+ return x
+
+orig_embedding_fwd = modeling_qwen3.Qwen3RotaryEmbedding.forward
+def adapted_RotaryEmbedding(self, x, position_ids, seq_len=None, *args, **kwargs):
+ if isinstance(position_ids, tuple) and len(position_ids) == 2:
+ return position_ids
+ # transformers will move to taking in position_ids like llama moving forward for qwen3. Qwen3RotaryEmbedding -> https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3/modeling_qwen3.py
+ #In transformers version > 4.44.2, the rotary_embedding's forward function returns position_embeddings (cos, sin) for each position_id in the batch.
+ return orig_embedding_fwd(self, x, position_ids=position_ids, *args, **kwargs)
+
+
+class QcQwen3Attention(Qwen3Attention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+ def __init__(self, config, layer_idx: Optional[int] = None):
+ super(QcQwen3Attention, self).__init__(config, layer_idx)
+
+ if getattr(self.config, "enable_r3_hadamard", False):
+ """ R3 Hadamard sourced from SpinQuant paper: https://arxiv.org/pdf/2405.16406"""
+ self.q_R3 = R3Hadamard(head_dim = self.head_dim)
+ self.k_R3 = R3Hadamard(head_dim = self.head_dim)
+
+ # LONG_CONTEXT : We only initialize anchor_updater when the anchor_alpha is present in the config
+ if getattr(config, "anchor_alpha", None) is not None:
+ self.anchor_updater = AnchorUpdaterKeySecond(alpha=config.anchor_alpha)
+
+ # We only use "torch.where(attention_mask, input, min(input)-20)" sequence when the enable_masked_softmax is present in the config
+ self.enable_masked_softmax = getattr(config, "enable_masked_softmax", False)
+
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ #QC
+
+ # LONG_CONTEXT
+ valid_token_mask = kwargs.get('valid_token_mask', None)
+ anchor_buffer = kwargs.get('anchor_buffer', None)
+
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = self.q_norm(query_states.view(bsz, q_len, self.config.num_attention_heads, self.head_dim)).transpose(1, 2)
+ key_states = self.k_norm(key_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim)).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.config.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ if cos.shape[-1] == query_states.shape[-1]:
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+ else:
+ query_states = _apply_rope_single(query_states, position_embeddings)
+ key_states = _apply_rope_single(key_states, position_embeddings)
+
+ if getattr(self.config, "enable_r3_hadamard", False):
+ query_states = self.q_R3(query_states)
+ key_states = self.k_R3(key_states)
+
+ # we cache the un-transposed keys before we pass into the combined scoring model.
+ untransposed_keys = key_states
+
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_value is not None:
+ assert isinstance(past_key_value, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ "num_key_value_heads": self.config.num_key_value_heads,
+ "head_dim": self.head_dim
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # key_states is the concatenated keys
+ # past_key_value.key_cache[layer_idx] is the new key states
+ # we invoke the scoring model and insert the outputs, anchor and the evict indices into the past_key_values (DynamicCache object) as additional attributes.
+ if valid_token_mask is not None and anchor_buffer is not None:
+ anchor_buffer = anchor_buffer[self.layer_idx]
+ anchor = self.anchor_updater(new_keys = untransposed_keys, valid_token_mask=valid_token_mask, old_anchor = anchor_buffer)
+
+ insert_meta_info_to_pastkv(past_key_values = past_key_value, meta_info_bundle={"anchor_buffer":anchor}, layer_idx=self.layer_idx)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache:
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1]!=value_states.shape[-2]:
+ attention_mask = attention_mask[:, :, :, : value_states.shape[-2]]
+ if self.enable_masked_softmax:
+ attn_weights_min, _ = torch.min(attn_weights, dim=-1, keepdim=True)
+ minus_value = -20
+ attn_weights = torch.where(attention_mask==0, attn_weights, attn_weights_min + minus_value)
+ else:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.config.num_attention_heads * self.head_dim )
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+
+ return attn_output, attn_weights
+
+
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ cache_position = cache_kwargs.get('cache_position')
+ num_key_value_heads = cache_kwargs.get('num_key_value_heads')
+ head_dim = cache_kwargs.get('head_dim')
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ # if the size of past key cache passed is smaller in value than the last position where the new kv is to be inserted
+ # [in case when Cache position determined automatically by HF] (Ctx_len+ARN), then we want to perform concat and not do scattering.
+ if self.value_cache[layer_idx].shape[-2] <= cache_position[-1]:
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ else:
+ # the cache_position passed in as model i/p by user is a 1d tensor reflecting the positions
+ # from valid_kv_end to valid_kv_end+ARN, we convert this into the indices for scattering. [# bsz, num_key_value_heads, head_dim, seq_len]-> works for transposed keys
+ indices = cache_position.view(1, 1, 1, -1).expand(value_states.shape[0], num_key_value_heads, head_dim, cache_position.shape[-1])
+
+ value_cache = self.value_cache[layer_idx].scatter(dim=-2, index=indices.transpose(-1,-2), src=value_states)
+
+ indices = indices.transpose(-1, -2) if key_cat_dim== -2 else indices
+ key_cache = self.key_cache[layer_idx].scatter(dim=key_cat_dim, index=indices, src=key_states)
+
+
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = value_cache
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ if len(self.value_cache) <= layer_idx:
+ return 0
+ return self.value_cache[layer_idx].shape[-2]
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
+
+def insert_meta_info_to_pastkv(
+ past_key_values,
+ meta_info_bundle: {},
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> None:
+ """
+ This function adds two new model outputs to the DynamicCache object, eliminating the need to write a new adaptation for the additional output.
+
+ params:
+ past_key_values: The DynamicCache object
+ evict_info_bundle: A dictionary containing the additional information to attach to the DynamicCache object.
+ layer_idx: An attribute pointing to a list, where layer_idx corresponds to the data of the given layer.
+ cache_kwargs: Additional keyword arguments for the cache.
+
+ """
+ for key, value in meta_info_bundle.items():
+ if getattr(past_key_values, key, None) is None:
+ setattr(past_key_values, key, [])
+
+ key_attr = getattr(past_key_values, key, None)
+ # Update the cache
+ if len(key_attr) <= layer_idx:
+ key_attr.append(value)
+ else:
+ key_attr[layer_idx] = value
+
+def DynamicCache_to_legacy_cache(self):
+
+ """
+ Converts the DynamicCache instance to its equivalent in the legacy cache format for backward compatibility.
+
+ The past_key_values passed into the model as input is a tuple.
+ The LlamaModel converts it into a Cache object if it isn't one already. Within the model, past_key_values flow as a DynamicCache object.
+ Just before returning the output, the LlamaModel converts the DynamicCache back to the legacy cache (tuple format).
+ Since we added new attributes to our Cache object, we need to ensure they are included as additional entries in the returned tuple."""
+
+ legacy_cache = ()
+ for layer_idx in range(len(self)):
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+ if "anchor_buffer" in dir(self):
+ return legacy_cache, self.anchor_buffer
+ return legacy_cache
+
+
+class QcQwen3ForCausalLM(Qwen3ForCausalLM):
+ def __init__(self, config):
+ super().__init__(config)
+ if getattr(config, "input_tokens_per_inference", None) is not None:
+ self.register_buffer(name='cache_tensor', tensor=torch.arange(config.input_tokens_per_inference), persistent=False)
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Optional[int] = None,
+ cache_index: Optional[torch.Tensor] = None,
+ **kwargs,
+ ):
+ logits_to_keep = logits_to_keep if logits_to_keep else getattr(self.config, "logits_to_keep", 0)
+ if cache_index is not None:
+ assert hasattr(self, "cache_tensor"), "QcQwen3ForCausalLM doesn't have attribute \"cache_tensor\", " \
+ "check if \"input_tokens_per_inference\" is specified in model config"
+ cache_position = cache_index + self.cache_tensor
+
+
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ outputs = super().forward(
+ input_ids = input_ids,
+ attention_mask= attention_mask,
+ position_ids= position_ids,
+ past_key_values= past_key_values,
+ inputs_embeds= inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ logits_to_keep=logits_to_keep,
+ **kwargs)
+
+
+ if return_dict:
+ assert type(outputs.past_key_values) != tuple
+ past_key_values_output = DynamicCache.to_legacy_cache(outputs.past_key_values)
+ outputs.past_key_values = past_key_values_output
+ else:
+ new_outputs = []
+ for item in outputs:
+ if isinstance(item, DynamicCache):
+ past_key_values_output = DynamicCache.to_legacy_cache(item)
+ new_outputs.append(past_key_values_output)
+ else:
+ new_outputs.append(item)
+ outputs = tuple(new_outputs)
+
+ if hasattr(self.config, "output_index_filter"):
+ return filter_outputs(outputs, self.config.output_index_filter)
+ return outputs
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/modeling_eaglet_qwen3.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/modeling_eaglet_qwen3.py
new file mode 100644
index 000000000..49329ebc3
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/modeling_eaglet_qwen3.py
@@ -0,0 +1,31 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+"""
+Eaglet model for Qwen3
+"""
+
+from torch import nn
+from genai_lib.llm.dev.model_adaptation.qwen3.adaptation import QcQwen3Attention
+from genai_lib.llm.eaglet.base_draft_model import Eaglet2BaseDraftModel
+from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3RMSNorm
+
+
+__all__ = ["Qwen3Eaglet2DraftModel", "Qwen3EagletDecoderLayer"]
+
+class Qwen3Eaglet2DraftModel(Eaglet2BaseDraftModel):
+ def __init__(self, config):
+ super().__init__(config, decoder_cls=Qwen3EagletDecoderLayer, norm_cls=Qwen3RMSNorm)
+
+
+class Qwen3EagletDecoderLayer(Qwen3DecoderLayer):
+ def __init__(self, config, layer_idx: int):
+ super().__init__(config=config, layer_idx=layer_idx)
+ if layer_idx == 0:
+ self.input_layernorm = nn.Identity()
+ self.self_attn = QcQwen3Attention(config=config, layer_idx=layer_idx)
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/utils.py
new file mode 100644
index 000000000..e1f9f5bfd
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/dev/model_adaptation/qwen3/utils.py
@@ -0,0 +1,135 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that the pipeline would need to work with the adaptations made to Qwen3 model. """
+
+import torch
+import functools
+from importlib.metadata import version
+from genai_lib.common.dev.utils import is_transformers_greater_or_equal_than_4_51,is_transformers_greater_or_equal_than_4_53
+
+assert is_transformers_greater_or_equal_than_4_51, "Qwen3 is only supported >= 4.51.0 of transformers."
+if is_transformers_greater_or_equal_than_4_53:
+ from transformers.masking_utils import create_causal_mask
+
+
+def llm_create_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, config, mask_neg = -1e3, cache_index = None, pad_to_left = True):
+ '''
+
+ This function creates a causal mask (2D) from the 1D attention mask
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. config: config of the model
+ 6. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ 7. cache_index: the index for the starting position of kvcaches
+ 8. pad_to_left: determines if the KV cache is padded to the left or right
+ '''
+
+ # if the cache position is None, then we assume that the current input ids will be concatenated to the right end and
+ # #hence we construct the cache position accordingly to be sent into the update_causal_mask
+
+ if pad_to_left:
+ # Cache index should not be passed. Concat op is used in doing the KV cache update
+ assert cache_index is None, "Invalid argument error: we do not support the combination of performing left padding and doing scatter for KV cache update."
+ else:
+ # if the user is doing right padding, it is necessary to pass the cache_index.
+ assert cache_index is not None, "Invalid argument error: we do not support the combination of performing right padding and doing concat for KV cache update"
+
+ if cache_index is None:
+ cache_position = torch.arange(model_context_len-max_input_tokens, model_context_len, device = input_tensor.device)
+ else:
+ cache_position = torch.arange(max_input_tokens, dtype=torch.float32, device=input_tensor.device) + cache_index.to(input_tensor.device)
+
+ input_embeds = torch.ones((input_tensor.shape[0], model_context_len, 1), device = input_tensor.device)
+ mask_kwargs = {
+ "config": config,
+ "input_embeds": input_embeds,
+ "attention_mask": prepared_1d_attn_mask,
+ "cache_position": cache_position,
+ "past_key_values": None,
+ }
+ prepared_attention_mask = create_causal_mask(**mask_kwargs)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+
+def llm_update_causal_mask(prepared_1d_attn_mask, input_tensor, max_input_tokens, model_context_len, model_id_or_path, mask_neg = -100.0, cache_index=None, pad_to_left = True):
+ """
+ This function creates a causal mask (2D) from the 1D attention mask
+
+ params:
+ 1. prepared_1d_attn_mask: attention mask of shape (batch_size, model_context_length)
+ 2. input_tensor : input_ids/ input_embeddings
+ 3. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 4. model_context_len: maximum number of tokens that the model can consume in total
+ 5. model_id_or_path: Model name or path to pretrained model
+ 6. mask_neg: proxy for minus infinity since minus infinity is not quantization friendly. This value should be large
+ enough to drown out tokens that should not be attended to
+ 7. cache_index: the index for the starting position of kvcaches
+ 8. pad_to_left: determines if the KV cache is padded to the left or right
+ """
+ qwen3_model = _get_model(model_id_or_path)
+
+ # if the cache position is None, then we assume that the current input ids will be concatenated to the right end and
+ # #hence we construct the cache position accordingly to be sent into the update_causal_mask
+
+ if pad_to_left:
+ # Cache index should not be passed. Concat op is used in doing the KV cache update
+ assert cache_index is None, "Invalid argument error: we do not support the combination of performing left padding and doing scatter for KV cache update."
+ else:
+ # if the user is doing right padding, it is necessary to pass the cache_index.
+ assert cache_index is not None, "Invalid argument error: we do not support the combination of performing right padding and doing concat for KV cache update"
+
+ if cache_index is None:
+ cache_position = torch.arange(model_context_len-max_input_tokens, model_context_len, device = input_tensor.device)
+ else:
+ cache_position = torch.arange(max_input_tokens, dtype=torch.float32, device=input_tensor.device) + cache_index.to(input_tensor.device)
+
+ input_embeds = torch.ones((input_tensor.shape[0], input_tensor.shape[1], 1), device = input_tensor.device)
+ prepared_attention_mask = qwen3_model._update_causal_mask(attention_mask=prepared_1d_attn_mask, input_tensor=input_embeds, output_attentions=True,
+ cache_position=cache_position, past_key_values=None)
+ prepared_attention_mask = prepared_attention_mask.clamp_min(mask_neg)
+ return prepared_attention_mask
+
+def llm_create_position_embeddings(config, position_ids=None):
+ """
+ This function creates position embedding (RoPE) from the position ids.
+ params:
+ 1. config: model configuration to create the Qwen3RotaryEmbedding object, expect config for backward compatibility, in future transformers, we only expect to pass the config, the one we have in docker, takes in the req argument
+ 2. position_ids: required position ids passed into the model
+ """
+ max_position_embeddings = config.max_position_embeddings
+ dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size // config.num_attention_heads
+ device = position_ids.device
+ x = torch.ones(1, device=device)
+ rotary_emb = _get_rotary_embedding(device=device, config=config)
+ cos, sin = rotary_emb(x, position_ids=position_ids, seq_len=max_position_embeddings)
+ cos, sin = cos.unsqueeze(dim = 1), sin.unsqueeze(dim = 1)
+ cos = cos[:,:,:,:dim//2]
+ sin = sin[:,:,:, :dim//2]
+ return cos, sin
+
+def _get_rotary_embedding(device, config=None):
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding
+ assert config!=None, "HF have changed the init to only take the config which is required parameter https://github.com/huggingface/transformers/blob/v4.51.0/src/transformers/models/qwen3/modeling_qwen3.py#L316"
+ rotary_emb = Qwen3RotaryEmbedding(config, device=device)
+
+ return rotary_emb
+@functools.cache
+def _get_model(model_id_or_path):
+ from transformers import AutoConfig
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3Model
+ config = AutoConfig.from_pretrained(model_id_or_path)
+ config.num_hidden_layers = 1
+ model = Qwen3Model(config)
+ return model
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/base_draft_model.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/base_draft_model.py
new file mode 100644
index 000000000..bccd5e771
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/base_draft_model.py
@@ -0,0 +1,430 @@
+
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+from abc import ABC
+from typing import List, Optional, Tuple, Union
+import aimet_torch.v2.nn.modules.custom as aimet_ops
+import torch
+from torch import nn
+from transformers.activations import ACT2FN
+from transformers.cache_utils import Cache, DynamicCache
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.utils import add_start_docstrings_to_model_forward
+from genai_lib.common.dev.utils import filter_outputs
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+__all__ = ["BaseDraftModel"]
+
+class BaseDraftModel(nn.Module, ABC):
+ """
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`EagletDecoderLayer`]
+ """
+
+ def __init__(self, config=None, decoder_cls=None, norm_cls=None, dual_fc=False):
+ """
+ Args:
+ config: pretrained configuration for the model.
+ decoder_cls: decoder class.
+ decoder layer's forward is expected to return a tuple of
+ (hidden_states, attention_weights, next_cache) where:
+ - `hidden_states` is the output of the layer
+ - `attention_weights` exists if `output_attentions` is True else not in the tuple
+ - `next_cache` exists if `use_cache` is True else not in the tuple
+ norm_cls: normalization layer class for final layer norm. If None, no final layer norm is applied.
+ dual_fc: whether to use dual fc layers to combine input embeddings and hidden states.
+ Raises:
+ TypeError: if config or decoder_cls are not of the correct type.
+ """
+ super().__init__()
+ self.config = config
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.gradient_checkpointing = False
+ self.training = False
+ self.layers = nn.ModuleList(
+ [decoder_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
+ )
+ if dual_fc:
+ self.fce = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ self.fch = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
+ else:
+ self.embedding_concat = aimet_ops.Concat(axis=-1)
+ self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=config.fc_bias)
+ self.down_proj = nn.Linear(config.hidden_size, config.downsample_size, bias=False)
+ self.up_proj = nn.Linear(config.downsample_size, config.hidden_size, bias=False)
+ self.act = ACT2FN[config.hidden_act]
+ self.merged_head = nn.Linear(config.downsample_size, getattr(config, "trimmed_vocabulary_length", config.vocab_size), bias=False)
+ self.logsoftmax = nn.LogSoftmax(dim=-1)
+ if norm_cls is not None:
+ self.norm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
+ for param in self.embed_tokens.parameters():
+ param.requires_grad = False
+
+ if getattr(config, "input_tokens_per_inference", None) is not None:
+ self.register_buffer(name='cache_tensor', tensor=torch.arange(config.input_tokens_per_inference))
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward('Draft model forward')
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ hidden_states: torch.Tensor = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Optional[int] = None,
+ cache_index: Optional[torch.Tensor]=None,
+ **kwargs,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ logits_to_keep = logits_to_keep if logits_to_keep else getattr(self.config, "logits_to_keep", 0)
+
+ if cache_index is not None:
+ assert hasattr(self, "cache_tensor"), f"{self.__class__.__name__} doesn't have attribute \"cache_tensor\", " \
+ "check if \"input_tokens_per_inference\" is specified in model config"
+ cache_position = cache_index + self.cache_tensor
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if cache_position is None:
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ cache_position = torch.arange(
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
+ )
+
+ # embed positions
+ if hasattr(self, "fce") and hasattr(self, "fch"):
+ hidden_states = self.fce(inputs_embeds) + self.fch(hidden_states)
+ else:
+ concat_embeddings = self.embedding_concat(inputs_embeds, hidden_states)
+ hidden_states = self.fc(concat_embeddings)
+
+ # QC Adaptation: handle pre-computed position embeddings contained in position_ids instead of rotary_emb
+ position_embeddings = (
+ position_ids
+ if isinstance(position_ids, (tuple, list)) and len(position_ids) == 2
+ else None
+ )
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if hasattr(self, "norm"):
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # hidden_up
+ hidden_states = self.down_proj(hidden_states)
+ hidden_states = self.act(hidden_states)
+
+ # add hidden states from the down proj
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ logits = self.merged_head(hidden_states[:, slice_indices, :])
+
+ if not return_dict:
+ outputs = tuple(v for v in [logits, None, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ if hasattr(self.config, "output_index_filter"):
+ outputs = filter_outputs(outputs, self.config.output_index_filter)
+ return outputs
+
+ return CausalLMOutputWithPast(
+ loss=None,
+ logits=logits,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(
+ self,
+ attention_mask: torch.Tensor,
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ output_attentions: bool,
+ ):
+ raise NotImplementedError(f"{self.__class__.__name__}._update_causal_mask must be implemented before being invoked")
+
+
+class Eaglet2BaseDraftModel(nn.Module, ABC):
+ def __init__(self, config, decoder_cls, norm_cls=None):
+ super().__init__()
+
+ self.config = config
+ self.gradient_checkpointing = False
+ self.training = False
+
+ self.embed_tokens = nn.Embedding(
+ config.vocab_size, config.original_hidden_size, config.pad_token_id
+ )
+
+ # fc layer that takes the feature and down projects to downsample_size
+ self.feature_fc = nn.Linear(config.original_hidden_size, config.downsample_size, bias=False)
+
+ # fc layer to down proj the embeddings to the bottleneck size
+ self.embedding_fc = nn.Linear(
+ config.original_hidden_size, config.downsample_size, bias=False
+ )
+
+ self.layers = nn.ModuleList(
+ [decoder_cls(config, index) for index in range(config.num_hidden_layers)]
+ )
+
+ if norm_cls is not None:
+ self.norm = norm_cls(config.hidden_size, eps=config.rms_norm_eps)
+
+ # take the output of the decoder layer and project it back to the original hidden state size
+ self.up_proj = nn.Linear(config.downsample_size, config.original_hidden_size, bias=False)
+
+ self.merged_head = nn.Linear(
+ config.downsample_size,
+ getattr(config, "trimmed_vocabulary_length", config.vocab_size),
+ bias=False,
+ )
+ for param in self.embed_tokens.parameters():
+ param.requires_grad = False
+
+ if getattr(config, "input_tokens_per_inference", None) is not None:
+ self.register_buffer(name='cache_tensor', tensor=torch.arange(config.input_tokens_per_inference))
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ next_hidden_state: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ hidden_states: torch.Tensor = None,
+ target_hidden_states: torch.Tensor = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Optional[int] = None,
+ cache_index: Optional[torch.Tensor]=None,
+ **kwargs,
+ ):
+ logits_to_keep = (
+ logits_to_keep if logits_to_keep else getattr(self.config, "logits_to_keep", 0)
+ )
+
+ if cache_index is not None:
+ assert hasattr(self, "cache_tensor"), (
+ f'{self.__class__.__name__} doesn\'t have attribute "cache_tensor", '
+ 'check if "input_tokens_per_inference" is specified in model config'
+ )
+ cache_position = cache_index + self.cache_tensor
+
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # kept for BC (non `Cache` `past_key_values` inputs)
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
+
+ if cache_position is None:
+ past_seen_tokens = (
+ past_key_values.get_seq_length() if past_key_values is not None else 0
+ )
+ cache_position = torch.arange(
+ past_seen_tokens,
+ past_seen_tokens + inputs_embeds.shape[1],
+ device=inputs_embeds.device,
+ )
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ # QC Adaptation: Assign pre-computed position embeddings if available
+ if isinstance(position_ids, (tuple, list)) and len(position_ids) == 2:
+ position_embeddings = position_ids
+ else:
+ raise ValueError("`position_ids` should be a pre-computed position embedding.")
+
+ # QC Adaptation: Assign pre-computed causal mask if available
+ if attention_mask is not None and attention_mask.dim() == 4:
+ causal_mask = attention_mask
+ else:
+ raise ValueError("Please provide a pre-computed 4D causal mask within `attention_mask`.")
+
+ # QC Adaptation: target_hidden_states comes from the target model.
+ # Either hidden_states or target_hidden_states must be zero.
+ assert torch.all(hidden_states == 0) or torch.all(target_hidden_states == 0)
+ target_hidden_states = self.feature_fc(target_hidden_states)
+ hidden_states = hidden_states + target_hidden_states
+
+ hidden_states = self.embedding_fc(inputs_embeds) + hidden_states
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if hasattr(self, "norm"):
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if return_legacy_cache:
+ next_cache = next_cache.to_legacy_cache()
+
+ slice_indices = (
+ slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
+ )
+ logits = self.merged_head(hidden_states[:, slice_indices, :])
+
+ if not return_dict:
+ outputs = tuple(
+ v
+ for v in [None, logits, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ if hasattr(self.config, "output_index_filter"):
+ outputs = filter_outputs(outputs, self.config.output_index_filter)
+ return outputs
+
+ return CausalLMOutputWithPast(
+ loss=None,
+ logits=logits,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/inference_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/inference_utils.py
new file mode 100644
index 000000000..7135e4234
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/inference_utils.py
@@ -0,0 +1,470 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+from typing import Tuple
+import torch
+
+def _set_tree_mask(model, tree_mask):
+ model.tree_mask = tree_mask
+
+def llm_verify_draft_tree(
+ target_model: torch.nn.Module,
+ draft_tokens: torch.Tensor,
+ tree_mask: torch.Tensor,
+ tree_position_ids: torch.Tensor,
+ retrieve_indices: torch.Tensor,
+ target_past_key_values: Tuple[Tuple[torch.Tensor]],
+ prev_input_len: int):
+ """
+ Verify draft tree using target model
+ Inputs:
+ target_model: target model for verification
+ draft_tokens: draft tokens to verify
+ tree_mask: attention mask for draft tree
+ tree_position_ids: position id for draft tree
+ retrieve_indices: indices of candidate tokens of draft tree paths
+ target_past_key_values: target model KV cache
+ prev_input_len: length of previous input tokens
+ Output:
+ accepted_tokens: accepted tokens verified by target model
+ accepted_hidden_states: accepted hidden states verified by target model
+ accepted_logits: accepted logits verified by target model with shape (#accepted_tokens, #vocab)
+ accepted_token_indices: accepted token indices
+ target_past_key_values: updated target model KV cache
+ """
+
+ _set_tree_mask(target_model, tree_mask)
+
+ tree_position_ids = tree_position_ids[None] + prev_input_len
+
+ tree_logits, target_past_key_values, (hidden_state,) = target_model(
+ draft_tokens,
+ past_key_values=target_past_key_values,
+ position_ids=tree_position_ids,
+ output_hidden_states=True,
+ )
+ if isinstance(hidden_state, (tuple, list)):
+ last_hidden_state = hidden_state[-1]
+ else:
+ last_hidden_state = hidden_state
+
+ tree_logits = tree_logits[0, retrieve_indices]
+ padding = torch.full((1, 1), -1, dtype=torch.long, device=draft_tokens.device)
+ padded_draft_tokens = torch.cat((draft_tokens, padding), dim=1)
+ draft_tokens_2d = padded_draft_tokens[0, retrieve_indices]
+
+ # Find the tokens that match the maximum logits for each position in the sequence
+ posterior_mask = (
+ draft_tokens_2d[:, 1:].to(tree_logits.device) == torch.argmax(tree_logits[:, :-1], dim=-1)
+ ).int()
+
+ # posterior_mask value is 1 when a token in a tree path is accepted or not
+ # so we need to apply cumprod to mask out the latter tokens after the first rejected tokens
+ # e.g., [1, 1, 0, 1, 0, 0] -> [1, 1, 0, 0, 0, 0]
+ candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
+ accept_length = candidates_accept_length.max()
+
+ # Choose the best candidate
+ if accept_length == 0:
+ # Default to the first candidate if none are accepted
+ best_candidate = torch.tensor(0, dtype=torch.long, device=draft_tokens_2d.device)
+ else:
+ best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
+
+ # Postprocessing
+ accepted_token_indices = retrieve_indices[best_candidate, :accept_length + 1]
+ accepted_hidden_states = last_hidden_state[:, accepted_token_indices]
+ accepted_logits = tree_logits[best_candidate, :accept_length + 1]
+ accepted_tokens = draft_tokens_2d[best_candidate, :accept_length + 1]
+
+ _set_tree_mask(target_model, None)
+ return accepted_tokens, accepted_hidden_states, accepted_logits, accepted_token_indices, target_past_key_values,
+
+
+def llm_rerank_and_prune_draft_tree(
+ root_token, draft_tokens, scores_list, parents_list, total_tokens
+):
+ """
+ Rerank and prune the draft tree.
+
+ Input:
+ root_token: root of draft tree. concated before pruned draft_tokens.
+ draft_tokens: list of draft tokens for each depth.
+ scores_list: list of scores for each depth.
+ parents_list: list of parents for each depth.
+ total_tokens: total tokens to be kept after pruning draft tree.
+
+ Output:
+ pruned_tokens: pruned tokens after reranking and pruning.
+ shape=(1, total_tokens+1)
+ [[root_token_id, ...`totla_tokens` draft tokens after pruning]]
+
+ retrieve_indices: list of indices for reconstructing path to leaves.
+ The indices after the leaf are indicated as -1.
+ shape=(#leaf, max_depth)
+ [
+ [0, ...indices of the path to leaf node in the tree, -1, ..., -1],
+ ...size of leaves in the pruned tree
+ ]
+
+ tree_mask: causal mask for pruned_tokens.
+ 1 for attendable, 0 for not attendable. The tokens on the path to current token are attendable.
+ shape=(1, 1, total_tokens+1, total_tokens+1)
+ [
+ [1, 1, 0, 0, ..., 0]
+ [1, 0, 1, 0, ..., 0]
+ ...
+ [1, 0, 0, 1, ..., 1]
+ ]
+
+ tree_position: position ids of each token in the pruned tree.
+ Tokens at same depth have the same position id.
+ shape=(total_tokens+1,)
+ [0, 1, 1, 1, ..., max_depth]
+ """
+
+ #
+ # Referred topK_generate of Eaglet reference code
+ #
+
+ top_k = draft_tokens[0].shape[-1]
+ flat_scores_list = torch.cat(scores_list, dim=0).view(-1)
+ flat_draft_tokens = torch.cat(draft_tokens, dim=0).view(-1)
+
+ top_scores = torch.topk(flat_scores_list, total_tokens, dim=-1)
+ top_scores_index = top_scores.indices
+ top_scores_index = torch.sort(top_scores_index).values
+
+ # Pick `total_tokens` tokens with highest scores
+ pruned_tokens = flat_draft_tokens[top_scores_index]
+ pruned_tokens = pruned_tokens[None]
+ # Prepend `root_token` to pruned draft tree
+ pruned_tokens = torch.cat((root_token, pruned_tokens), dim=1)
+
+ # `draft_parents` is the parent of each token in `pruned_tokens`
+ draft_parents = torch.cat(parents_list, dim=0)[top_scores_index // top_k].long()
+ # `mask_index` is the index of mask for parent
+ mask_index = torch.searchsorted(top_scores_index, draft_parents - 1, right=False)
+
+ # Update the index of `mask_index` considering the prepended root
+ mask_index[draft_parents == 0] = -1
+ mask_index = mask_index + 1
+ mask_index_list = mask_index.tolist()
+
+ tree_mask = torch.eye(total_tokens + 1).bool()
+ # The root token is visible to all draft tokens
+ tree_mask[:, 0] = True
+ # Update `tree_mask` by accumulating the mask of parent
+ for i in range(total_tokens):
+ tree_mask[i + 1].add_(tree_mask[mask_index_list[i]])
+
+ tree_position_ids = torch.sum(tree_mask, dim=1) - 1
+
+ tree_mask = tree_mask.float()[None, None]
+
+ max_depth = torch.max(tree_position_ids) + 1
+ # `mask_index` has the index of parents.
+ noleaf_index = torch.unique(mask_index).tolist()
+ # Decrease 1 due to the prepended root token.
+ noleaf_num = len(noleaf_index) - 1
+ leaf_num = total_tokens - noleaf_num
+
+ retrieve_indices = torch.zeros(leaf_num, max_depth.item(), dtype=torch.long) - 1
+ retrieve_indices = retrieve_indices.tolist()
+
+ rid = 0
+ position_ids_list = tree_position_ids.tolist()
+
+ # Contruct `retrieve_indices` for reconstructing path to leaves
+ for i in range(total_tokens + 1):
+ if i not in noleaf_index:
+ cid = i
+ depth = position_ids_list[i]
+ # Update path by walking from the leaf to the root.
+ for j in reversed(range(depth + 1)):
+ retrieve_indices[rid][j] = cid
+ cid = mask_index_list[cid - 1]
+ rid += 1
+
+ retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
+
+ return pruned_tokens, retrieve_indices, tree_mask, tree_position_ids
+
+
+def llm_create_draft_tree(
+ draft_model,
+ draft_token_logits,
+ hidden_state,
+ past_key_values,
+ depth,
+ top_k,
+ tokens_seen_so_far,
+ forward_kwargs,
+ trimmed_vocab_map=None
+):
+ """
+ Create the draft tree with draft model.
+
+ Input:
+ draft_model: model to generate the draft tokens to be added to the draft tree.
+ draft_token_logits: logits for constructing first depth of draft tree. logits are come from root token.
+ hidden_state: hidden state of root token from draft model.
+ past_key_values: KV cache for draft model
+ depth: depth of draft tree
+ top_k: how many draft tokens are picked for draft tree
+ tokens_seen_so_far: the length added to position ids.
+ forward_kwargs: other arguments for draft model forward function.
+ trimmed_vocab_map: vocabulary map to get original token index from trimmed token index
+
+ Output:
+ draft_tokens: tuple of draft tokens for each depth.
+ [
+ [...`top_k` tokens ids at 1st depth],
+ [...`top_k`**2 tokens ids at 2nd depth],
+ ...,
+ [...`top_k`**2 tokens ids at {`depth` + 1}-th depth]
+ ]
+
+ scores_list: tuple of scores for each depth.
+ [
+ [...scores for `top_k` tokens at 1st depth],
+ [...cumulative scores for `top_k`**2 tokens at 2nd depth],
+ ...,
+ [...cumulative scores for `top_k`**2 tokens at {`depth` + 1}-th depth]
+ ]
+
+ parents_list: tuple of parents for each depth.
+ [
+ [0], # the parent index for tokens at 1st depth which indicates the root token.
+ [...indices for `top_k` tokens at 2nd depth],
+ ...,
+ [...indices for `top_k` tokens at {`depth` + 1}-th depth],
+ ]
+ """
+
+ #
+ # Referred topK_generate of Eaglet reference code
+ #
+ draft_tokens = []
+ scores_list = []
+ parents_list = []
+
+ device = hidden_state.device
+
+ # Pick `top_k` from `draft_token_logits` to construct first depth of draft tree.
+ last_p = torch.log_softmax(draft_token_logits[:, -1], dim=-1)
+ top = torch.topk(last_p, top_k, dim=-1)
+ topk_index, topk_p = top.indices, top.values
+
+ if trimmed_vocab_map is not None:
+ topk_index = trimmed_vocab_map[topk_index]
+
+ scores = topk_p[0]
+ # Store scores of the draft tokens for cummulative scores.
+ scores_list.append(scores[None])
+ # Initialize the parent index of the first depth as the root token.
+ parents_list.append(torch.zeros(1, dtype=torch.long, device=device))
+ # Initialize draft tree with `top_k` tokens in 1st depth.
+ draft_tokens.append(topk_index)
+
+ # Prepare inputs for forward of draft model.
+ input_ids = topk_index
+ input_hidden = hidden_state.repeat(1, top_k, 1)
+ tree_mask_init = torch.eye(top_k, device=device)[None, None]
+ tree_mask = tree_mask_init
+ cum_topk_index = torch.arange(top_k, device=device)
+ position_ids_init = torch.zeros(top_k, dtype=torch.long, device=device)
+ len_posi = tokens_seen_so_far
+
+ for i in range(depth):
+ _set_tree_mask(draft_model, tree_mask)
+ position_ids = len_posi + position_ids_init
+ position_ids = position_ids.expand((input_ids.shape[0], -1))
+
+ (last_headout, past_key_values, out_hidden) = draft_model(
+ hidden_states=input_hidden,
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ position_ids=position_ids,
+ output_hidden_states=True,
+ **forward_kwargs,
+ )
+
+ # Update position offset for next inference
+ len_posi += 1
+
+ # Update `parents_list` considering depth with `bias`
+ bias1 = top_k if i > 0 else 0
+ bias2 = max(0, i - 1)
+ # The flattened draft tree looks like [root_token, ...`top_k` tokens at 1st depth, ...`top_k`**2 tokens at n-th depth, ...]
+ # `bias` indicates the number of tokens in the previous depth:
+ # - `1` for the root token
+ # - `bias1` for `top_k` tokens at 1st depth.
+ # - `bias2` for `top_k`**2 tokens at n-th depth (n > 1).
+ bias = 1 + top_k**2 * bias2 + bias1
+ # `cum_topk_index` indicates the offset of `top_k` tokens in the current depth
+ parents = cum_topk_index + bias
+ parents_list.append(parents)
+
+ # Applied log scale to use addition when accumulating probabilities
+ # and to prevent the cumulative magnitude from becoming too small.
+ last_p = torch.log_softmax(last_headout[0], dim=-1)
+
+ top = torch.topk(last_p, top_k, dim=-1)
+ topk_index, topk_p = top.indices, top.values
+
+ if trimmed_vocab_map is not None:
+ topk_index = trimmed_vocab_map[topk_index]
+
+ # Update cumulative scores of the last `top_k`**2 draft tokens.
+ cum_scores = topk_p + scores[:, None]
+
+ # Pick `top_k` tokens from the last `top_k`**2 draft tokens.
+ cum_topk = torch.topk(cum_scores.view(-1), top_k, dim=-1)
+ cum_topk_index, cum_topk_p = cum_topk.indices, cum_topk.values
+ scores = cum_topk_p
+
+ # Pick hidden state of parents of top_k tokens for next inference.
+ # By dividing by `top_k`, we can get the indices of parent among the input tokens.
+ out_ids = cum_topk_index // top_k
+ input_hidden = out_hidden[-1][:, out_ids]
+
+ input_ids = topk_index.view(-1)[cum_topk_index][None]
+
+ # Update `draft_tokens` with all `top_k` tokens per each draft token.
+ draft_tokens.append(topk_index)
+ scores_list.append(cum_scores)
+ tree_mask = torch.cat((tree_mask[:, :, out_ids], tree_mask_init), dim=3)
+
+ _set_tree_mask(draft_model, None)
+
+ return tuple(draft_tokens), tuple(scores_list), tuple(parents_list)
+
+
+def llm_apply_tree_mask(
+ causal_mask, tree_mask, max_input_tokens, mask_neg=-100.0, pad_to_left=True, cache_index=None
+):
+ """
+ Mask out causal mask wiht tree mask.
+
+ Inputs:
+ causal_mask: causal mask with shape (batch_size, 1, input_length, context_length)
+ tree_mask: mask consisting of 0 and 1 with shape (input_length, tree_size)
+ max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ mask_neg: proxy for minus infinity since minus infinity is not quantization friendly.
+ This value should be large enough to drown out tokens that should not be attended to
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+ cache_index: cache_index determines where the attn_mask_input was placed.
+ If None, the input_attention mask was placed towards the end.
+
+ Outputs:
+ causal_mask: masked out `causal_mask` with `tree_mask`
+
+ Example: max_input_tokens = 4
+ Left padding case)
+ cache_index = 6
+ causak_mask: shape=(1,1,4,10) like [2 paddings, 4 valid kv$, 1 padding, 3 valid inputs]
+ [[[
+ [mask_neg, mask_neg, 0, 0, 0, 0, mask_neg, mask_neg, mask_neg, mask_neg],
+ [mask_neg, mask_neg, 0, 0, 0, 0, mask_neg, 0, mask_neg, mask_neg],
+ [mask_neg, mask_neg, 0, 0, 0, 0, mask_neg, 0, 0, mask_neg],
+ [mask_neg, mask_neg, 0, 0, 0, 0, mask_neg, 0, 0, 0],
+ ]]]
+
+ tree_mask: shape=(1,1,3,6)
+ [[[
+ [1, 0, 0, 1, 0, 0],
+ [1, 0, 1, 0, 1, 0],
+ [0, 1, 0, 0, 0, 1],
+ ]]]
+
+ causal_mask after applying tree_mask: *mask_neg*s are masked out by tree_mask
+ [[[
+ [mask_neg, mask_neg, 0, 0, 0, 0, mask_neg, mask_neg, mask_neg, mask_neg],
+ [mask_neg, mask_neg, 0, 0, *mask_neg*, *mask_neg*, mask_neg, 0, mask_neg, mask_neg],
+ [mask_neg, mask_neg, 0, 0, *mask_neg*, 0, mask_neg, *mask_neg*, 0, mask_neg],
+ [mask_neg, mask_neg, 0, *mask_neg*, 0, *mask_neg*, mask_neg, *mask_neg*, *mask_neg*, 0],
+ ]]]
+
+ Right padding case)
+ cache_index = 4
+ causak_mask: shape=(1,1,4,10) like [4 valid kv$, 3 valid inputs, 3 paddings]
+ [[[
+ [0, 0, 0, 0, 0, mask_neg, mask_neg, mask_neg, mask_neg, mask_neg],
+ [0, 0, 0, 0, 0, 0, mask_neg, mask_neg, mask_neg, mask_neg],
+ [0, 0, 0, 0, 0, 0, 0, mask_neg, mask_neg, mask_neg],
+ [0, 0, 0, 0, 0, 0, 0, mask_neg, mask_neg, mask_neg],
+ ]]]
+
+ tree_mask: shape=(1,1,3,6)
+ [[[
+ [1, 0, 0, 1, 0, 0],
+ [1, 0, 1, 0, 1, 0],
+ [0, 1, 0, 0, 0, 1],
+ ]]]
+
+ causal_mask after applying tree_mask: *mask_neg*s are masked out by tree_mask
+ [[[
+ [0, 0, *mask_neg*, *mask_neg*, 0, mask_neg, mask_neg, mask_neg, mask_neg, mask_neg],
+ [0, 0, *mask_neg*, 0, *mask_neg*, 0, mask_neg, mask_neg, mask_neg, mask_neg],
+ [0, *mask_neg*, 0, *mask_neg*, *mask_neg*, *mask_neg*, 0, mask_neg, mask_neg, mask_neg],
+ [0, 0, 0, 0, 0, 0, 0, mask_neg, mask_neg, mask_neg],
+ ]]]
+
+ """
+ _, _, input_length, tree_mask_target_len = tree_mask.shape
+ remained_target_length = tree_mask_target_len - input_length
+
+ cache_index = causal_mask.shape[-1] - max_input_tokens if cache_index is None else cache_index
+ if pad_to_left:
+ causal_mask[:, :, -input_length:, cache_index+max_input_tokens-input_length:cache_index+max_input_tokens][
+ tree_mask[:, :, :, -input_length:] == 0
+ ] = mask_neg
+ causal_mask[:, :, -input_length:, cache_index - remained_target_length : cache_index][
+ tree_mask[:, :, :, :-input_length] == 0
+ ] = mask_neg
+ else:
+ causal_mask[:, :, :input_length, cache_index-remained_target_length:cache_index+input_length][
+ tree_mask == 0
+ ] = mask_neg
+
+ return causal_mask
+
+
+def llm_evaluate_block_efficiency(output_ids, accept_lengths, stop_token_ids=None):
+ """
+ Evaluate block eficiency about output before the stope sequence.
+
+ Inputs:
+ output_ids: the generated ids for evaluation Block Efficiency.
+ accept_lengths: list of lengths of accepted tokens for each iteration.
+ stop_sequences: stop sequences used to stop generation.
+
+ Outputs:
+ block_efficiency: average accepted length for each iteration.
+ generation_length: length of the output excluding stop sequences.
+ output_ids: output_ids before the stop sequence.
+ """
+ iter_count = len(accept_lengths)
+ if stop_token_ids is not None:
+ if not isinstance(stop_token_ids, (tuple, list)):
+ stop_token_ids = [stop_token_ids]
+ stop_indices = [i for i, id in enumerate(output_ids) if id in stop_token_ids]
+ if len(stop_indices) > 0:
+ stop_index = stop_indices[0]
+ if accept_lengths[-1] == 1:
+ # Ignore last iteration if the last accepted token is solely stop token.
+ iter_count = iter_count - 1
+
+ output_ids = output_ids[:stop_index]
+
+ block_efficiency = len(output_ids) / (iter_count)
+ generation_length = len(output_ids)
+
+ return block_efficiency, generation_length, output_ids
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/modeling_eaglet.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/modeling_eaglet.py
new file mode 100644
index 000000000..7167b16d4
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/eaglet/modeling_eaglet.py
@@ -0,0 +1,119 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" Eaglet speculative decoding algorithm https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py """
+from typing import Optional, Tuple
+import torch
+from torch import nn
+from transformers.cache_utils import Cache
+from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm
+from genai_lib.llm.dev.model_adaptation.llama.adaptation import QcLlamaAttention
+from genai_lib.common.dev.utils import is_transformers_greater_or_equal_than_4_48
+from transformers.models.llama.configuration_llama import LlamaConfig
+from .base_draft_model import BaseDraftModel, Eaglet2BaseDraftModel
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+__all__ = ["EagletDraftModel", "EagletDecoderLayer", "Eaglet2DraftModel"]
+
+
+class EagletDraftModel(BaseDraftModel):
+ def __init__(self, config):
+ super().__init__(config=config, decoder_cls=EagletDecoderLayer)
+
+
+class EagletDecoderLayer(nn.Module):
+ def __init__(self, config: LlamaConfig, layer_idx: int):
+ super().__init__()
+ self.self_attn = QcLlamaAttention(config=config, layer_idx=layer_idx)
+ self.mlp = LlamaMLP(config)
+ if layer_idx > 0:
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ residual = hidden_states
+
+ if hasattr(self, "input_layernorm"):
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ self_attn_outputs = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ if is_transformers_greater_or_equal_than_4_48:
+ hidden_states, self_attn_weights = self_attn_outputs
+ # past_key_value is an instance of Cache having new KV cache after the attention
+ present_key_value = past_key_value
+ else:
+ hidden_states, self_attn_weights, present_key_value = self_attn_outputs
+
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+ if output_attentions:
+ outputs += (self_attn_weights,)
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class Eaglet2DraftModel(Eaglet2BaseDraftModel):
+ def __init__(self, config):
+ super().__init__(config=config, decoder_cls=EagletDecoderLayer)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/evaluation_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/evaluation_utils.py
new file mode 100644
index 000000000..8213ca72e
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/evaluation_utils.py
@@ -0,0 +1,123 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides evaluation utilities for LLM Lib """
+
+import contextlib
+from torch.nn import CrossEntropyLoss
+import torch
+from tqdm import tqdm
+from torch.utils._pytree import tree_map
+
+
+def change_tensor_device_placement(input_data, device: torch.device):
+ """
+ Change the tensor_data's device placement
+
+ :param input_data: torch.tensor , list of torch.tensors, or tuple of torch.tensors
+ :param device: device
+ :return: tensor_data with modified device placement
+
+ Duplicated code with AIMET_Torch to remove dependency
+ """
+ return tree_map(
+ lambda x: x.to(device) if isinstance(x, torch.Tensor) else x, input_data
+ )
+
+@contextlib.contextmanager
+def _place_model_in_eval_mode(model):
+ '''Temporarily switch to evaluation mode.'''
+ istrain = model.training
+ try:
+ model.eval()
+ yield model
+ finally:
+ if istrain:
+ model.train()
+def llm_compute_loss_from_logits(outputs, labels):
+ '''
+ This function computes the loss from the logits and the labels passed
+ '''
+ #Get the outputs and move it to CPU. Assumes that index 0 is logits as
+ lm_logits = outputs[0].cpu()
+ shift_logits = lm_logits[..., :-1, :].contiguous().to(dtype=torch.float32)
+ shift_labels = labels[..., 1:].contiguous().to(shift_logits.device)
+
+ #Compute the loss
+ loss_fn = CrossEntropyLoss()
+ neg_log_likelihood = loss_fn(
+ shift_logits.view(-1, shift_logits.size(-1)),
+ shift_labels.view(-1),
+ )
+ return neg_log_likelihood
+
+@torch.no_grad()
+def llm_evaluate_ppl(model, encoded_dataset, max_length, stride=2048) -> float:
+ '''
+ This function takes in an encoded dataset string and computes ppl
+ params:
+ model: the model to evaluate
+ encoded_dataset: the encoded dataset
+ max_length: represents the max length of inputs model can take in
+ stride: the stride used for sliding window over the dataset
+ '''
+ nlls = []
+ prev_end_loc = 0
+ seq_len = encoded_dataset.input_ids.size(1)
+ device = model.device
+ with _place_model_in_eval_mode(model):
+
+ for begin_loc in tqdm(range(0, seq_len, stride)):
+ end_loc = min(begin_loc + max_length, seq_len)
+ trg_len = end_loc - prev_end_loc # may be different from stride on last loop
+ input_ids = encoded_dataset.input_ids[:, begin_loc:end_loc].to(device)
+ labels = input_ids.clone()
+ labels[:, :-trg_len] = -100
+
+ #Labels are not passed into the model. Loss computation happens outside the mode
+ outputs = model(input_ids)
+
+ nlls.append(llm_compute_loss_from_logits(outputs,labels))
+ del outputs
+ #Set up variables for the next batch of data
+ prev_end_loc = end_loc
+ if end_loc == seq_len:
+ break
+ ppl = torch.exp(torch.stack(nlls).mean())
+ return float(ppl)
+
+@torch.no_grad()
+def llm_evaluate_ppl_with_dataloader(model, dataloader, num_batches=None, model_forward_kwargs={}):
+ '''
+ This function takes in a dada loader and a model and computes ppl score
+ params:
+ model: the model to evaluate
+ dataloader: dataset loader
+ num_batches: number of batches to run evaluation on
+ '''
+ num_batches = num_batches if num_batches else len(dataloader)
+ nlls=[]
+ device=model.device
+ model_forward_kwargs = change_tensor_device_placement(model_forward_kwargs, device)
+
+ for batch_id, batch in enumerate(tqdm(dataloader, total=num_batches, desc="Evaluating")):
+ if batch_id >= num_batches:
+ break
+ if "inputs_embeds" in batch:
+ batch["input_ids"] = batch["labels"]
+ batch["inputs_embeds"] = batch["inputs_embeds"].to(device)
+ outputs = model(inputs_embeds=batch["inputs_embeds"], **model_forward_kwargs)
+ else:
+ batch["input_ids"] = batch["input_ids"].to(device)
+ outputs = model(input_ids=batch["input_ids"], **model_forward_kwargs)
+
+ nlls.append(llm_compute_loss_from_logits(outputs, batch["input_ids"]))
+ del outputs
+ ppl = torch.exp(torch.stack(nlls).mean())
+ return float(ppl)
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/long_context_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/long_context_utils.py
new file mode 100644
index 000000000..d6d88ab3f
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/long_context_utils.py
@@ -0,0 +1,530 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+"""
+Utility APIs for supporting Key Similarity Long Context and Sliding Window
+"""
+# as per design the most updated design considering pytorch & multihead all bmm.
+import os
+import pickle
+import torch
+import torch.nn as nn
+import aimet_torch.nn.modules.custom as modules
+from genai_lib.llm.test_vectors import to_torch_tensor,to_cpu
+from typing import Tuple, List, Optional
+
+MatMul = modules.MatMul
+class AnchorUpdaterKeySecond(nn.Module):
+ def __init__(self, alpha: float = 0.001):
+ super().__init__()
+ self.alpha = alpha
+ self.one_minus_alpha = 1 - alpha
+ self.mul = MatMul()
+
+ def forward(
+ self,
+ new_keys: torch.Tensor,
+ valid_token_mask: torch.Tensor,
+ old_anchor: torch.Tensor,
+ ):
+ """
+ inputs:
+ new_keys: [bsz, heads, ar_n, head_dim]
+ valid_token_mask: [bsz, heads, 1, ar_n]
+ old_anchor: [bsz, heads, 1, head_dim]
+ outputs:
+ new_anchor: [bsz, heads, 1, head_dim]
+ """
+ new_anchor = self.mul(
+ valid_token_mask, # [bsz, heads, 1, ar_n]
+ new_keys, # [bsz, heads, ar_n, head_dim]
+ ) # [bsz, heads, 1, head_dim]
+ new_anchor = new_anchor + self.one_minus_alpha * old_anchor
+ return new_anchor
+
+
+class ScorerKeySecond(nn.Module):
+
+ def __init__(self, num_keys):
+ super().__init__()
+ self.mul = nn.ModuleList([MatMul() for _ in range(num_keys)])
+
+ def forward(self, keys: tuple, anchor_buffer: tuple):
+ """
+ inputs:
+ keys: tuple of length config.num_hidden_layers
+ where each item is of shape [bsz, heads, head_dim, context_len]
+
+ anchor: tuple of length config.num_hidden_layers
+ where each item is of shape [bsz, heads, 1, head_dim]
+ outputs:
+ score: tuple of length config.num_hidden_layers
+ where each item is of shape [bsz, heads, 1, contex_len]
+ """
+ score = ()
+ for i in range(len(keys)):
+ # anchor[i] shape [bsz, heads, 1, head_dim]
+ # keys[i] shape [bsz, heads, head_dim, context_len]
+ # curr_score shape [bsz, heads, 1, context_len]
+ curr_score = self.mul[i](anchor_buffer[i], keys[i])
+ score += (curr_score,)
+
+ return score
+
+
+def get_scorer_input_output_names(num_hidden_layers):
+
+ """
+ This function is responsible for returning a list of the model input and output names based on the number of hidden layers for scoring network
+ num_hidden_layers: number of hidden layers of the languge model
+
+ :params
+ 1. num_hidden_layers : num_hidden_layers in language model config
+
+ """
+ def _get_names(pfx, sfx, n_layers):
+ all = []
+ for i in range(n_layers):
+ all.append(f'{pfx}_{i}_{sfx}')
+ return all
+
+ input_names=[]
+ input_names += _get_names("keys", "in", num_hidden_layers)
+ input_names += _get_names("anchor_buffer", "in", num_hidden_layers)
+
+ output_names=[]
+ output_names += _get_names("score", "out", num_hidden_layers)
+ return input_names, output_names
+
+
+
+
+def generate_scorer_test_vectors(fp_language_model, fp_scoring_model, quantsim, num_hidden_layers, input_ids, batch_index, output_dir):
+ """
+ This function is responsible for generating test vectors for fp,qt model in output_dir
+ :params
+ 1. fp_language_model: fp adapted language model
+ 2. fp_scoring_model: fp scoring network model
+ 3. scorer quantsim model : quantsim of scoring network model
+ 4. num_hidden_layers: num_hidden_layers of language model
+ 5. input_ids: input_ids inputs for language model
+ 6. batch_index: batch index of dataloader
+ 7. output_dir: output directory for test vector generation
+ """
+ test_vectors_dir = os.path.join(output_dir,'test_vectors')
+ os.makedirs(test_vectors_dir, exist_ok=True)
+ fp_vectors_pickle_dict = {}
+ qt_vectors_pickle_dict = {}
+ language_model_outputs = fp_language_model(input_ids=input_ids)
+ model_input = {}
+ model_input["keys"] = ()
+ for n_layer in range(num_hidden_layers):
+ keys = language_model_outputs[1][n_layer][0]
+ model_input["keys"] += (keys,)
+
+ model_input["anchor_buffer"] = fp_language_model.anchor_buffer
+ fp_scoring_model_out = fp_scoring_model(model_input["keys"], model_input["anchor_buffer"])
+ fp_vectors_pickle_dict[str(batch_index)]= model_input
+ fp_vectors_pickle_dict[str(batch_index)]["score"] = fp_scoring_model_out
+
+ qsim_scoring_model_out = quantsim.model(model_input["keys"], model_input["anchor_buffer"])
+ qt_vectors_pickle_dict[str(batch_index)] = model_input
+ qt_vectors_pickle_dict[str(batch_index)]["score"] = qsim_scoring_model_out
+ with open(os.path.join(test_vectors_dir, f"fp_{batch_index}.pkl"), "wb") as fp_pickle_file:
+ pickle.dump(to_cpu(to_torch_tensor(fp_vectors_pickle_dict)), fp_pickle_file)
+ with open(os.path.join(test_vectors_dir, f"qt_{batch_index}.pkl"), "wb") as qt_pickle_file:
+ pickle.dump(to_cpu(to_torch_tensor(qt_vectors_pickle_dict)), qt_pickle_file)
+
+
+def llm_compute_scores(scorer, past_key_values, anchor, valid_kv_len=None, pad_to_left=True):
+ """
+ The function calculates the scores, where a higher score indicates greater similarity between the anchor vector and the given key
+ It calls the scorer for each layer and returns a tuple of scores in the format (1, n_heads, 1, ctx_len)*n_layers (assuming keys are second input to matmul).
+
+ We consider two scenarios for the past_key_values:
+
+ 1. Static Shape Input in the Scorer: The past_key_values contain padded key-value pairs, assuming the scorer requires a static shape input.
+ The valid_kv_length is used to identify true non-padded key-value pairs. To ensure padding values are always evicted, we update the scores by
+ maximizing/ inflating the scores corresponding to padding.
+
+ 2. Dynamic Shape Input in the Scorer: The scorer can accept dynamic shape input, in which case the past_key_values are not padded.
+
+ :params
+ 1. scorer: a callable which computes scores given the keys and the anchor vector
+ 2. past_key_values: past_key_values for all the layers
+ 3. anchor: the anchor vector tuple representing the anchor vector for each layer, this is essentially the exponential moving average of the keys seen so far
+ 4. valid_kv_len: valid kv length represents the amount of valid kv in the past_key_values in case of padding
+ 5. pad_to_left: boolean value indicating whether padding is done towards the left or right
+
+ """
+ updated_scores = ()
+ all_keys = tuple(k for k,_ in past_key_values)
+ scores = scorer(all_keys, anchor)
+
+ for score in scores:
+ if valid_kv_len is not None:
+ max_values, _ = torch.max(score, dim=3, keepdim=True)
+
+ if pad_to_left:
+ score[:, :, : ,:-valid_kv_len] = max_values
+ else:
+ score[:, :, :, valid_kv_len:] = max_values
+
+ updated_scores += (score,)
+
+ # the assertion ensures that there is parity between the shape of scores and past_kv shape along the sequence dimension.
+ # dim represents the dimension along the sequence_length
+ assert updated_scores[0].shape[-1] == past_key_values[0][1].shape[2]
+ return updated_scores
+
+
+def llm_compute_indices_to_keep(scores, model_context_len, max_input_tokens,
+ new_KV_per_inference, eviction_frequency = 1, num_kv_to_evict=None):
+ '''
+ The function is responsible for computing the indices to either keep or evict based on the scores.
+ If the user does not specify how many key-value pairs (kv), we determine it dynamically using the maximum allowed kv cache budget,
+ the eviction frequency, and the eviction amount at each frequency (which varies between the prefill and inference stages).
+
+ params:
+
+ 1. scores: Tensor representing the similarity of the keys with the anchor.
+ 2. model_context_len : the maximum context length that can be sent to the model (this is the HF maximum length of the context)
+ 3. max_input_tokens: the maximum tokens that can be sent to the model, in our context represents the AR length
+ 4. eviction_frequency: The eviction frequency refers to how often (after how many inferences) this API is called.
+ If the frequency is 1, it implies that this API is called after every inference.
+ If the frequency is more than 1, it implies that eviction will be called once every "frequency" inferences
+ in which case we should ensure that we create enough room to accommodate new KV until the next opportunity
+ to evict.
+ 5. new_KV_per_inference: This is the number of tokens whose KV we expect will get added at each inference. Ideally,
+ for the prefill stage, we set it to ARN, and 1 for inference (for SSD, this is the forecast tokens).
+ 6. num_kv_to_evict: Optionally, the user can specify the exact amount of kv to evict. If this is provided, we do
+ not compute the num_kv_to_evict. If it is not provided we assume that the new_KV that was output by the language
+ model is not in the candidate list for eviction.
+ '''
+ #score_per_layer = (bsz, num_kv_heads, 1, seq_len)
+ current_len = scores[0].shape[-1] # 4051 if dynamic tensor/ 4096 if static
+
+ if num_kv_to_evict is None:
+ # Reduce the budget by the sum of max_input_tokens and new_KV_per_inference.
+ # This is because we have new KV from the latest inference that is pending addition.
+ # In addition we need to create room for the next max_input_tokens that will be sent in during next inf.
+ budget = model_context_len - (max_input_tokens + new_KV_per_inference)
+
+ num_kv_to_evict = current_len - budget
+ num_kv_to_evict += (eviction_frequency - 1) * new_KV_per_inference
+
+ keep_indices_tuple = ()
+ num_kv_to_keep = current_len - num_kv_to_evict
+ assert num_kv_to_keep > 0, "The num_kv_to_evict should not exceed the current kv length"
+ for score in scores:
+ # in order to run topk, we flip the scores by multiplying with a -1 so we get the actual lowest scores
+ keep_indices = torch.topk(-1 * score, k=num_kv_to_keep, dim=-1)[1]
+ # needed to move the indices in the third dimension (aligned with shape of value cache) where seq_len is in third dimension
+ keep_indices = keep_indices.transpose(-1, -2)
+ keep_indices_tuple += (keep_indices,)
+ return keep_indices_tuple
+
+
+def reindexer_default(keep_indices, past_key_values):
+ reindexer = Reindexer(keep_indices, past_key_values[1].shape[-1])
+
+ updated_key = reindexer.reindex_tensor(past_key_values[0].transpose(-1, -2)).transpose(-1, -2)
+ updated_value = reindexer.reindex_tensor(past_key_values[1])
+
+ return (updated_key, updated_value)
+
+
+def llm_compress_kv_cache(past_key_values, keep_indices, compression_algorithm=reindexer_default):
+ '''
+ The function is responsible for compressing the kv cache whenever the past_kv exceeds the budget.
+ It uses the gather op internally.
+
+ :params
+ 1. past_key_values: past_key_values to be compressed
+ 2. keep_indices: tensor representing the keep indices of shape [bsz, num_heads, indices, 1]
+ 3. compression_algorithm: the caller takes in the keep_indices and the kv for a given layer and returns the compressed kv.
+
+ '''
+ max_kv_idx = past_key_values[0][1].shape[2]-1
+
+ # following check ensures that the keep_indices don't exceed the kv length.
+ for layer_idx in range(len(past_key_values)):
+ # (bsz, num_heads, keep_indices, 1) -> (num_heads)
+ max_indices_per_head = torch.max(keep_indices[layer_idx], dim =2)[0].squeeze(dim=0).squeeze(dim=-1)
+ assert torch.all(max_indices_per_head <= max_kv_idx), "keep_indices should fall within the input past_key_values range"
+
+ updated_past_key_values = ()
+ for layer_idx in range(len(past_key_values)):
+ updated_key_value = compression_algorithm(keep_indices[layer_idx], past_key_values[layer_idx], )
+ updated_past_key_values += (updated_key_value,)
+
+ return updated_past_key_values
+
+class Reindexer:
+ """
+ Reindexer takes in the tensor and the keep_indices and performs gather to retain the values corresponding to keep_indices.
+ """
+ def __init__(self, keep_indices: torch.Tensor, head_dim: int):
+ # keep_idx.shape = [bsz, heads, window, 1]
+ self.keep_idx = keep_indices.sort(-2).values
+
+ # idx.shape = [bsz, heads, window, dim]
+ self.idx = self.keep_idx.expand(-1, -1, -1, head_dim)
+
+ def reindex_tensor(self, states_tensor: torch.Tensor, sequence_dim=-2) -> torch.Tensor:
+ return torch.gather(states_tensor, sequence_dim, self.idx)
+
+def llm_update_overwriting_cache(model, scores, num_extra_kvs):
+ '''
+ The API is responsible for adding the necessary indices to evict from the scores into the overwriting_index_cache. We first determine if we need to add indices to the overwriting_index_cache.
+
+ 1. If this is the initial instance of populating the cache, we add the indices to the overwriting_index_cache.
+
+ 2. If the remaining indices in the cache are fewer than num_extra_kvs, we retrieve the indices from the scores. Note that these scores do not take in the the new Key Cache in their computation.
+
+ We iterate over the layers and fill the overwriting_index_cache.
+
+ params:
+ model: The object that holds the overwriting index cache and the configuration head dimension.
+ scores: A tuple of scores for each layer, with the shape [batch_size, num_heads, 1, window].
+ num_extra_kvs: This value indicates the extent to which the concatenation of accumulated and new key-value pairs exceeds the budget. It represents the number of indices required from the overwriting_index_cache.
+ '''
+
+ if not hasattr(model, 'overwriting_index_cache'):
+ raise AttributeError("The passed model does not support lazy eviction")
+ overwriting_index_cache = model.overwriting_index_cache.get(0, None)
+
+ # the flag to check whether we need to write the indices into the overwriting_index_cache
+ overwrite_cache = False
+ if overwriting_index_cache is None:
+ overwrite_cache = True
+ else:
+ if overwriting_index_cache.shape[-2] < num_extra_kvs:
+ overwrite_cache = True
+ if overwrite_cache:
+ for layer_idx in range(len(scores)):
+ head_dim = model.config.head_dim
+
+ # this was taken from the systems implementation, I believe they want to rule out any scenarios where the model.overwriting_index_len (perhaps set incorrectly by the user) is less than num_extra_kvs
+ k = max(model.overwriting_index_len, num_extra_kvs)
+
+ #(bsz, n_heads, 1, past_kv_len) -> (bsz, n_heads, 1, k) -> (bsz, n_head, k, 1)
+ top_k_idx = torch.topk(scores[layer_idx], k, dim=-1)[1].transpose(-1, -2)
+
+ # Expand indices in the last dim to match head_dim
+ #(bsz, n_head, seq_len, 1) -> (bsz, n_head, seq_len, head_dim) similar to value cache shape
+ model.overwriting_index_cache[layer_idx] = top_k_idx.expand(
+ -1, -1, -1, head_dim
+ )
+
+def llm_scatter_exceeded_kv_using_lazy_eviction(model, past_key_values, num_extra_kvs, key_concat_axis,
+ value_concat_axis,):
+ '''
+ This API is responsible for scattering the exceeded KV cache into the indices of the old KV that will undergo eviction.
+ params:
+ model: this object contains the overwriting_index_cache
+ past_key_values: the accumulated old and new KV cache
+ num_extra_kvs: This value indicates the extent to which the concatenation of accumulated and new key-value pairs exceeds the budget. It represents the number of indices required from the overwriting_index_cache.
+ key_concat_axis: the axis to which we want to append the keys
+ value_concat_axis: the axis to which we want to append the values
+
+ '''
+ if not hasattr(model, 'overwriting_index_cache'):
+ raise AttributeError("The passed model does not support lazy eviction")
+
+ # container to store the updated KV cache
+ updated_past_key_values = ()
+ total_num_kvs = past_key_values[0][1].shape[2]
+ for layer_idx in range(len(past_key_values)):
+ # extract the portion of KV which equals the budget size
+ cached_key = past_key_values[layer_idx][0][..., :total_num_kvs-num_extra_kvs]
+ cached_value =past_key_values[layer_idx][1][..., :total_num_kvs-num_extra_kvs, :]
+
+ # extract the exceeded_KV portion
+ exceed_key = past_key_values[layer_idx][0][..., total_num_kvs-num_extra_kvs:]
+ exceed_value =past_key_values[layer_idx][1][..., total_num_kvs-num_extra_kvs:, :]
+
+ # Extract the top / front most num_extra_kvs indices from the overwriting_index_cache
+ #idx (bsz, n_head, window, head_dim) -> (bsz, n_head, num_extra_kvs, head_dim)
+ scatter_idx = model.overwriting_index_cache[layer_idx][..., :num_extra_kvs, :]
+
+
+ # Scattering new exceeded values into old value cache at scatter index positions
+ # the shape of scatter indices align with the value cache dimesnion, seq_len in the dim=-2
+ updated_value = torch.scatter(cached_value, value_concat_axis, scatter_idx, exceed_value)
+
+ # Scattering new exceeded keys into old key cache at scatter index positions
+ # transpose scatter_idx if keys are transposed
+ updated_key = torch.scatter(cached_key, key_concat_axis, scatter_idx.transpose(-1, -2) if key_concat_axis==3 else scatter_idx, exceed_key)
+
+
+ # update the overwriting_index_cache to remove the top num_extra_kvs values
+ model.overwriting_index_cache[layer_idx] = model.overwriting_index_cache[layer_idx][
+ ..., num_extra_kvs:, :
+ ]
+
+ updated_past_key_values += ((updated_key, updated_value),)
+
+ return updated_past_key_values
+
+def replenish_rotating_index_cache(
+ cache_length: int,
+ num_kv_heads: int,
+ head_dim: int,
+ bsz: int = 1,
+ rotating_eviction_cache: Optional[torch.Tensor] = None,
+ num_sinked_kvs: int = 0,
+) -> torch.Tensor:
+ """
+ Infer missing indices and replenish them at the end.
+ params:
+ cache_length: this represents the size of the queue in window dimension, for sliding window this is equal to sliding_window_context_len - ARN
+ num_kv_heads: num_kv_heads in the model config
+ head_dim: head_dim in the model_config.
+ bsz: batch size, default is 1
+ rotating_eviction_cache: the queue of shape [bsz, num_kv_heads, window, head_dim]
+ num_sinked_kvs: number of sinked kvs, default is 0
+
+ returns:
+ torch.Tensor: The replenished rotating eviction cache.
+ - If num_sinked_kvs > 0, the returned cache shape is [bsz, num_kv_heads, cache_length - num_sinked_kvs, head_dim].
+ - Otherwise, the shape is [bsz, num_kv_heads, cache_length, head_dim].
+ """
+
+ # Full range of possible indices
+ full_indices = torch.arange(cache_length)
+
+ # Find missing (evicted) indices
+ mask = torch.ones(cache_length, dtype=torch.bool)
+
+ if rotating_eviction_cache is not None:
+ # Filter out indices that are outside the current cache length
+ rotating_eviction_cache = rotating_eviction_cache[:, :, rotating_eviction_cache[0, 0, :, 0] < cache_length, :]
+ mask[rotating_eviction_cache[0,0,:,0]] = False
+ if num_sinked_kvs > 0:
+ # If there are sinked kvs, we need to remove them from the mask
+ # This makes the rotating index cache shorter by num_sinked_kvs
+ mask[:num_sinked_kvs] = False
+ evicted = full_indices[mask]
+ evicted = evicted[None, None, :, None]
+ evicted = evicted.expand(bsz, num_kv_heads, -1, head_dim)
+
+ # Append them back
+ if rotating_eviction_cache is not None:
+ rotating_eviction_cache = torch.cat([rotating_eviction_cache, evicted], dim=2)
+ else:
+ rotating_eviction_cache = evicted
+ return rotating_eviction_cache
+
+
+def llm_scatter_exceeded_kv_using_rotating_eviction(
+ rotating_eviction_cache: torch.Tensor,
+ past_key_values: Tuple[Tuple[torch.Tensor]],
+ num_extra_kvs: int,
+ key_concat_axis: int,
+ value_concat_axis: int,
+ layer_indices_to_perform_eviction: Optional[List[int]] = None,
+) -> Tuple[Tuple[Tuple[torch.Tensor]], torch.Tensor]:
+ """
+ This API is responsible for scattering the exceeded KV cache into the indices of the old KV that will undergo eviction.
+ params:
+ rotating_eviction_cache: the queue of shape [bsz, num_kv_heads, window, head_dim]
+ past_key_values: the accumulated old and new KV cache
+ num_extra_kvs: This value indicates the extent to which the concatenation of accumulated and new key-value pairs exceeds the budget. It represents the number of indices required from the overwriting_index_cache.
+ key_concat_axis: the axis to which we want to append the keys
+ value_concat_axis: the axis to which we want to append the values
+ layer_idx_to_perform_eviction: a list indicating which layers should undergo eviction, the ones not in the list are kept as is in the returned KV$. If None, all layers are assumed to undergo eviction.
+
+ """
+ if layer_indices_to_perform_eviction is None:
+ # If no layers are specified, we assume all layers should undergo eviction
+ layer_indices_to_perform_eviction = list(range(len(past_key_values)))
+
+ assert len(layer_indices_to_perform_eviction) > 0, "At least one layer should be specified for eviction"
+ assert num_extra_kvs <= rotating_eviction_cache.shape[-2], "The number of extra KVs exceeds the size of the rotating eviction cache"
+
+ # container to store the updated KV cache
+ updated_past_key_values = []
+ total_num_kvs = past_key_values[layer_indices_to_perform_eviction[0]][1].shape[2]
+ assert total_num_kvs >= num_extra_kvs, "The total number of KVs should be greater than or equal to the number of extra KVs"
+
+ # Extract the top / front most num_extra_kvs indices from the rotating_eviction_cache
+ # idx (bsz, n_head, window, head_dim) -> (bsz, n_head, num_extra_kvs, head_dim)
+ scatter_idx = rotating_eviction_cache[..., :num_extra_kvs, :]
+ scatter_idx = scatter_idx.to(past_key_values[0][0].device)
+ for layer_idx, (key_cache, value_cache) in enumerate(past_key_values):
+ if layer_idx in layer_indices_to_perform_eviction:
+ # extract the portion of KV which equals the budget size
+ cached_key = key_cache[..., :total_num_kvs - num_extra_kvs]
+ cached_value = value_cache[..., :total_num_kvs - num_extra_kvs, :]
+
+ # extract the exceeded_KV portion
+ exceed_key = key_cache[..., total_num_kvs - num_extra_kvs:]
+ exceed_value = value_cache[..., total_num_kvs - num_extra_kvs:, :]
+
+ # Scattering new exceeded values into old value cache at scatter index positions
+ # the shape of scatter indices align with the value cache dimension, seq_len in the dim=-2
+ updated_value = torch.scatter(cached_value, value_concat_axis, scatter_idx, exceed_value)
+
+ # Scattering new exceeded keys into old key cache at scatter index positions
+ # transpose scatter_idx if keys are transposed
+ updated_key = torch.scatter(cached_key, key_concat_axis,
+ scatter_idx.transpose(-1, -2) if key_concat_axis == 3 else scatter_idx,
+ exceed_key)
+
+ else:
+ updated_key = key_cache
+ updated_value = value_cache
+
+ updated_past_key_values += ((updated_key, updated_value),)
+
+ # update the rotating_eviction_cache to remove the top num_extra_kvs values
+
+ rotating_eviction_cache = rotating_eviction_cache[..., num_extra_kvs:, :]
+
+ # Add the scatter_idx back into the rotating_eviction_cache in the end to preserve the temporal ordering of indices.
+ rotating_eviction_cache = torch.cat([rotating_eviction_cache, scatter_idx.to(rotating_eviction_cache.device)], dim=2)
+
+ return tuple(updated_past_key_values), rotating_eviction_cache
+
+
+def llm_mask_lower_diagonal_swa_attention(swa_attention_mask: torch.Tensor,
+ sliding_window:int,
+ sliding_kv_length: int,
+ prefix_kv_len: int = 0,
+ mask_neg: int = -100.0,
+ pad_to_left: bool = True) -> torch.Tensor:
+ """
+ This API is responsible for masking out the lower triangular area of the SWA attention mask
+
+ params:
+ swa_attention_mask: The input sliding window attention mask to be modified
+ sliding_window: The window size of the sliding window attention
+ sliding_kv_length: the length of kvcache for the sliding window attention
+ mask_neg: mask value for the positions we don't want to attend to
+
+ """
+ # Get ARN from swa_attention_mask
+ max_input_tokens = swa_attention_mask.shape[-2]
+
+ # Lower diagonal masking offset
+ diagonal_offset = -1 if pad_to_left else (sliding_kv_length - sliding_window - 1 + max_input_tokens)
+ diagonal_offset += prefix_kv_len
+
+ # Mask out lower diagonal part of SWA attention to simulate HF SWA behavior
+ left_masking_area = torch.tril(
+ torch.ones_like(swa_attention_mask, dtype=bool),
+ diagonal=diagonal_offset
+ )
+
+ # Don't mask prefix KV
+ left_masking_area[..., :prefix_kv_len] = False
+
+ # Set the mask value
+ swa_attention_mask.masked_fill_(left_masking_area, mask_neg)
+ return swa_attention_mask
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/model_preparation_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/model_preparation_utils.py
new file mode 100644
index 000000000..60d8dc590
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/model_preparation_utils.py
@@ -0,0 +1,96 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides model preparation utilities """
+
+import warnings
+
+def llm_build_preparer_converter_args(num_hidden_layers, model_input_names, separate_position_ids=True, use_qairt_mpp=False, attention_indices=None):
+ '''
+ This function builds converter args used in the model preparation step
+ params:
+ num_hidden_layers: number of hidden hidden layers
+ model_input_names: list of input names of the model
+ attention_indices: list of layer indices that are attention layer
+ '''
+ if attention_indices is None:
+ attention_indices = range(num_hidden_layers)
+
+ if not use_qairt_mpp:
+ warnings.warn(
+ "We are not supporting any further updates to the non-qairt converter args workflow"
+ )
+ converter_args_param = ['--input_layout']
+ converter_args_value = 'NONTRIVIAL'
+ converter_args = []
+ for input_param in converter_args_param:
+ for input_name in model_input_names:
+ if input_name == 'position_ids' and separate_position_ids:
+ converter_args += [input_param, 'position_ids[0]', converter_args_value]
+ converter_args += [input_param, 'position_ids[1]', converter_args_value]
+ elif input_name == 'past_key_values':
+ for i in range(num_hidden_layers):
+ if i in attention_indices:
+ converter_args += [input_param, f'past_key_values[{i}][0]', converter_args_value]
+ converter_args += [input_param, f'past_key_values[{i}][1]', converter_args_value]
+ else:
+ converter_args += [input_param, f'past_conv_cache[{i}]', converter_args_value]
+ elif input_name == 'anchor_buffer':
+ for i in range(num_hidden_layers):
+ converter_args += [input_param, f'anchor_buffer[{i}]', converter_args_value]
+ elif input_name == 'keys':
+ for i in range(num_hidden_layers):
+ converter_args += [input_param, f'keys[{i}]', converter_args_value]
+ else:
+ converter_args += [input_param, input_name, converter_args_value]
+ else:
+ from qti.aisw.tools.core.modules.converter.converter_module import InputTensorConfig
+ input_tensors = []
+ for input_name in model_input_names:
+ if input_name in {'position_ids', 'swa_position_ids'} and separate_position_ids:
+ for i in range(2):
+ input_tensor_config=InputTensorConfig(name=f'{input_name}[{i}]',
+ source_model_input_layout='NONTRIVIAL', source_model_input_datatype='float32')
+ input_tensors.append(input_tensor_config)
+ elif input_name == 'past_key_values':
+ for i in range(num_hidden_layers):
+ if i in attention_indices:
+ rangelen=2
+ else:
+ rangelen=1
+ for j in range(rangelen):
+ input_tensor_config=InputTensorConfig(name=f'past_key_values[{i}][{j}]',
+ source_model_input_layout='NONTRIVIAL', source_model_input_datatype='float32')
+ input_tensors.append(input_tensor_config)
+ elif input_name == 'anchor_buffer':
+ for i in range(num_hidden_layers):
+ input_tensor_config=InputTensorConfig(name=f'anchor_buffer[{i}]',
+ source_model_input_layout='NONTRIVIAL', source_model_input_datatype='float32')
+ input_tensors.append(input_tensor_config)
+ elif input_name == 'keys':
+ for i in range(num_hidden_layers):
+ input_tensor_config=InputTensorConfig(name=f'keys[{i}]',
+ source_model_input_layout='NONTRIVIAL', source_model_input_datatype='float32')
+ input_tensors.append(input_tensor_config)
+ elif input_name in {'cache_index', 'swa_cache_index', 'conv_cache_position'}:
+ input_tensor_config=InputTensorConfig(name=input_name,
+ source_model_input_layout='NONTRIVIAL', source_model_input_datatype='int64')
+ input_tensors.append(input_tensor_config)
+ elif input_name == 'input_ids':
+ input_tensor_config=InputTensorConfig(name=input_name,
+ source_model_input_layout='NONTRIVIAL', source_model_input_datatype='int64')
+ input_tensors.append(input_tensor_config)
+ else:
+ input_tensor_config=InputTensorConfig(name=input_name,
+ source_model_input_layout='NONTRIVIAL')
+ input_tensors.append(input_tensor_config)
+
+ converter_args = {'input_tensors': input_tensors}
+
+ return converter_args
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/sliding_window_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/sliding_window_utils.py
new file mode 100644
index 000000000..8617d6493
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/sliding_window_utils.py
@@ -0,0 +1,97 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import torch
+
+def create_swa_mask_from_global_mask(global_mask, sliding_window_length, lookback_window, ssd_prefix_kv_len, global_cache_index=None):
+ """
+ This API creates the swa_mask from the global causal mask. We expect the global_cache_index and sliding_cache_index are aligned (not in values, but they should progress in the same manner) for API correctness.
+ params:
+ global_mask: this is the global causal mask of the shape [1, 1, ARN, context_length].
+ We fetch a subset of values from this to create the sliding causal mask of shape [1, 1, ARN, sliding_window_length]
+ sliding_window_length: this is the expanded sliding window shape [could be inclusive of prefix KV$ or not, could be +1/-1, or even have more than ARN]->
+ all we want to check if the global_cache_index + ARN fall within or outside the sliding_window_length
+ lookback_window: Consider this as the lookback_window, every token in our mask should ideally be looking at exactly lookback_window worth of actual tokens
+ global_cache_index: this integer tells the amount of past KV$ [the new_kv$ will be added right after the old KV$, starting from global_cache_index position]
+ ssd_prefix_kv_len: the length of forecast prefix KV$
+
+ Note: we only support the right padding scenario for SWA+SSD computation in simulation
+ """
+
+ max_input_tokens = global_mask.shape[2]
+ if global_cache_index!=None:
+ # Right padding case, if ssd_prefix_kv_len == 0 then it is non-ssd right padding use-case.
+
+ swa_causal_mask = torch.ones((1, 1, max_input_tokens, sliding_window_length-ssd_prefix_kv_len)).to(global_mask.device)*global_mask.min().item()
+
+ # Step 1:
+ TARGET_VALUE = 0.0
+ # We know the boundary is around global_cache_index..
+
+ # max is because when you have a case when the updated sliding window is more than where your cache index is, then we don't want to simply return the first updated sliding window values from the global mask because we want to make sure that the actual tokens in that are still attending to lookback_window tokens.
+ left_end = max(0, global_cache_index + max_input_tokens - sliding_window_length)
+ right_end = left_end + sliding_window_length
+ # Shift left_end due to prefix_kv_len
+ left_end += ssd_prefix_kv_len
+
+ # The following loop ensures that for every element in ARN, we need to ensure that as we travel back, we pick lookback amount of indices..
+ for i in range(max_input_tokens):
+ # get the indices which equal to zero first, we only want the lookback_window worth of indices
+ mask = (global_mask[:,:,i, left_end:right_end]==TARGET_VALUE).int()
+
+ # Compute prefix sum from right to left
+ cumulative_sum = torch.flip(torch.cumsum(torch.flip(mask, dims=[-1]), dim=-1), dims=[-1])
+
+ # Keep only positions where prefix sum ≤ max_matches
+ final_mask = (mask.bool()) & (cumulative_sum <= lookback_window)
+
+
+ # Use final_mask to copy values from global_mask to swa_causal_mask
+ # We need to broadcast final_mask to match the shape of global_mask[:, :, i, left_end:right_end]
+ selected_values = torch.where(final_mask, global_mask[:, :, i, left_end:right_end], torch.tensor(global_mask.min().item(), device=global_mask.device))
+
+ # Copy into the correct slice of swa_causal_mask
+ swa_causal_mask[:, :, i, :] = selected_values
+
+ # Step 2: We concat the extracted causal mask with the prefix KV$ portion on the left
+ swa_causal_mask = torch.cat((global_mask[:,:,:,:ssd_prefix_kv_len], swa_causal_mask.to(global_mask.device)), dim=-1)
+ return swa_causal_mask
+
+ else:
+ # For left padding, we do not support SSD for now.
+ swa_causal_mask = torch.ones((1, 1, max_input_tokens, sliding_window_length-ssd_prefix_kv_len)).to(global_mask.device)*global_mask.min().item()
+
+ # Step 1:
+ #the TARGET_VALUE represents the value in the causal_mask where the tokens attend to
+ TARGET_VALUE = 0.0
+ left_end = global_mask.shape[-1] - sliding_window_length + ssd_prefix_kv_len
+ right_end = global_mask.shape[-1]
+
+ # The for loop ensures that for every element in ARN, we need to ensure that as we travel back, we pick lookback amount of indices..
+ for i in range(max_input_tokens):
+ # get the indices which equal to zero first, we only want the lookback_window worth of indices
+ mask = (global_mask[:,:,i, left_end:right_end]==TARGET_VALUE).int()
+
+ # Compute prefix sum from right to left
+ cumulative_sum = torch.flip(torch.cumsum(torch.flip(mask, dims=[-1]), dim=-1), dims=[-1])
+
+ # Keep only positions where prefix sum ≤ max_matches
+ final_mask = (mask.bool()) & (cumulative_sum <= lookback_window)
+
+ # Use final_mask to copy values from global_mask to swa_causal_mask
+ # We need to broadcast final_mask to match the shape of global_mask[:, :, i, left_end:right_end]
+ selected_values = torch.where(final_mask, global_mask[:, :, i, left_end:right_end], torch.tensor(global_mask.min().item(), device=global_mask.device))
+
+ # Copy into the correct slice of swa_causal_mask
+ swa_causal_mask[:, :, i, :] = selected_values
+
+ # Step 2: We concat the extracted causal mask with the prefix KV$ portion on the left
+ swa_causal_mask = torch.cat((global_mask[:,:,:,:ssd_prefix_kv_len], swa_causal_mask.to(global_mask.device)), dim=-1)
+
+ return swa_causal_mask
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/ssd/graph_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/ssd/graph_utils.py
new file mode 100644
index 000000000..2cbe07ac1
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/ssd/graph_utils.py
@@ -0,0 +1,16 @@
+import torch
+def llm_concat_forecast_embedding(inputs_embeds, path_to_ssd_pt):
+ '''
+ This API concatenate inputs_embeds and SSD forecast_embeddings
+ inputs_embeds | forecast_embeddings
+
+ params:
+ inputs_embeds: inputs_embeds is the torch.nn.Embedding layer
+ path_to_ssd_pt: the path where the SSD forecast embeddings are stored
+ '''
+ ssd_params= torch.load(path_to_ssd_pt)
+ forecast_embeddings = ssd_params['forecast_embedding'].to(device = inputs_embeds.weight.device, dtype = inputs_embeds.weight.dtype)
+ # inputs_embeds.weight is of type nn.Parameter, hence we concat with .data and assign to .data which is a tensor, otherwise we cannot assign a tensor to a nn.Pramater object,
+ inputs_embeds.weight.data = torch.cat([inputs_embeds.weight.data, forecast_embeddings], dim=0)
+
+ return inputs_embeds
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/ssd/inference_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/ssd/inference_utils.py
new file mode 100644
index 000000000..093dab8d3
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/ssd/inference_utils.py
@@ -0,0 +1,424 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities that are needed for token generation with SSD """
+import torch
+from genai_lib.llm.utils import _concat
+import functools
+from transformers import TopKLogitsWarper
+import numpy as np
+from copy import deepcopy
+DEFAULT_STOP_SEQUENCES = ["###", "<|endoftext|>", "<|begin_of_text|>", \
+ "<|eot_id|>", "<|end_of_text|>", "<|im_start|>", \
+ "<|im_end|>", "<|end|>", "<\s>", "", "", \
+ "\n\n\n\n\n"]
+
+def llm_generate_forecast_tokens(num_draft_tokens, num_forecast_per_token, vocab_size, batch_size = 1, device = "cpu"):
+ '''
+ This API is responsible for generating the num_forecast_per_token tokens. The token IDs are determined assuming they are appended to the right of the existing vocabulary.
+ For SSD, the forecast tokens are appended to the valid_token as well as each of the draft tokens in the input.
+ params:
+ 1. num_draft_tokens: the number of draft tokens in the input
+ 2. num_forecast_per_token: number of forecast tokens that get associated with each draft and valid token
+ 3. vocab_size: the vocabulary size of the given model
+ 4. batch_size: this is needed to create the correct shape of the returned forecast tokens which are further appended to the input/ draft tree which is of shape [bsz, input_ids_length]
+ '''
+ vocab_size = vocab_size
+ total_input_tokens = (1+num_draft_tokens)
+ # we repeat to create the forecast tokens for valid input and each draft token
+ forecast_tokens = torch.arange(vocab_size, vocab_size + num_forecast_per_token, device = device).repeat(batch_size, total_input_tokens)
+ return forecast_tokens
+
+def llm_generate_position_ids_for_draft_and_forecast_tokens(n_branch_list, pos_id_for_valid_token, num_forecast_per_token, num_draft_tokens, use_top1_expand=False, batch_size = 1, device = "cpu"):
+ '''
+ This API generates position IDs for both draft and forecast tokens.
+ Draft tokens at the same level share the same position IDs.
+ Forecast tokens get position IDs that follow the IDs of their associated valid or draft tokens.
+ params:
+ 1. n_branch_list : the list shows the number of draft tokens at each level
+ 2. pos_id_for_valid_token : the position ID corresponding to the valid token in the input
+ 3. num_forecast_per_token : number of forecast tokens that get associated with each draft and valid token
+ 4. num_draft_tokens: the number of draft tokens in the input
+ 5. use_top1_expand: whether to use top1-expand style draft tree expansion
+ '''
+ pos_ids = [pos_id_for_valid_token]
+ if num_draft_tokens>0:
+ for n_depth in range(1, len(n_branch_list)+1):
+ n_tokens_for_current_depth = int(np.prod(n_branch_list[:n_depth]))
+ if use_top1_expand:
+ n_tokens_for_current_depth = int(n_branch_list[n_depth-1])
+ pos_ids += [pos_id_for_valid_token+n_depth] * n_tokens_for_current_depth
+
+ forecast_pos_ids = []
+ for pid in pos_ids:
+ forecast_pos_ids += [pid+i for i in range(1, num_forecast_per_token + 1)]
+
+ pos_ids += forecast_pos_ids
+ return torch.tensor(pos_ids, dtype=torch.long, device = device).repeat(batch_size, 1)
+
+def llm_build_ssd_causal_mask(input_slice, max_input_tokens, num_draft_tokens, n_branch_list,
+ num_forecast_per_token, prefix_kv_len, valid_kv_len, model_context_len, mask_neg=-100, pad_left=True, cache_index=None, use_top1_expand=False):
+ '''
+ This API is responsible for building the causal mask for SSD.
+ It assumes that the prefix KV cache is always appended to the left of the valid KV cache.
+ Additionally, the API assumes the same padding direction for both the input and the past KV.
+
+ The valid input, draft tokens, and forecast tokens all attend to the valid past KV. Forecast tokens also attend to the prefix KV.
+
+ Draft tokens at the first level attend to the valid KV/ root token in the draft tree, while tokens at subsequent levels attend to the valid token as well as their parent draft tokens. .
+ Forecast tokens follow the same routine as their corresponding draft tokens and, in addition, attend to themselves in an autoregressive manner.
+ params:
+ 1. input_slice = refers to the current input slice without any padding
+ 2. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 3. num_draft_tokens: the number of draft tokens in the input
+ 4. n_branch_list: the list shows the number of draft tokens at each level
+ 5. num_forecast_per_token: number of forecast tokens that get associated with each draft and valid token
+ 6. prefix_kv_len: length of prefix kv cache
+ 7. valid_kv_len: length of valid KV cache accumulated so far
+ 8. model_context_len : the maximum context length that can be sent to the model (this is the HF maximum length of the context)
+ 9. mask_neg: mask value for the positions we don't want to attend to
+ 10. pad_left: whether the current input and the KV cache follow the left padding or the right padding (API assumes same for both)
+ 11. use_top1_expand: whether to use top1-expand style draft tree expansion
+ '''
+ # copied frm llama causal mask utils
+ total_kv_len = model_context_len-max_input_tokens
+ if pad_left:
+ # Cache index should not be passed. Concat op is used in doing the KV cache update
+ assert cache_index is None, "Invalid argument error: we do not support the combination of performing left padding and doing scatter for KV cache update."
+ else:
+ # if the user is doing right padding, it is necessary to pass the cache_index.
+ assert cache_index is not None, "Invalid argument error: we do not support the combination of performing right padding and doing concat for KV cache update"
+ #look at build_attention_mask_infer_bfs
+ # Get batch size and seq len of input
+ batch_size, input_len, = input_slice.shape[:2]
+ pad_tokens = max_input_tokens - input_len
+ device = input_slice.device
+ dtype = input_slice.dtype
+
+ #determine the length of padding KV
+ pad_kv_len=total_kv_len - prefix_kv_len - valid_kv_len
+
+ #Determine the number of valid inputs (i.e exculding the draft & forecast tokens)
+ num_valid_inputs = input_len - num_draft_tokens - (num_draft_tokens+1)*num_forecast_per_token
+ #Create a blank causal mask that has mask_neg in all positions
+ causal_mask = torch.full((batch_size, 1, max_input_tokens, total_kv_len+max_input_tokens), mask_neg, device=device, dtype=dtype)
+
+ #Create a blank attention mask with all zeros.
+ #We will set ones in the positions where we would want a token to attend
+ attn_mask = torch.zeros((batch_size,1,max_input_tokens,total_kv_len+max_input_tokens),dtype=torch.int, device=device)
+
+ #All valid tokens in input attend to valid_past_kv. Set these entries to 1.
+ #Ensure that we do not set these for pad tokens
+ # here we assume that the padding direction for input & the kv cache is same.
+ # also we assume that the prefix kv is always append next to valid kv & towards it's left.
+ if pad_left:
+ attn_mask[:, :, pad_tokens:, pad_kv_len+prefix_kv_len:total_kv_len] = 1
+ else:
+ attn_mask[:,:,:input_len,prefix_kv_len:prefix_kv_len+valid_kv_len] = 1
+
+ #create attention mask with lower triangle of 1 - this corresponds to the valid inputs
+ valid_input_attn_mask = torch.tril(torch.ones(num_valid_inputs,num_valid_inputs),)
+
+ #Set the valid_input_attn_mask in attn_mask.
+ #The location of this would be:
+ # Row : 0 to num_valid_inputs-1
+ # Column : total_kv_len to total_kv_len+num_valid_inputs-1
+ if pad_left:
+ attn_mask[:, :, pad_tokens:pad_tokens+num_valid_inputs, total_kv_len+pad_tokens:total_kv_len+pad_tokens+num_valid_inputs] = valid_input_attn_mask
+ else:
+ attn_mask[:,:,:num_valid_inputs, total_kv_len:total_kv_len+num_valid_inputs] = valid_input_attn_mask
+
+ #All draft and forecast tokens should attend to valid tokens. Set those positions to 1 also
+ if pad_left:
+ attn_mask[:,:, pad_tokens+num_valid_inputs:max_input_tokens, total_kv_len+pad_tokens:total_kv_len+pad_tokens+num_valid_inputs] = 1
+ else:
+ attn_mask[:,:,num_valid_inputs:input_len, total_kv_len:total_kv_len+num_valid_inputs] = 1
+
+ #Now we have to identify which tokens the draft tokens need to attend to
+ if pad_left:
+ offset_q = pad_tokens + num_valid_inputs -1
+ offset_kv = total_kv_len + pad_tokens + num_valid_inputs - 1
+ else:
+ offset_q = num_valid_inputs - 1
+ offset_kv = total_kv_len + num_valid_inputs - 1
+
+ mask_prev_node = [[0]]
+ len_verify = len(n_branch_list) if num_draft_tokens else 0
+
+ if use_top1_expand:
+ parent = mask_prev_node
+ for ndx in range(len_verify):
+ new_connection = [parent[0] + [int(sum(n_branch_list[:ndx])) + 1 + i] for i in range(n_branch_list[ndx])]
+ mask_prev_node += new_connection
+ parent = new_connection
+ else:
+ prev_all = [[0]]
+ offset = 0
+ for ndx in range(len_verify):
+ curr = []
+ for prev in prev_all:
+ for _ in range(n_branch_list[ndx]):
+ offset += 1
+ curr += [prev + [offset]]
+ prev_all = deepcopy(curr)
+ mask_prev_node += prev_all
+
+
+ for ndx, cmask in enumerate(mask_prev_node):
+ attn_mask[:, :, offset_q + ndx, offset_kv + torch.tensor(cmask)] = 1
+
+ # Now we identify which tokens the forecast tokens need to attend to
+ mask_next_draft = []
+ offset_q += len(mask_prev_node)
+ for ndx_forecast, cmask in enumerate(mask_prev_node):
+ offset_forecast = len(mask_prev_node) + ndx_forecast * num_forecast_per_token
+ for ndx in range(num_forecast_per_token):
+ mask_next_draft += [cmask + list(range(offset_forecast, offset_forecast + ndx + 1))]
+ for ndx, cmask in enumerate(mask_next_draft):
+ attn_mask[:, :, offset_q + ndx, offset_kv + torch.tensor(cmask)] = 1
+ # prefix for next forecasting
+ if pad_left:
+ attn_mask[:, :, offset_q + ndx, pad_kv_len:pad_kv_len+prefix_kv_len] = 1
+ else:
+ attn_mask[:, :, offset_q + ndx, :prefix_kv_len] = 1
+
+ causal_mask.masked_fill_(attn_mask.bool(), 0)
+ # TODO: If pad_left is False, i.e if right padding is there, we need to adjust the mask to reflect the post KV cache update layout
+ # old good KV | new KV | padding KV, the above computation assumes we have old KV + (padding) | new KV + (padding) always.
+ # Causal mask is of shape [1, 1, new_KV, total_KV], we need to move around the total KV dimension.
+ # Scatter the new KV cache after the cache index and make everything else, -100.
+ if pad_left is False:
+ cache_tensor = torch.arange(max_input_tokens).to(device = cache_index.device)
+ cache_position = cache_index + cache_tensor
+ indices = cache_position.view(1, 1, 1, -1).expand(causal_mask.shape[0], causal_mask.shape[1], max_input_tokens,
+ cache_position.shape[-1])
+ causal_mask = causal_mask.scatter(dim=-1, index=indices, src=causal_mask[:,:,:,total_kv_len:])
+ causal_mask[:,:,:,prefix_kv_len+valid_kv_len+max_input_tokens:] = mask_neg
+ return causal_mask
+
+def llm_add_prefix_kv_cache(prefix_kv, current_kv, model_context_len, max_input_tokens,key_concat_axis, value_concat_axis=-2):
+ '''
+ This API is responsible for adding the prefix kv cache to the left of accumulated past kv (current kv to be sent into the model).
+ It checks whether the addition is feasible as it should not exceed the cache budget. (model_context_len - max_input_tokens)
+ # TODO: revisit to see if this budget needs to be updated for the Long Context feature.
+
+ params:
+ 1. prefix_kv: prefix KV cache for the ssd
+ 2. current_kv: thw current accumulated unpadded past kv cache
+ 3. model_context_len : the maximum context length that can be sent to the model (this is the HF maximum length of the context)
+ 4. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 5. key_concat_axis: the axis to which we want to append the keys
+ 6. value_concat_axis: the axis to which we want to concatenate the values
+ '''
+ #Extract the len from values
+ # prefix_kv is of shape [n_layer,2, bsz, num_kv_heads, seq_len, head_dim]
+ prefix_kv_len = prefix_kv[0][1].shape[-2]
+
+ if current_kv:
+ current_kv_len = current_kv[0][1].shape[-2]
+
+ # Check if we would overflow the KV going into the model if we add prefix KV to it
+ assert prefix_kv_len + current_kv_len <= model_context_len - max_input_tokens
+
+ concatenated_key_values = tuple((_concat(prefix_key, current_key, key_concat_axis),
+ _concat(prefix_value, current_value, value_concat_axis))
+ for (prefix_key, prefix_value), (current_key, current_value) in
+ zip(prefix_kv, current_kv))
+ return concatenated_key_values
+
+ return prefix_kv
+
+def llm_verify_and_accept_draft_tokens(sample_tree, logp, n_branch_list, use_top1_expand=False) -> list:
+ '''
+ This function verifies and accepts draft tokens at the current step. It assumes that the token sampled from the logits of the valid token (the first token) is always accepted.
+ The function then recursively checks if this accepted token matches any draft tokens at the next level, repeating the process until no match is found.
+
+ params:
+ 1. sample_tree: the draft tree with the (valid token + draft tokens) + forecast tokens
+ 2. logp: the logits corresponding to the tokens in the sample tree
+ 3. n_branch_list: the list shows the number of draft tokens at each level
+ 4. use_top1_expand: whether to use top1-expand style draft tree expansion
+ '''
+
+ depth_max = len(n_branch_list)
+ # offsets_with_children is a collection of offsets of note
+ # that has children when we're doing top1-expand style tree expansion.
+ # For instance, when n_branch_list=[2, 3, 2, 3] then offsets_with_children should be [1, 3, 6],
+ # which can be calculated using cumulative sum of [1] + n_branch_list
+ n_branch_with_root = [1] + n_branch_list
+ offsets_with_children = [sum(n_branch_with_root[:idx+1]) for idx in range(len(n_branch_with_root))][:-2]
+ def _verify_recursive(accepted_all, node_ids_all, accepted, node_ids, depth_current, offset):
+ if depth_current <= depth_max:
+ n_branch = n_branch_list[depth_current-1]
+ target = accepted[0, -1]
+ for branch in range(1, n_branch+1):
+ ndx_node = offset + branch
+ if target == sample_tree[0, ndx_node]:
+ target_next = torch.topk(logp[:, ndx_node], k=1, dim=-1).indices
+ accepted = torch.concat([accepted, target_next], dim=-1)
+ node_ids = deepcopy(node_ids) + [ndx_node]
+ accepted_all += [accepted]
+ node_ids_all += [node_ids]
+ if use_top1_expand:
+ # If current node does not have children,
+ # then we need to stop verification there.
+ # Otherwise call _verify_recursive on its children
+ if ndx_node not in offsets_with_children:
+ return accepted_all, node_ids_all
+ else:
+ offset_next = int(np.sum(n_branch_list[:depth_current]))
+ else:
+ n_branch_next = n_branch_list[depth_current] if depth_current+1 <= depth_max else 0
+ offset_next = int(np.sum(np.cumprod(n_branch_list[:depth_current]))) + (branch-1) * n_branch_next
+ accepted_all, node_ids_all = _verify_recursive(accepted_all, node_ids_all, accepted, node_ids, depth_current+1, offset_next)
+ return accepted_all, node_ids_all
+ target = torch.topk(logp[:, 0], k=1, dim=-1).indices
+ # initial
+ accepted = target
+ node_ids = [0]
+ depth_next = 1
+ offset_next = 0
+ accepted_all, node_ids_all = _verify_recursive([accepted], [node_ids], target, node_ids, depth_next, offset_next)
+ return accepted_all[-1], node_ids_all[-1]
+
+
+def llm_build_draft_tree_for_next_inference(valid_token, logits_of_selected_forecast_tokens, n_branch_list, use_top1_expand=False, ):
+ #refer build_tree_dfs
+ '''
+ This function is responsible for building the draft tree for next inference using the valid token and the forecast tokens (sampled from the logits using topk)
+
+ params:
+ valid_token: previous last token
+ logits_of_selected_forecast_tokens: logits outputs of the two forecast tokens of acceptted trajectory
+ n_branch_list: the list shows the number of draft tokens at each level, [3,2] example , topk sampling parameter k in different levels
+ use_top1_expand: whether to use top1-expand style draft tree expansion
+ '''
+ samples = [valid_token]
+ tree_bfs = [valid_token]
+ len_draft = len(n_branch_list)
+ logp = torch.log_softmax(logits_of_selected_forecast_tokens, dim=-1)
+ if use_top1_expand:
+ for ndx in range(len_draft):
+ ## remove repeated tokens, sample 1 extra token, use extra token if repeated token present
+ samples = torch.topk(logp[:, ndx:ndx + 1], k=n_branch_list[ndx] + 1, dim=-1).indices
+ # check if top-1 token is not repeated in next stage of draft
+ last_top1_sample = tree_bfs[-1][0,0]
+ idx = torch.where(last_top1_sample == samples.squeeze()[:-1])[0]
+ if len(idx) > 0:
+ # shift tokens
+ samples = torch.cat([samples[:,:,:idx[0]],samples[:,:,idx[0]+1:]], dim=-1)
+ else:
+ samples = samples[:,:,:-1]
+
+ tree_bfs += samples
+ else:
+ for ndx in range(len_draft):
+ samples += torch.topk(logp[:, ndx:ndx + 1], k=n_branch_list[ndx], dim=-1).indices
+ # refer build_tree_dfs
+ for ndx in range(1, len_draft + 1):
+ tree_bfs += [samples[ndx]] * int(np.prod(n_branch_list[:ndx - 1]))
+ tree_bfs = torch.concat(tree_bfs, dim=1)
+ return tree_bfs
+
+
+
+def llm_get_next_step_forecast_logits(logits, last_accepted_token_index, n_branch_list, use_top1_expand=False):
+ '''
+ From the current ids, once we select the last draft acceptance token position, we then need to determine where do the corresponding
+ forecast logits for this reside and return the forecast logits
+
+ params:
+ 1. logits: the logits that correspond to the valid+draft and the associated forecast tokens
+ 2. last_accepted_token_index: the index of the last accepted token
+ 3. n_branch_list : the list shows the number of draft tokens at each level, [3,2] example , topk sampling parameter k in different levels
+ '''
+ len_draft = len(n_branch_list)
+ offset = llm_len_flat_sample_tree(n_branch_list, use_top1_expand=use_top1_expand)
+ return logits[:,offset+(last_accepted_token_index*len_draft):offset+ (last_accepted_token_index+1)*len_draft,:]
+
+def llm_verify_ssd_arn(max_input_tokens, n_branch_list, use_top1_expand=False):
+ '''
+ This function determines if the ARN is sufficient to hold the valid_token+ draft_tokens + forecast tokens associated with the valid_token and each of the draft_tokens
+
+ params:
+ max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ n_branch_list : the list shows the number of draft tokens at each level, [3,2] example , topk sampling parameter k in different levels
+ use_top1_expand: whether to use top1-expand style draft tree expansion
+ '''
+ len_flat_sample_tree_ = llm_len_flat_sample_tree(n_branch_list, use_top1_expand)
+ draft_tree_length = len_flat_sample_tree_ + len_flat_sample_tree_*len(n_branch_list)
+ return max_input_tokens>=draft_tree_length
+
+def llm_len_flat_sample_tree(n_branch_list, use_top1_expand=False):
+ '''
+ This function returns the length of the flat sample tree (this only corresponds to the draft tokens length, excluding any associated forecast tokens)
+ for example, for [2,3], the flat length is 2+ 2*3 = 8.=, 2 draft tokens in the first level, and 6 in the next
+
+ params:
+ 1. n_branch_list : the list shows the number of draft tokens at each level, [3,2] example , topk sampling parameter k in different levels
+ 2. use_top1_expand: whether to use top1-expand style draft tree expansion
+ '''
+ if not isinstance(n_branch_list, torch.Tensor):
+ n_branch_list = torch.Tensor(n_branch_list)
+ if use_top1_expand:
+ return 1 + int(n_branch_list.sum().item())
+ return 1+int(torch.cumprod(n_branch_list, dim=0).sum().item())
+
+
+def llm_preprocess_prefix_kv(prefix_kv, key_concat_axis, num_layers):
+ '''
+ The function is responsible for prefix keys and values in the tupled format expected by the prepared graph
+
+ params:
+ prefix_kv: the learned prefix kv with shape [28, 2, 1, 8, 16, 128], [num_layers, 2, bsz, num_kv_heads, seq_len, head_dim]
+ key_concat_axis: the axis along which seq_len is present for keys
+ num_layers: num_layers keys to extract
+ '''
+ prefix_kv_tuple = ()
+ for idx in range(num_layers):
+ key_states, value_states = prefix_kv[idx][0], prefix_kv[idx][1]
+ if key_concat_axis==3 or key_concat_axis==-1:
+ key_states = key_states.transpose(2, 3)
+ prefix_kv_tuple += ((key_states, value_states),)
+ return prefix_kv_tuple
+
+def llm_check_if_end_generation(generated_tokens, tokenizer):
+ '''
+ The function checks if the currently generated tokens contain the eos_token.
+ If it does, it sets the break_generation flag to true and extracts all tokens that appear before the eos_token from the generated tokens.
+
+ params:
+ 1. generated_tokens: the tokens generated so far. [bsz, output tokens]
+ 2. tokenizer: the tokenizer used to decode the generated tokens
+ '''
+ break_generation = False
+ eos_token_id = tokenizer.eos_token_id
+ if eos_token_id in generated_tokens[0]:
+ stop_index = (generated_tokens[0] == eos_token_id).nonzero(as_tuple=True)
+ stop_index = stop_index[0][0].item()
+ break_generation = True
+ return (break_generation, generated_tokens[:, :stop_index])
+ return (break_generation, generated_tokens)
+
+def llm_capture_ssd_stats(n_accepted_all, input_text, output_text, output_ids = None):
+ '''
+ This function is used to capture statistics like acceptance rate, input_text, etc.
+ It returns a dictionary containing the computed stats as available.
+ '''
+ stat={}
+ n_tokens = np.array(n_accepted_all) + 1
+ stat['n_accepted'] = n_accepted_all
+ stat['sum_n_tokens'] = int(n_tokens.sum())
+ stat['avg_n_tokens'] = float(n_tokens.mean())
+ stat['mem_bound_speedup'] = stat['avg_n_tokens']
+ stat['input_text'] = input_text
+ stat['output_text'] = output_text
+ if output_ids is not None:
+ stat['output_ids'] = output_ids
+ return stat
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/static_graph_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/static_graph_utils.py
new file mode 100644
index 000000000..b68043d07
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/static_graph_utils.py
@@ -0,0 +1,561 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities for the pipeline to work with static shape requirements for inputs that go into the model """
+
+import torch
+from genai_lib.llm.utils import _shift, _concat
+
+def llm_slice_inputs_for_inference(max_input_tokens, model_context_len, input_ids=None, inputs_embeds=None, attention_mask=None, position_ids=None, past_seen_tokens=None, hidden_states=None, remainder_first=True):
+ """
+ This function is responsible for slicing the inputs based on the AR and yield them to the user.
+ params:
+ 1. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 2. model_context_len: maximum number of tokens that the model can consume in total
+ 3. input_ids: input ids sent to the model
+ 4. inputs_embeds: input embeds sent to the model
+ 5. attention_mask: attention mask sent to the model
+ 6. position_ids: position ids sent to the model
+ 7. hidden_states: hidden states sent to the model
+ 8. remainder_first: boolean flag which indicates whether we are slicing such that the remainder is in beginning or in end. for an input of size 10 and ARN=3, If true, the remainder is in the beginning [1, 3, 3, 3] else it is in the end [3, 3, 3, 1]
+
+ Note: To be able to ingest all the model_context_len tokens, we need to slice using the left padding and chunk the input into chunks of max_input_tokens
+ """
+
+ input_count = 0
+ for input in (input_ids, inputs_embeds):
+ if input is not None:
+ input_count = input_count + 1
+
+ assert input_count == 1, "Should pass either input ids or input embeddings, not both"
+
+ if input_ids is not None:
+ input_length = input_ids.shape[1]
+ batch_size = input_ids.shape[0]
+ device = input_ids.device
+ else:
+ input_length = inputs_embeds.shape[1]
+ batch_size = inputs_embeds.shape[0]
+ device = inputs_embeds.device
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, input_length), dtype = torch.long, device = device)
+
+ if position_ids is None:
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
+ if past_seen_tokens is not None:
+ position_ids += past_seen_tokens
+
+ # If suppose we have an input chunk of size 10, and max_input_tokens (ARN) is 3, then we can either do [1, 3, 3, 3] (remainder_first) or we can do [3, 3, 3, 1] (remainder_last)
+ if remainder_first:
+ """
+ As an example consider:
+ ctx_len: 10
+ input_len: 10
+ max_input_tokens: 3
+ KV$ that can be sent into the model is ctx_len-max_input_tokens = 10-3 = 7
+ Chunks we will send [1, 3, 3, 3]
+ After 1st iteration: accumulated KV$ = 1
+ After 2nd iteration: accumulated KV$ = 4
+ After 3rd iteration: accumulated KV$ = 7
+
+ Now, when sending the last slice of 3, we will either pad it left or right, irrespective of that, the past KV$
+ that can flow into the model/ or the KV$ that the current input slice will attend to can only be ctx_len-ARN, hence
+ we will only look at 7 which is accumulated accurately until this point.
+
+ Hence, we can pass ctx_len worth of input chunk into the model without needing any eviction logic here.
+ This is the default behavior.
+ """
+ for idx in range(0, input_length, max_input_tokens)[::-1]:
+ idx = input_length - idx
+ slice_beginning = max(0, idx-max_input_tokens)
+ output_slice = {
+ 'attn_mask_slice':attention_mask[:, slice_beginning:idx],
+ 'position_ids_slice': position_ids[:, slice_beginning:idx],
+ }
+
+ if input_ids is not None:
+ output_slice['input_ids_slice'] = input_ids[:, slice_beginning:idx]
+ else:
+ output_slice['inputs_embeds_slice'] = inputs_embeds[:, slice_beginning:idx, :]
+
+ if hidden_states is not None:
+ output_slice['hidden_states_slice'] = hidden_states[:, slice_beginning:idx]
+
+ yield output_slice
+ else:
+ """
+ This is the default behavior for Qualla/ on-target
+ As an example consider:
+ ctx_len: 10
+ input_len: 10
+ max_input_tokens: 3
+ KV$ that can be sent into the model is ctx_len-max_input_tokens = 10-3 = 7
+ Chunks we will send [3, 3, 3, 1]
+ After 1st iteration: accumulated KV$ = 3
+ After 2nd iteration: accumulated KV$ = 6
+ After 3rd iteration: accumulated KV$ = 9
+
+ Now, when sending the last slice of 1, we will either pad it left or right, irrespective of that, the past KV$
+ that can flow into the model/ or the KV$ that the current input slice will attent to can only be ctx_len-ARN, hence
+ we will only look at 7 (instead of 9 KV$) and loose information as we need to evict 2 KV$
+
+ More importantly, we will have to evict this extra KV$ otherwise we will run into issues.
+ """
+ for idx in range(0, input_length, max_input_tokens):
+ slice_end = min(idx+max_input_tokens, input_length)
+ output_slice = {
+ 'attn_mask_slice': attention_mask[:, idx: slice_end],
+ 'position_ids_slice': position_ids[:, idx: slice_end],
+ }
+
+ if input_ids is not None:
+ output_slice['input_ids_slice'] = input_ids[:, idx: slice_end]
+ else:
+ output_slice['inputs_embeds_slice'] = inputs_embeds[:, idx: slice_end, :]
+
+ if hidden_states is not None:
+ output_slice['hidden_states_slice'] = hidden_states[:, idx: slice_end]
+
+ yield output_slice
+
+
+def llm_pad_inputs(max_input_tokens, input_ids_slice=None, inputs_embeds_slice=None, pad_token=0, pad_embeds=None, pad_to_left=True):
+ '''
+ This function pads the input_ids/ inputs_embeds since slice may return input_ids/ inputs_embeds that is smaller in length
+ than what the model accepts (AR len)
+
+ params:
+ 1. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 2. input_ids_slice: the current input ids slice that is passed into the model in the current invocation
+ 3. inputs_embeds_slice: the current input embeds slice that is passed into the model in the current invocation
+ 4. pad_token: padding token, this is defaulted to 0 to avoid impacting the range of values in the activation tensor
+ 5. pad_embeds: Tensor with which we pad the inputs_embeds_slice. This is optional and will be used if provided.
+ If this is not provided, and we are working with input embeddings, the inputs_embeds_slice tensor will be padded
+ with zero and not the pad_token. The reason for this is that the pad_token could be a large non-zero value which
+ will impact the range of values in the padded tensor.
+ 6. pad_to_left: boolean value indicating whether padding is done towards the left or right.
+
+ '''
+ input = input_ids_slice if input_ids_slice is not None else inputs_embeds_slice
+ device = input.device
+ input_length = input.shape[1]
+ batch_size = input.shape[0]
+ shape = (batch_size, max_input_tokens - input_length)
+
+ if pad_embeds is None:
+ if inputs_embeds_slice is not None:
+ shape += (input.shape[-1],)
+ pad_token = 0
+
+ input_extensions = torch.full(
+ shape,
+ fill_value=pad_token,
+ dtype=input.dtype,
+ device=device
+ )
+ else:
+ assert input.shape[-1] == pad_embeds.shape[-1]
+ # we only want to extract the embeddings dimension from the passed pad_embeddings
+ pad_embeds = pad_embeds[-1]
+ input_extensions = pad_embeds.view(1, 1, -1).repeat(batch_size, max_input_tokens - input_length, 1).to(dtype=input.dtype, device=device)
+
+ # left padding
+ if pad_to_left:
+ input = torch.cat((input_extensions, input), dim=1)
+ # right padding
+ else:
+ input = torch.cat((input, input_extensions), dim=1)
+
+ return input
+
+def llm_pad_hidden_states(max_input_tokens: int, hidden_states_slice: torch.Tensor, pad_token=0, pad_to_left=True):
+ '''
+ This function pads the hidden states since slice may return hidden states that is smaller in length
+ than what the model accepts (AR len)
+
+ params:
+ 1. max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ 2. hidden_states_slice: the current hidden state slice that is passed into the model in the current invocation
+ 3. pad_token: padding token, this is defaulted to 0 to avoid impacting the range of values in the activation tensor
+ 4. pad_to_left: boolean value indicating whether padding is done towards the left or right.
+
+ '''
+ pad_shape = list(hidden_states_slice.shape)
+ pad_shape[1] = max_input_tokens - hidden_states_slice.shape[1]
+ pad = torch.full(
+ pad_shape,
+ fill_value=pad_token,
+ dtype=hidden_states_slice.dtype,
+ device=hidden_states_slice.device
+ )
+
+ # left padding
+ if pad_to_left:
+ padded_hidden_states_slice = torch.cat((pad, hidden_states_slice), dim=1)
+ # right padding
+ else:
+ padded_hidden_states_slice = torch.cat((hidden_states_slice, pad), dim=1)
+
+ return padded_hidden_states_slice
+
+def llm_pad_input_attn_mask(attn_mask_slice, max_input_tokens, pad_to_left=True):
+ """
+ This function pads the 1d attention mask to make it of shape (batch_size, max_input_tokens),
+
+ A: padded current input (0s)
+ B: current valid input (1s)
+
+ If the pad_to_left argument is set to True, it means we perform left_padding & produce the attention mask as A|B
+ else, pad_to_left=False means we do right padding, & produce the attention mask as B|A
+
+ params:
+ attn_mask_slice: the attention mask which corresponds to the current slice of inputs
+ max_input_tokens: the maximum tokens that can be sent to the model, in our context represents the AR length
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+ """
+ batch_size = attn_mask_slice.shape[0]
+ input_padding_length = max_input_tokens - attn_mask_slice.shape[1]
+
+ padded_input_attn_mask = torch.zeros((batch_size, input_padding_length), dtype=torch.long,
+ device=attn_mask_slice.device)
+
+ #left padding
+ if pad_to_left:
+ attention_mask = torch.cat((
+ padded_input_attn_mask,
+ attn_mask_slice
+ ),
+ dim=1
+ )
+ # right padding
+ else:
+ attention_mask = torch.cat((
+ attn_mask_slice,
+ padded_input_attn_mask
+ ),
+ dim=1
+ )
+
+ return attention_mask
+
+
+def llm_create_kv_attn_mask(unpadded_past_kv, model_context_len, max_input_tokens, batch_size, device, pad_to_left=True, global_layer_idx = 0):
+ """
+ This function prepares the 1d attention mask based on the useful past key values seen so far.
+ This can be visualized into two sections.
+ A | B
+ A: padded past kv length (0s)
+ B: useful past kv length (1s)
+
+ params:
+ unpadded_past_kv: this is the useful accumulated past kv
+ max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ model_context_len: maximum number of tokens that the model can consume in total
+ batch_size: batch size of current input
+ device: device to place the attention mask
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+ global_layer_idx: an integer representing the index of the global layer (for models like Gauss, where we have sliding window attention, the first layer could be the sliding layer whose shape may not reflect the correct context_len)
+ """
+
+ useful_past_kv_length = unpadded_past_kv[global_layer_idx][1].shape[-2] if unpadded_past_kv else 0
+ padded_kv_length = (model_context_len - max_input_tokens) - useful_past_kv_length
+
+ useful_past_kv_attn_mask = torch.ones((batch_size, useful_past_kv_length), dtype=torch.long,
+ device=device)
+ padded_kv_attn_mask = torch.zeros((batch_size, padded_kv_length), dtype=torch.long, device=device)
+
+ # left padding
+ if pad_to_left:
+ attention_mask = torch.cat((
+ padded_kv_attn_mask,
+ useful_past_kv_attn_mask
+ ),
+ dim=1
+ )
+ #right padding
+ else:
+ attention_mask = torch.cat((
+ useful_past_kv_attn_mask,
+ padded_kv_attn_mask
+ ),
+ dim=1
+ )
+ return attention_mask
+
+def llm_create_1d_attn_mask(attn_mask_past_kv, attn_mask_input, cache_index=None):
+ '''
+ This function concatenates the attention mask corresponding to the input ids and the past kv together
+ params:
+ attn_mask_past_kv: the attention mask corresponding to the past kv
+ attn_mask_input: the attention mask corresponding to the input (max_input tokens that the model takes)
+ cache_index: cache_index determines where should the attn_mask_input be placed. If None, the input_attention mask
+ is placed towards the end (assuming concat in the kv update within attention) else it is placed right after the valid kv mask.
+ '''
+ if cache_index is None:
+ attention_mask = torch.cat((attn_mask_past_kv,
+ attn_mask_input
+ ),
+ dim=1
+ )
+ else:
+ attention_mask_post_valid_kv = attn_mask_past_kv[:, cache_index:]
+ attention_mask_valid_kv = attn_mask_past_kv[:, :cache_index]
+ attention_mask = torch.cat((
+ attention_mask_valid_kv,
+ attn_mask_input,
+ attention_mask_post_valid_kv
+ ),
+ dim=1)
+
+
+ return attention_mask
+
+def llm_pad_past_kv(dummy_past_kv, unpadded_past_kv, num_hidden_layers, key_concat_axis, value_concat_axis=2, pad_to_left=True):
+ """
+ This function is responsible taking in current past kv and pad it using dummy kv to meet the static shape
+ requirements for past kv.
+ We compute the padding kv length as (Context Length - AR length) - (valid kv length).
+ The shape after we pad past kv is (Context Length - AR length)
+
+ params:
+ dummy_past_kv: this corresponds to the dummy kv for one hidden layer, it is same for all the layers or a list for where each entry is a tuple, which is the dummy kv for the particular layer
+ unpadded_past_kv: this is the useful accumulated past kv (require this to obtain the length of useful past kv)
+ num_hidden_layers: The number of decoder blocks in the model
+ key_concat_axis: the axis to which we want to append the keys
+ value_concat_axis: the axis to which we want to append the values
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+
+ """
+
+ # if the dummy kv is a list, then we will iterate over that and pad the KV for that layer accordingly. This gives us the flexibility to pad different layers according to it's own ctx_len.
+ if isinstance(dummy_past_kv, list):
+ assert len(dummy_past_kv) == num_hidden_layers, "Please make sure you pass the dummy KV for each layer"
+ padded_key_values = tuple()
+ for i in range(num_hidden_layers):
+ dummy_past_kv_i = dummy_past_kv[i]
+ # get the useful past kv length of the ith layer
+ useful_past_kv_length = unpadded_past_kv[i][1].shape[-2] if unpadded_past_kv else 0
+
+ # trim the dummy kv corresponding to that particular layer based on it's useful kv length
+ # trimmed dummy kv is the final length dummy kv that will be concatenated to the unpadded_past_kv either to the left or to the right.
+ trimmed_dummy_kv = (_shift(dummy_past_kv_i[0], key_concat_axis, useful_past_kv_length),
+ _shift(dummy_past_kv_i[1], value_concat_axis, useful_past_kv_length))
+ if unpadded_past_kv:
+ if pad_to_left:
+ padded_key_values_i = (_concat(trimmed_dummy_kv[0], unpadded_past_kv[i][0], key_concat_axis),
+ _concat(trimmed_dummy_kv[1], unpadded_past_kv[i][1], value_concat_axis))
+ else:
+ padded_key_values_i = (_concat(unpadded_past_kv[i][0], trimmed_dummy_kv[0], key_concat_axis),
+ _concat(unpadded_past_kv[i][1], trimmed_dummy_kv[1], value_concat_axis))
+
+ padded_key_values += (padded_key_values_i, )
+ else:
+ padded_key_values += (trimmed_dummy_kv, )
+ return padded_key_values
+
+ else:
+ useful_past_kv_length = unpadded_past_kv[0][1].shape[-2] if unpadded_past_kv else 0
+
+ # trimmed dummy kv is the final length dummy kv that will be concatenated to the unpadded_past_kv either to the left or to the right.
+ trimmed_dummy_kv = (_shift(dummy_past_kv[0], key_concat_axis, useful_past_kv_length), _shift(dummy_past_kv[1], value_concat_axis, useful_past_kv_length))
+ if unpadded_past_kv:
+ if pad_to_left:
+ padded_key_values = tuple((_concat(trimmed_dummy_kv[0], unpadded_past_kv[i][0], key_concat_axis),
+ _concat(trimmed_dummy_kv[1], unpadded_past_kv[i][1], value_concat_axis)) for i in range(num_hidden_layers))
+ else:
+ padded_key_values = tuple((_concat(unpadded_past_kv[i][0], trimmed_dummy_kv[0], key_concat_axis),
+ _concat(unpadded_past_kv[i][1], trimmed_dummy_kv[1], value_concat_axis)) for i in
+ range(num_hidden_layers))
+ return padded_key_values
+ return tuple(trimmed_dummy_kv for _ in range(num_hidden_layers))
+
+def llm_get_dummy_kv(batch_size,num_key_value_heads, head_dim, key_concat_axis, device, dtype=torch.float32, cache_len = None, model_context_len=None, max_input_tokens=None ):
+ """
+ This function determines the shape of the dummy kv using the required arguments which reflect model config
+ Returns the dummy kv of fixed size each time (for a single layer). This will be used for padding the passed past kv
+
+ params:
+ batch_size: the batch size needed to create dummy kv
+ model_context_len : model_context_len: maximum number of tokens that the model can consume in total
+ max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ num_key_value_heads: the number of key value heads
+ head_dim: dimension at each head
+ key_concat_axis: the axis to which we want to append the keys
+ device: the device to place dummy kv on, this is inferred from the unpadded_past_kv tensor if it is not None
+ """
+
+ def _cache(shape):
+ return torch.zeros(shape, device=device, dtype=dtype)
+
+ if cache_len is None:
+ cache_len = model_context_len-max_input_tokens
+
+ value = (batch_size, num_key_value_heads,cache_len , head_dim)
+ key = (value[0], value[1], value[3], value[2]) if key_concat_axis == 3 else tuple(value)
+ return (_cache(key), _cache(value))
+
+def _llm_trim_padded_tensor(tensor, input_length, pad_axis=1, pad_to_left=True):
+ """
+ This function is responsible for stripping the non-useful values from the returned tensor (e.g., logits or hidden states)
+ since our prepared model returns fixed length tensor
+ params:
+ tensor: current tensor returned from the model (e.g., logits or hidden states)
+ input_length: length of the valid portion of tensor
+ pad_axis: dimension index of padding
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+ """
+ # left padding so we remove the logits from the left & return the valid input length from the end
+ if pad_to_left:
+ trimmed_tensor = torch.narrow(tensor, pad_axis, tensor.shape[pad_axis] - input_length, input_length)
+ # right padding, so we extract the valid input_length from the beginning
+ else:
+ trimmed_tensor = torch.narrow(tensor, pad_axis, 0, input_length)
+ return trimmed_tensor
+
+def llm_trim_pad_logits(cur_logits, input_ids_slice=None, inputs_embeds_slice=None, pad_to_left=True):
+ """
+ This function is responsible for stripping the non-useful logits from the returned logits since our prepared model returns fixed length logits
+ params:
+ cur_logits: current logits returned from the model
+ input_ids_slice: current input ids slice which is not padded
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+ """
+ input = input_ids_slice if input_ids_slice is not None else inputs_embeds_slice
+ input_length = input.shape[1]
+
+ return _llm_trim_padded_tensor(cur_logits, input_length=input_length, pad_to_left=pad_to_left)
+
+def llm_trim_padded_hidden_states(hidden_states, input_ids_slice=None, inputs_embeds_slice=None, pad_to_left=True):
+ """
+ This function is responsible for stripping the non-useful hidden states from the returned hidden states
+ since our prepared model returns fixed length hidden states
+ params:
+ hidden_states: current hidden states returned from the model
+ input_ids_slice: current input ids slice which is not padded
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+ """
+ input = input_ids_slice if input_ids_slice is not None else inputs_embeds_slice
+ input_length = input.shape[1]
+
+ return _llm_trim_padded_tensor(hidden_states, input_length=input_length, pad_to_left=pad_to_left)
+
+def llm_get_position_ids_from_attention_mask(attention_mask, max_input_tokens, model_context_len, cache_index=None):
+ """
+ This function computes the position ids for the tokens being fed into the model from the 1d_attn_mask.
+
+ params:
+ attention_mask: takes in the prepared attention mask needed to deduce the position ids
+ max_input_tokens: the maximum tokens that can be sent to the model, in our context represents the AR length
+ model_context_len : the maximum context length that can be sent to the model (this is the HF maximum length of the context)
+ cache_index: the index for the starting position of kvcaches
+ """
+
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
+ position_ids = position_ids.clip(0, model_context_len - 1)
+ if cache_index is None:
+ position_ids = position_ids[..., -max_input_tokens:]
+ else:
+ position_ids = position_ids[..., cache_index:cache_index+max_input_tokens]
+ return position_ids
+
+def llm_pad_position_ids(position_ids_slice, max_input_tokens, pad_value=0, pad_to_left = True):
+ """
+ This function pads the position_ids since slice may return position_ids that is smaller than what the model accepts (AR len)
+
+ params:
+ position_ids_slice: the current position_ids slice that is passed into the model in the current invocation
+ max_input_tokens: maximum number of tokens that can be consumed by the model at each inference (equals ARN)
+ pad_value: padding value, this is defaulted to 0
+ pad_to_left: boolean value indicating whether padding is done towards the left or right.
+
+ """
+
+ assert position_ids_slice is not None
+ assert position_ids_slice.dim() == 2
+
+ batch_size, pos_ids_len = position_ids_slice.shape
+
+ if pos_ids_len < max_input_tokens:
+ pad_pos_ids = torch.full((batch_size, max_input_tokens-pos_ids_len), pad_value,
+ dtype=position_ids_slice.dtype, device=position_ids_slice.device)
+
+ if pad_to_left:
+ position_ids = torch.cat((pad_pos_ids, position_ids_slice), dim=-1)
+ else:
+ position_ids = torch.cat((position_ids_slice, pad_pos_ids), dim=-1)
+
+ return position_ids
+ else:
+ return position_ids_slice
+
+
+def slice_tensors(slice_length, max_length, tensor_dict, remainder_first=True, **kwargs):
+ """
+ Slices tensors in a dictionary along specified dimensions into smaller chunks.
+
+ Parameters:
+ -----------
+ slice_length : int
+ The length of each slice.
+ max_length : int
+ The total length to be sliced from each tensor.
+ tensor_dict : dict
+ A dictionary where keys are variable names and values are tuples of the form (tensor, slice_dim),
+ where `tensor` is a PyTorch tensor and `slice_dim` is the dimension along which to slice.
+ remainder_first : bool, optional (default=True)
+ If True, the remainder (if any) is included in the first slice. Otherwise, it's included in the last slice.
+ **kwargs : dict
+ Additional keyword arguments (not used in this function but included for extensibility).
+
+ Yields:
+ -------
+ dict
+ A dictionary of sliced tensors corresponding to each slice.
+ """
+ remainder = max_length % slice_length
+ num_full_slices = max_length // slice_length
+ num_slices = num_full_slices + (1 if remainder > 0 else 0)
+
+ for i in range(num_slices):
+ sliced_dict = {}
+
+ # Determine start and end indices for the current slice
+ if remainder_first:
+ if i == 0 and remainder > 0:
+ start_idx = 0
+ end_idx = remainder
+ else:
+ start_idx = remainder + (i - 1) * slice_length if remainder > 0 else i * slice_length
+ end_idx = start_idx + slice_length
+ else:
+ if i < num_full_slices:
+ start_idx = i * slice_length
+ end_idx = start_idx + slice_length
+ else:
+ start_idx = num_full_slices * slice_length
+ end_idx = max_length
+
+ # Slice each tensor in the dictionary
+ for var_name, (tensor, slice_dim) in tensor_dict.items():
+ assert isinstance(tensor, torch.Tensor), \
+ f"Input {var_name} is not a tensor, but a {type(tensor)}"
+
+ # Adjust end index if it exceeds tensor size
+ if end_idx > tensor.size(slice_dim):
+ end_idx = tensor.size(slice_dim)
+
+ # Skip if the slice would be empty or invalid
+ if end_idx - start_idx <= 0:
+ return
+
+ # Perform the slicing using torch.narrow
+ sliced_dict[var_name] = tensor.narrow(slice_dim, start_idx, end_idx - start_idx)
+
+ yield sliced_dict
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/test_vectors.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/test_vectors.py
new file mode 100644
index 000000000..90e55c8be
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/test_vectors.py
@@ -0,0 +1,274 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides utilities for sample input and test vector recording """
+
+import contextlib
+import os
+import pickle
+import re
+from typing import Dict, Tuple, Union, Any
+
+import numpy as np
+import torch
+import torch.nn
+from aimet_torch.layer_output_utils import LayerOutput, LayerOutputUtil, NamingScheme
+from aimet_torch.onnx_utils import OnnxExportApiArgs
+from aimet_torch.quantsim import ExportableQuantModule
+from aimet_torch.utils import (
+ change_tensor_device_placement,
+ in_eval_mode,
+ is_leaf_module,
+ nested_map,
+)
+from aimet_torch.v2.nn.base import BaseQuantizationMixin
+from aimet_torch.v2.quantization import QuantizedTensorBase
+
+from genai_lib.common.dev.utils import reset_kernels
+
+MODULE_TYPE_FOR_ATTACHING_HOOK = (ExportableQuantModule,)
+modules_to_treat_as_leaf = []
+
+def to_torch_tensor(t):
+ """ utilty to move test vectors from DequantizedTensor to torch.Tensor """
+ return nested_map(t, lambda x: torch.tensor(x) if isinstance(x, QuantizedTensorBase) else x)
+
+def to_cpu(t):
+ return change_tensor_device_placement(t, torch.device('cpu'))
+
+def quantizers_state(sim, disabled) -> contextlib.ExitStack:
+ exit_stack = contextlib.ExitStack()
+ if disabled:
+ for _, module in sim.model.named_modules():
+ if isinstance(module, BaseQuantizationMixin):
+ exit_stack.enter_context(module._remove_all_quantizers())
+ return exit_stack
+
+def run_hook_for_layers_with_given_input_get_output(model: torch.nn.Module,
+ input_tensor: Union[torch.Tensor, Tuple, Dict], hook,
+ module_type_for_attaching_hook=None, module_regex_to_include=None,
+ leaf_node_only=True, fwd_func=None):
+ """
+ Register the given hook function for all layers in the model
+ :param model: Model
+ :param input_tensor: Input tensor to the model. If more than one model inputs, use a tuple
+ :param hook: Hook function to register
+ :param module_type_for_attaching_hook: Tuple of torch.nn module types for which hook has to be attached
+ :param leaf_node_only: Set to False if all modules are required
+ :param fwd_func: forward function for model inference
+ :return: None
+ """
+ # ------------------------
+ # Register hook function
+ # ------------------------
+ hooks = []
+ # All leaf modules
+ modules = []
+
+ # Based on the modules in modules_to_treat_as_leaf, we do not want to further continue searching for next level
+ # of modules present in modules_to_treat_as_leaf. To achieve this, save them in modules_to_skip
+ modules_to_skip = set()
+
+ if module_regex_to_include:
+ patterns = [re.compile(pattern) for pattern in module_regex_to_include]
+ name_match_modules = [module for name, module in model.named_modules() if any (re.match(pattern, name) for pattern in patterns)]
+ else:
+ name_match_modules = model.modules()
+
+ for module in name_match_modules:
+ if module not in modules_to_skip:
+ # pylint: disable=protected-access
+ if isinstance(module, tuple(modules_to_treat_as_leaf)):
+ modules.append(module)
+ # check for modules inside the 'module' and add them to modules_to_skip
+ for sub_module in module._modules.values():
+ modules_to_skip.add(sub_module)
+ else:
+ if leaf_node_only:
+ if is_leaf_module(module):
+ modules.append(module)
+ else:
+ modules.append(module)
+
+ if module_type_for_attaching_hook:
+ # if needed, filter by module types specified by caller
+ modules = [module for module in modules if isinstance(module, module_type_for_attaching_hook)]
+
+ try:
+ for module in modules:
+ hooks.append(module.register_forward_hook(hook))
+
+ # ------------------------------------------------
+ # Run forward pass to execute the hook functions
+ # ------------------------------------------------
+ with in_eval_mode(model), torch.no_grad():
+ if fwd_func:
+ output = fwd_func(model, input_tensor)
+ else:
+ if isinstance(input_tensor, (list, tuple)):
+ output = model(*input_tensor)
+ elif isinstance(input_tensor, dict):
+ output = model(**input_tensor)
+ else:
+ output = model(input_tensor)
+
+ finally:
+ # --------------------------
+ # Remove all hooks we added
+ # --------------------------
+ for h in hooks:
+ h.remove()
+
+ return output
+
+
+class LLMLayerOutput(LayerOutput):
+ def __init__(self, model: torch.nn.Module, dir_path: str, naming_scheme: NamingScheme = NamingScheme.PYTORCH,
+ dummy_input = None, onnx_export_args: Union[OnnxExportApiArgs, Dict] = None, regex_patterns = None):
+ super().__init__(model, dir_path, naming_scheme, dummy_input, onnx_export_args)
+ self.regex_patterns = regex_patterns
+
+ def record_outputs(self, module: torch.nn.Module, input: Tuple, output: Any):
+ """
+ Hook function to capture output of a layer.
+
+ :param module: Layer-module in consideration.
+ :param input: Input of the layer-module.
+ :param output: Output of the layer-module.
+ :return: None
+ """
+ layer_name = self.module_to_name_dict[module]
+ self.layer_name_to_layer_output_dict[layer_name] = {"input": to_cpu(input), "output": to_cpu(output)}
+
+ def get_outputs(self, input_batch) -> Dict[str, torch.Tensor]:
+ """
+ This function captures layer-outputs and renames them as per the AIMET exported pytorch/onnx/torchscript model.
+
+ :param input_batch: Batch of inputs for which we want to obtain layer-outputs.
+ :return: layer-name to layer-output batch dict
+ """
+
+ # Fetch outputs of all the layers
+ self.layer_name_to_layer_output_dict = {}
+ if self.is_quantsim_model:
+ # Apply record-output hook to QuantizeWrapper modules (one node above leaf node in model graph)
+ model_output = run_hook_for_layers_with_given_input_get_output(self.model, input_batch, self.record_outputs,
+ module_type_for_attaching_hook=MODULE_TYPE_FOR_ATTACHING_HOOK,
+ leaf_node_only=False, module_regex_to_include=self.regex_patterns)
+ else:
+ # Apply record-output hook to Original modules (leaf node in model graph)
+ model_output = run_hook_for_layers_with_given_input_get_output(self.model, input_batch, self.record_outputs,
+ leaf_node_only=True, module_regex_to_include=self.regex_patterns)
+
+ # Rename outputs according to pytorch/onnx/torchscript model
+ layer_output_name_to_layer_output_dict = LayerOutput.rename_layer_outputs(self.layer_name_to_layer_output_dict,
+ self.layer_name_to_layer_output_name_dict)
+
+ return layer_output_name_to_layer_output_dict, model_output
+
+
+class LLMLayerOutputUtil(LayerOutputUtil):
+ def __init__(self, model: torch.nn.Module, dir_path: str, file_prefix: str, naming_scheme: NamingScheme = NamingScheme.PYTORCH,
+ dummy_input = None, onnx_export_args: Union[OnnxExportApiArgs, Dict] = None, regex_patterns = None):
+ """
+ Constructor for LayerOutputUtil.
+
+ :param model: Model whose layer-outputs are needed.
+ :param dir_path: Directory wherein layer-outputs will be saved.
+ :param naming_scheme: Naming scheme to be followed to name layer-outputs. There are multiple schemes as per
+ the exported model (pytorch, onnx or torchscript). Refer the NamingScheme enum definition.
+ :param dummy_input: Dummy input to model. Required if naming_scheme is 'NamingScheme.ONNX' or 'NamingScheme.TORCHSCRIPT'.
+ :param onnx_export_args: Should be same as that passed to quantsim export API to have consistency between
+ layer-output names present in exported onnx model and generated layer-outputs. Required if naming_scheme is
+ 'NamingScheme.ONNX'.
+ """
+ super().__init__(model, dir_path, naming_scheme, dummy_input, onnx_export_args)
+ self.output_dir = dir_path
+ self.file_prefix = file_prefix
+
+ # Utility to capture layer-outputs
+ self.layer_output = LLMLayerOutput(model=model, naming_scheme=naming_scheme, dir_path=dir_path, dummy_input=dummy_input,
+ onnx_export_args=onnx_export_args, regex_patterns=regex_patterns)
+
+ def generate_layer_outputs(self, input_batch, batch_idx):
+ """
+ This method captures output of every layer of a model & saves the inputs and corresponding layer-outputs to disk.
+
+ :param input_batch: Batch of inputs for which we want to obtain layer-outputs.
+ :return: None
+ """
+
+ # Obtain layer-output name to output dictionary
+ layer_output_batch_dict, model_outputs = self.layer_output.get_outputs(input_batch)
+
+ test_vectors = {f"{batch_idx}": {**to_cpu(to_torch_tensor(input_batch)),
+ **to_cpu(to_torch_tensor(layer_output_batch_dict))}}
+
+ assert os.path.exists(self.output_dir), "output_dir for test vectors doesn't exist"
+
+ for key, value in test_vectors.items():
+ filename = os.path.join(self.output_dir, self.file_prefix + f"_{batch_idx}.pkl")
+ with open(filename, 'wb') as file:
+ pickle.dump({key: value}, file)
+
+ return model_outputs
+
+def generate_test_vectors(sim, model_inputs, output_dir, batch_index, test_vector_layers, idx_to_name_output_dict=None):
+ """
+ Generates the test vectors in fp and sim on the sim model using the inputs for the test_vector_layers.
+ :param sim: QuantSim model to generate the test vectors on
+ :param model_inputs: model inputs to trace the model and tap inputs at intermediate layers
+ :param output_dir: directory to save the generated test vectors
+ :param batch_index: batch index to differentiate multiple test_vectors
+ :param test_vector_layers: regex layer which corresponds to the layer expression whose input and output we need to extract
+ :param idx_to_name_output_dict: a dict mapping the output index to the corresponding name, by default we assume in LLMs we have first output as logits and second as past_key_values.
+ """
+
+ if idx_to_name_output_dict is None:
+ idx_to_name_output_dict = {0: 'logits', 1: 'past_key_values'}
+
+ vector_output_dir = os.path.join(output_dir, "test_vectors")
+ os.makedirs(vector_output_dir, exist_ok=True)
+
+ def _sanitize_and_update_test_vectors(test_vectors, test_outputs):
+ if "past_key_values" in test_outputs:
+ test_outputs["output_key_values"] = test_outputs.pop("past_key_values")
+
+ test_vectors.update(to_cpu(test_outputs))
+
+ for vector_type in ['fp', 'qt']:
+
+ recorder = LLMLayerOutputUtil(sim.model, dir_path=vector_output_dir,
+ file_prefix=vector_type, regex_patterns=test_vector_layers)
+
+ ctx_managers = contextlib.ExitStack()
+ ctx_managers.enter_context(quantizers_state(sim, disabled=(vector_type == 'fp')))
+ if vector_type == 'fp':
+ # Note: We need to reset the OSET kernel to the PyTorch kernel because the OSET kernel requires encoding,
+ # which will not be available if we remove the quantizers.
+ ctx_managers.enter_context(reset_kernels(sim))
+
+ with ctx_managers:
+ model_outputs = recorder.generate_layer_outputs(model_inputs, batch_index)
+ outputs = {}
+ if len(model_outputs) != len(idx_to_name_output_dict):
+ raise ValueError("please specify the correct mapping to map the output to their names")
+
+ # iterate over the model_outputs and fetch the corresponding name from the idx_to_name_output_dict.
+ for idx, out in enumerate(model_outputs):
+ outputs[idx_to_name_output_dict[idx]] = out
+
+ filename = os.path.join(vector_output_dir, f"{vector_type}_{batch_index}.pkl")
+ test_vector_dict = np.load(filename, allow_pickle=True)
+
+ _sanitize_and_update_test_vectors(test_vector_dict[f"{batch_index}"], outputs)
+ test_vector_dict = to_cpu(to_torch_tensor(test_vector_dict))
+
+ with open(filename, 'wb') as file:
+ pickle.dump(test_vector_dict, file)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/test_vectors_onnx.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/test_vectors_onnx.py
new file mode 100644
index 000000000..c8c0dfa8e
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/test_vectors_onnx.py
@@ -0,0 +1,367 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import os
+import pickle
+import re
+from typing import List, Dict, Tuple, Set, Optional, Any, Sequence
+import numpy as np
+
+from contextlib import contextmanager, ExitStack
+
+import torch
+from torch.utils._pytree import tree_map_only
+
+from aimet_onnx.quantsim import QuantizationSimModel
+from aimet_onnx.layer_output_utils import LayerOutput
+
+from aimet_common.layer_output_utils import save_layer_output_names
+
+from genai_lib.common.onnxruntime_utils import ONNXNameMapper, OutputBufferCreator, ORTInferenceModule
+
+from onnx import ModelProto
+
+@contextmanager
+def disable_quantizers(sim, quantizer_names: Set[str]):
+ """
+ Disables all quantizers in quantizer_names inside the context
+
+ Copied from AIMET-ONNX repo: https://github.qualcomm.com/qualcomm-ai/aimet/blob/develop/TrainingExtensions/onnx/src/python/aimet_onnx/utils.py
+ """
+ if not isinstance(quantizer_names, set):
+ quantizer_names = set(quantizer_names)
+
+ if not quantizer_names.issubset(sim.qc_quantize_op_dict.keys()):
+ raise RuntimeError(
+ f"quantizer_names contains non-existent quantizers: {quantizer_names - sim.qc_quantize_op_dict.keys()}"
+ )
+
+ is_enabled = {
+ name: sim.qc_quantize_op_dict[name].enabled for name in quantizer_names
+ }
+
+ try:
+ for name in quantizer_names:
+ sim.qc_quantize_op_dict[name].enabled = False
+
+ yield
+
+ finally:
+ for name in quantizer_names:
+ sim.qc_quantize_op_dict[name].enabled = is_enabled[name]
+
+
+class LLMLayerOutput:
+ """
+ This class creates a dictionary of node name to node-input/output, and model-input name to model-input. It also produces model output.
+ """
+
+ def __init__(self,
+ model: ModelProto,
+ providers: Sequence[str | Tuple[str, Dict[Any, Any]]],
+ dir_path: str,
+ onnx_name_to_tensor: ONNXNameMapper,
+ output_buffer_creator: OutputBufferCreator,
+ regex_patterns:Optional[List[str]] = None):
+ """
+ Constructor - Initializes lists required for capturing and naming layer-outputs
+
+ params:
+ model (ModelProto): Onnx Model
+ providers (List): List of onnxruntime execution providers
+ dir_path (str): directory path to save the layer-inputs/outputs
+ onnx_name_to_tensor (ONNXNameMapper): ONNXNameMapper object which maps between the onnx names and pytorch tensors
+ output_buffer_creator (OutputBufferCreator): OutputBufferCreator object that creates an Output Buffer for IO Binding
+ regex_patterns (Optional[List[str]]): list of regex patterns to match the layer names
+ """
+ self.model = model
+ self.output_buffer_creator = output_buffer_creator
+ self.onnx_name_to_tensor = onnx_name_to_tensor
+ self.original_output_names = [output.name for output in model.graph.output]
+
+ self.node_name_to_io_names = {}
+
+ # retrieve list of inputs to nodes, list of outputs of nodes, a dict of node name to inputs/outputs, and a list of all the activations (regardless of regex matching)
+ self.node_input_activation_names, self.node_output_activation_names, self.node_name_to_io_names, all_layer_output_names = LLMLayerOutput.get_activation_names(
+ self.model, regex_patterns)
+ # mark any activations that are quantized
+ quantized_activation_names = [
+ n for n in self.node_input_activation_names + self.node_output_activation_names if n.endswith("_updated")
+ ]
+ # add the the non-quantized versions of those activations to the list to remove
+ # remove the "_updated" at the end of the quantized activation name to get the original
+ activations_to_remove = {n[:-8] for n in quantized_activation_names}
+
+ # add any qdq activations to the list to remove as well
+ activations_to_remove.update(
+ n for n in (self.node_input_activation_names + self.node_output_activation_names) if n.endswith("_qdq")
+ )
+
+ # remove in-place the activations marked earlier, by changing the values inside the list, while keeping the list in-place
+ for lst in (self.node_input_activation_names, self.node_output_activation_names):
+ lst[:] = [n for n in lst if n not in activations_to_remove]
+
+ # remove those activations from the node_name_to_io_names dict as well
+ for key in self.node_name_to_io_names:
+ for io in ["input", "output"]:
+ self.node_name_to_io_names[key][io] = [name for name in self.node_name_to_io_names[key][io] if
+ name not in activations_to_remove]
+
+ # recompile a full list of activations to hook
+ self.activation_names = self.node_input_activation_names + self.node_output_activation_names
+
+ # ONNX "hook" for activations
+ LayerOutput.register_activations(self.model, self.activation_names)
+
+ # Build ONNX runtime inference session
+ self.session = QuantizationSimModel.build_session(self.model, providers)
+
+ # Reverting the added outputs until the output length matches the original state
+ # This does not affect the inference session, as self.session does not have a live reference to self.model.
+ while len(self.model.graph.output) > len(self.original_output_names):
+ self.model.graph.output.pop()
+
+ # Sanitize layer output names to save
+ sanitized_all_layer_output_names = [LLMLayerOutput.sanitize_activation_name(name) for name in all_layer_output_names if not (name.endswith("_updated") or name.endswith("qdq"))]
+
+ # Save all model activations in topological order of model graph, for use when comparing layer-outputs
+ # NOTE: This only saves layer output names, not layer input names, to match behavior of torch-based test vectors.
+ save_layer_output_names(sanitized_all_layer_output_names, dir_path)
+
+ @staticmethod
+ def sanitize_activation_name(activation_name: str) -> str:
+ """
+ This function sanitizes the activation name by replacing non-alphanumeric characters with underscores.
+
+ params:
+ activation_name (str): activation name
+ returns:
+ str: sanitized activation name
+ """
+ return re.sub(r"\W+", "_", activation_name.replace("_updated", "")).strip("_")
+
+ @staticmethod
+ def get_activation_names(model: ModelProto, regex_patterns) -> Tuple[
+ List[str], List[str], Dict[str, Dict[str, List[str]]], List[str]]:
+ """
+ This function fetches the activation names (model_input, node_input, node_output names) of the given onnx model, that match the provided regex patterns.
+
+ params:
+ model (ModelProto): ONNX model
+ regex_patterns (list): list of regex patterns
+ return:
+ tuple: Tuple containing lists of activation names for node inputs and node outputs, a Dict of node name to Dicts of node input/output names, and a list of all activation names
+ """
+ patterns = [re.compile(pattern) for pattern in (regex_patterns if regex_patterns is not None else [])]
+
+ node_name_to_io_names = {}
+ node_input_activation_names = []
+ node_output_activation_names = []
+ # we build a list of all_activation names for save_layer_output_names function
+ all_layer_output_names = []
+
+ for node in model.graph.node:
+ # Regardless of regex expression matching, collect layer output name to save
+ for output_name in node.output:
+ all_layer_output_names.append(output_name)
+ # Ensure that node is one of the desired nodes
+ if not any(re.match(pattern, node.name) for pattern in patterns):
+ continue
+ # Construct node name dict
+ node_name_to_io_names[node.name] = {"input": [], "output": []}
+ # Add each input name to the list and dict
+ for input_name in node.input:
+ node_input_activation_names.append(input_name)
+ node_name_to_io_names[node.name]["input"].append(input_name)
+ # Add each output name to the list and dict
+ for output_name in node.output:
+ node_output_activation_names.append(output_name)
+ node_name_to_io_names[node.name]["output"].append(output_name)
+ return node_input_activation_names, node_output_activation_names, node_name_to_io_names, all_layer_output_names
+
+ def get_outputs(self, input_dict: Dict[str, Any]) -> Tuple[Dict[str, Dict[str, Tuple]], Dict[str, torch.Tensor]]:
+ """
+ This function creates node-input/output dict, and also returns model output
+
+ params:
+ input_dict (Dict[str, Any]): Dictionary that contains inputs to model
+ returns:
+ Tuple[Dict[str, Dict[str, Tuple]], Dict[str, torch.Tensor]]: Tuple of node input/output dict and model output
+ """
+
+ device = "cuda" if "CUDAExecutionProvider" in self.session.get_providers() else "cpu"
+
+ # Ensure input tensors are contiguous
+ input_dict = tree_map_only(torch.Tensor, lambda t: t.contiguous(), input_dict)
+
+ # Create output buffer
+ output_buffer = self.output_buffer_creator.create_buffer()
+
+ # Bind the original outputs to the output buffer
+ io_binding = ORTInferenceModule.bind_io(session=self.session,
+ onnx_name_to_tensor=self.onnx_name_to_tensor,
+ input_buffer = input_dict,
+ output_buffer=output_buffer,
+ output_names = self.original_output_names)
+
+ # Bind all the newly hooked outputs
+ for output_name in self.activation_names:
+ io_binding.bind_output(output_name, device)
+
+ # Inference
+ self.session.run_with_iobinding(io_binding)
+
+ # This produces a list of the outputs, first the original outputs, and then the newly hooked outputs
+ all_outputs = io_binding.copy_outputs_to_cpu()
+
+ # Fix the start index
+ start_idx = len(self.original_output_names)
+
+ # Zip the node input values
+ node_input_values_dict = dict(zip(self.node_input_activation_names,
+ all_outputs[start_idx:start_idx + len(self.node_input_activation_names)]))
+
+ # Zip the node output values
+ start_idx += len(self.node_input_activation_names)
+ node_output_values_dict = dict(zip(self.node_output_activation_names, all_outputs[start_idx:]))
+
+ # Initialize layer name to layer io dict
+ layer_name_to_layer_io_dict = {}
+
+ # Create {"input": (input_values), "output": (output_values)} Entries
+ for node_name, node_io_names in self.node_name_to_io_names.items():
+ # Sanitize each name individually
+ sanitized_node_name = LLMLayerOutput.sanitize_activation_name(node_name)
+ layer_name_to_layer_io_dict[sanitized_node_name] = {"input": [], "output": []}
+
+ input_values = []
+ for input_name in node_io_names["input"]:
+ input_values.append(node_input_values_dict[input_name])
+ output_values = []
+ for output_name in node_io_names["output"]:
+ output_values.append(node_output_values_dict[output_name])
+ layer_name_to_layer_io_dict[sanitized_node_name] = {"input": tuple(input_values),
+ "output": tuple(output_values)}
+
+ return layer_name_to_layer_io_dict, output_buffer
+
+
+class LLMLayerOutputUtil:
+ """Class to capture and save inputs and outputs of intermediate nodes of an onnx model"""
+
+ def __init__(
+ self,
+ model: ModelProto,
+ dir_path: str,
+ file_prefix: str,
+ providers: Sequence[str | Tuple[str, Dict[Any, Any]]],
+ onnx_name_to_tensor: ONNXNameMapper,
+ output_buffer_creator: OutputBufferCreator,
+ regex_patterns:Optional[List[str]]=None):
+ """
+ Constructor - It initializes the utility classes that captures and saves the layer-inputs/outputs of an onnx model
+
+ params:
+ model (ModelProto): Onnx Model
+ dir_path (str): directory path to save the layer-inputs/outputs
+ file_prefix (str): file prefix to save the layer-inputs/outputs
+ onnx_name_to_tensor (ONNXNameMapper): ONNXNameMapper object which maps between the onnx names and pytorch tensors
+ output_buffer_creator (OutputBufferCreator): OutputBufferCreator object that creates an Output Buffer for IO Binding
+ device (int): device id to run the model on
+ regex_patterns (Optional[List[str]]): list of regex patterns to match the layer names
+ """
+ self.output_dir = dir_path
+ self.file_prefix = file_prefix
+
+ self.layer_output = LLMLayerOutput(model=model,
+ providers=providers,
+ dir_path=dir_path,
+ onnx_name_to_tensor=onnx_name_to_tensor,
+ output_buffer_creator=output_buffer_creator,
+ regex_patterns=regex_patterns)
+
+ def generate_layer_outputs(self, input_dict: Dict[str, Any], batch_idx: int) -> Dict[str, torch.Tensor]:
+ """
+ This function captures input/output of model, as well as every node that has been matched with the regex patterns. Then it saves these values to disk.
+
+ params:
+ input_batch (Dict[str, Any]): input batch to the model
+ batch_idx (int): batch index
+ returns:
+ Dict[str, torch.Tensor]: dictionary containing the original model outputs
+ """
+ layer_output_dict, model_outputs = self.layer_output.get_outputs(input_dict)
+
+ # Convert to torch to match torch-based test vectors export
+ test_vectors = {f"{batch_idx}": tree_map_only(np.ndarray, torch.from_numpy, {**input_dict, **layer_output_dict})}
+
+ assert os.path.exists(self.output_dir), "output_dir for test vectors doesn't exist"
+
+ for key, value in test_vectors.items():
+ filename = os.path.join(self.output_dir, f"{self.file_prefix}_{batch_idx}.pkl")
+ with open(filename, 'wb') as file:
+ pickle.dump({key: value}, file)
+
+ return model_outputs
+
+
+def generate_test_vectors(sim: QuantizationSimModel, model_inputs: Dict[str, Any], output_dir: str,
+ batch_index: int, test_vector_layers: List[str], onnx_name_to_tensor: ONNXNameMapper, output_buffer_creator: OutputBufferCreator):
+ """
+ This function captures inputs/outputs to the model and also inputs/outputs to the nodes that are matched by the provided list of regex patterns.
+ It does this for both Floating point and Quantized models and saves the values to disk.
+
+ params:
+ sim (QuantizationSimModel): QuantizationSimModel object
+ model_inputs (Dict[str, Any]): input batch to the model
+ output_dir (str): output directory to save test vectors
+ batch_index (int): batch index
+ test_vector_layers (List[str]): list of regex patterns to match the nodes to capture the layer-inputs/outputs
+ onnx_name_to_tensor (ONNXNameMapper): ONNXNameMapper object which maps between the onnx names and pytorch tensors
+ output_buffer_creator (OutputBufferCreator): OutputBufferCreator object that creates an Output Buffer for IO Binding
+ returns:
+ None
+ """
+ vector_output_dir = os.path.join(output_dir, "test_vectors")
+ os.makedirs(vector_output_dir, exist_ok=True)
+
+ def _convert_and_update_test_vectors(test_vectors, test_outputs):
+ if "past_key_values" in test_outputs:
+ test_outputs["output_key_values"] = test_outputs.pop("past_key_values")
+
+ test_vectors.update(test_outputs)
+
+ for vector_type in ['fp', 'qt']:
+ recorder = LLMLayerOutputUtil(sim.model.model,
+ dir_path=vector_output_dir,
+ file_prefix=vector_type,
+ providers=sim.providers,
+ onnx_name_to_tensor=onnx_name_to_tensor,
+ output_buffer_creator=output_buffer_creator,
+ regex_patterns=test_vector_layers)
+
+ ctx_managers = ExitStack()
+ if vector_type == 'fp':
+ ctx_managers.enter_context(disable_quantizers(sim, sim.qc_quantize_op_dict.keys()))
+
+ with ctx_managers:
+ model_outputs = recorder.generate_layer_outputs(model_inputs, batch_index)
+
+ filename = os.path.join(vector_output_dir, f"{vector_type}_{batch_index}.pkl")
+ test_vector_dict = np.load(filename, allow_pickle=True)
+
+ _convert_and_update_test_vectors(test_vector_dict[f"{batch_index}"], model_outputs)
+
+ with open(filename, 'wb') as file:
+ pickle.dump(test_vector_dict, file)
+
+ # Delete the inference session manually to free up memory
+ recorder.layer_output.session = None
+ del recorder
+ torch.cuda.empty_cache()
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/utils.py
new file mode 100644
index 000000000..7a9e74803
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/llm/utils.py
@@ -0,0 +1,122 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+""" This file provides helper functions for LLM Lib """
+
+from typing import List, Tuple, Union
+
+import torch
+
+
+def _concat(a, b, dim):
+ if isinstance(a, tuple):
+ assert len(a) == len(b), 'Unexpected key/value pair'
+ return tuple(_concat(ai, bi, dim) for ai, bi in zip(a, b))
+ if a is None:
+ return b
+ if b is None:
+ return a
+ return torch.cat((a, b), dim=dim)
+
+
+def _do_concat(a, b, key_dim, value_dim):
+ return tuple((_concat(ak, bk, key_dim), _concat(av, bv, value_dim)) for (ak, av), (bk, bv) in zip(a, b))
+
+
+def _shift(a, dim, shift_size, shift_to_left=True):
+ if isinstance(a, tuple):
+ return tuple(_shift(ai, dim) for ai in a)
+ assert dim in (2, 3), 'Unexpected shift axis'
+ if shift_to_left:
+ return a[:, :, shift_size:, :] if dim == 2 else a[:, :, :, shift_size:]
+ else:
+ #this fails when the shift size is 0, in that case, it returns empty tensor, which is opposite of what we want
+ #return a[:, :, :-shift_size, :] if dim == 2 else a[:, :, :, :-shift_size]
+ orig_len = a.shape[-2] if dim == 2 else a.shape[-1]
+ keep_size = orig_len - shift_size
+ return a[:, :, :keep_size, :] if dim == 2 else a[:, :, :, :keep_size]
+
+
+def _do_shift(a, key_dim, value_dim, shift_size):
+ return tuple((_shift(k, key_dim, shift_size), _shift(v, value_dim, shift_size)) for k, v in a)
+
+def _get_past_key_value_names(sfx, n_layers, separate_tuple_input_output):
+ if not separate_tuple_input_output:
+ return ["past_key_values"]
+ all = []
+ for i in range(n_layers):
+ all.append(f'past_key_{i}_{sfx}')
+ all.append(f'past_value_{i}_{sfx}')
+ return all
+
+
+def _get_position_emb_names(use_position_embedding_input=True, separate_tuple_input_output=False):
+ if separate_tuple_input_output:
+ if use_position_embedding_input:
+ return ['position_ids_cos', 'position_ids_sin']
+ return ['position_ids']
+
+
+def llm_model_input_output_names(
+ num_hidden_layers: int,
+ use_position_embedding_input: bool = True,
+ separate_tuple_input_output: Union[bool, Tuple[bool, bool]] = False,
+ use_input_embedding: bool = False,
+) -> Tuple[List[str], List[str]]:
+ """
+ This function is responsible for returning a list of the model input and output names based on the number of hidden layers for a LLama like signature of inputs
+
+ :param num_hidden_layers: number of hidden layers of the model
+ :param use_position_embedding_input: are the position ids supplied to the model in embeddings form (assume sin and cos embedding, if yes)
+ :param separate_tuple_input_output: are the inputs passed into the model in tupled format or not
+ If passed value is a bool, the value is applied to both input and output
+ If passed value tuple of size 2, each bool value is applied separately to input and output
+ :param use_input_embedding: do we pass input ids or input embeddings
+ :return: model input names and model output names
+ """
+ if isinstance(separate_tuple_input_output, bool):
+ separate_tuple_input, separate_tuple_output = separate_tuple_input_output, separate_tuple_input_output
+ elif (
+ isinstance(separate_tuple_input_output, tuple)
+ and len(separate_tuple_input_output) == 2
+ ):
+ separate_tuple_input, separate_tuple_output = separate_tuple_input_output
+ else:
+ raise ValueError('separate_tuple_input_output must be bool or tuple with size 2')
+
+ input_names=['input_ids', 'attention_mask']
+ input_names += _get_position_emb_names(use_position_embedding_input=use_position_embedding_input, separate_tuple_input_output=separate_tuple_input)
+ output_names = ['logits']
+ input_names += _get_past_key_value_names("in", num_hidden_layers, separate_tuple_input)
+ output_names += _get_past_key_value_names("out", num_hidden_layers, separate_tuple_output)
+ if use_input_embedding:
+ input_names += ['inputs_embeds']
+ input_names.pop(0)
+ return input_names, output_names
+
+
+def llm_exporter_input_output_names(num_hidden_layers, use_position_embedding_input=True , separate_tuple_input_output=True, use_input_embedding = False):
+ '''
+ This function returns the input output names passed as input during model export. For downstream use cases, the encodings names are preferred in the untupled format
+ params:
+ num_hidden_layers: number of hidden layers of the model
+ use_position_embedding_input: are the position ids supplied to the model in embeddings form (assume sin and cos embedding, if yes)
+ separate_tuple_input_output: are the inputs passed into the model in tupled format or not
+ use_input_embedding: do we pass input ids or input embeddings
+ '''
+ input_names, output_names = llm_model_input_output_names(num_hidden_layers=num_hidden_layers, use_position_embedding_input=use_position_embedding_input , separate_tuple_input_output=separate_tuple_input_output, use_input_embedding = use_input_embedding)
+ return input_names, output_names
+
+def llm_search_layers_by_type(model, module_type):
+ embedding_layers = []
+ for name, module in model.named_modules():
+ if isinstance(module, module_type):
+ embedding_layers.append(module)
+ return embedding_layers
+
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/calibration.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/calibration.py
new file mode 100644
index 000000000..d67cd9144
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/calibration.py
@@ -0,0 +1,318 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" common utilities and class implementation for calibration """
+
+import contextlib
+import glob
+import inspect
+import os
+import pickle
+from typing import Tuple, Any, Dict, Optional, List, Union, Callable
+from packaging import version
+
+import numpy
+import torch
+from aimet_torch.adaround import adaround_optimizer
+import logging
+logger = logging.getLogger(__name__)
+
+FILENAME_PREFIX = 'cal_'
+FWD_SIGNATURE_FILE = 'fwd_signature.pkl'
+PREP_FWD_MAP_FILE = 'prepared_fwd_map.pkl'
+
+def _flatten_model_inputs(inputs: Dict, prefix: str = '', input_dict: Dict = None, tensors_only: bool = True):
+ """ recursive function to flatten flattens a dict or nested dict into flat dict """
+
+ if isinstance(inputs, dict):
+ for name, t in inputs.items():
+ new_prefix = name if prefix == '' else f'{prefix}_{name}'
+ _flatten_model_inputs(t, new_prefix, input_dict, tensors_only)
+
+ elif tensors_only is False or isinstance(inputs, torch.Tensor):
+ input_dict[prefix] = inputs
+
+def flatten_model_inputs(inputs: Dict, tensors_only: bool = True) -> Dict[str, Any]:
+ """
+ function to flatten flattens a dict or nested dict into flat dict
+ :param inputs: dict of inputs
+ :param tensors_only: if True, filters out entry which do not have Tensor as value
+ :return: flat dict.
+ """
+ input_dict = {}
+ _flatten_model_inputs(inputs, '', input_dict, tensors_only)
+ return input_dict
+
+
+# TODO split reader into data_loader (Calibration) and i/o data reader
+class ModelDataReader:
+ """ Implements a reader for recorded calibration data. """
+
+ def __init__(self, saved_path: str, input_only: bool, is_prepared_model: bool, num_samples: int):
+ self.saved_path = saved_path
+ self._len = len(glob.glob1(self.saved_path, FILENAME_PREFIX + "*.npy"))
+ if num_samples > 0:
+ self._len = min(self._len, num_samples)
+ assert self._len > 0, f'No sample model inputs found in {self.saved_path}'
+ self._input_only = input_only
+ self._is_prepared_model = is_prepared_model
+ with open(os.path.join(self.saved_path, FWD_SIGNATURE_FILE), 'rb') as f:
+ self.model_parameter_defaults = pickle.load(f)
+ if is_prepared_model:
+ prepared_fwd_path = os.path.join(saved_path, PREP_FWD_MAP_FILE)
+ assert os.path.exists(prepared_fwd_path), f'{prepared_fwd_path} does not exist, did you remember to call `recorder.extract_prepare_model_info(prepared_model)` ?'
+ with open(os.path.join(self.saved_path, PREP_FWD_MAP_FILE), 'rb') as f:
+ self.prep_model_input_map = pickle.load(f)
+
+ def get_calibration_data(self, item: int):
+ """Returns the indexed item from the dataset"""
+ file_path = f'{self.saved_path}/cal_{item}.npy'
+ if item < self._len:
+ return numpy.load(file_path, allow_pickle=True)[0]
+
+ raise IndexError
+
+ def __getitem__(self, item: int):
+ """Returns the indexed item from the dataset formatted based on config """
+ data = self.get_calibration_data(item)
+
+ if self._is_prepared_model:
+ recoded_flatten_inputs = self.get_flattened_inputs(inputs=data['inputs'])
+ inputs = tuple([recoded_flatten_inputs[input_key] for input_key in self.prep_model_input_map])
+ else:
+ inputs = self.format_inputs(data['inputs'], self.model_parameter_defaults)
+
+ if self._input_only:
+ return inputs
+
+ return inputs, data['outputs']
+
+ def get_flattened_inputs(self, item: int = None, inputs: Optional[Tuple[Any,]] = None, tensor_only: bool = True):
+ """Returns the indexed item from the dataset as flattened dict"""
+ if inputs is None:
+ inputs = self.get_calibration_data(item)['inputs']
+ inputs = self.format_inputs(inputs, self.model_parameter_defaults)
+ return flatten_model_inputs(dict(zip(self.model_parameter_defaults.keys(), inputs)), tensors_only=tensor_only)
+
+ @staticmethod
+ def format_inputs(inputs: Tuple[Tuple, Dict], parameter_defaults):
+ """
+ flattens input in the format of '*args, **kwargs' to tuple along with adding defaults when value not provided.
+ :param inputs: input tuple in *args, **kwargs format
+ :param parameter_defaults: forward pass parameter /w default values
+ :return: input as tuple.
+ """
+ *positional_inputs, kwargs = inputs
+ if kwargs:
+ add_args = []
+ param_keys = list(parameter_defaults.keys())
+ if param_keys[0] == "self":
+ param_keys = param_keys[1:]
+ for k in param_keys[len(positional_inputs):]:
+ if k in kwargs:
+ add_args.append(kwargs[k])
+ else:
+ default = parameter_defaults[k]
+ if default != inspect.Parameter.empty:
+ add_args.append(default)
+ else:
+ ValueError(f'no value provided for kwarg={k}, which has no defaults')
+ inputs = tuple([*positional_inputs, *add_args])
+ else:
+ inputs = tuple(positional_inputs)
+ return inputs
+
+ def __len__(self):
+ """Returns length of calibration data"""
+ return self._len
+
+
+class ModelDataRecorder:
+ """
+ Implements mechanism to capture model input and output as calibration data.
+ """
+ def __init__(self, save_path: str, model: Optional[torch.nn.Module] = None,
+ clean_start: bool = False, num_samples: int = -1):
+ self.index = 0
+ self.hook_handler = []
+ self.input_args_type = None
+ os.makedirs(save_path, exist_ok=True)
+ self.save_path = save_path
+ self.num_samples = num_samples
+
+ if model is not None:
+ self._register(model)
+
+ calibration_files = glob.glob(f'{self.save_path}/{FILENAME_PREFIX}*.npy')
+ if len(calibration_files) > 0:
+ logger.info("calibration data (num=%d) found at %s, %s ..",
+ len(calibration_files), self.save_path, 'deleting' if clean_start else 'exiting')
+ if clean_start:
+ for filename in calibration_files:
+ os.remove(filename)
+ else:
+ raise ValueError("if you intend to collect calibration data again, please set clean_start=True")
+ else:
+ fwd_signature_path = os.path.join(self.save_path, FWD_SIGNATURE_FILE)
+ assert os.path.exists(fwd_signature_path), f'forward signature file not found in {fwd_signature_path}'
+
+ def _register(self, model: torch.nn.Module):
+ """
+ Register forward hook for model input/output.
+ """
+ with open(os.path.join(self.save_path, FWD_SIGNATURE_FILE), 'wb') as f:
+ params_defaults = {param: value.default for param, value in
+ inspect.signature(model.forward).parameters.items()}
+ pickle.dump(params_defaults, f, protocol=pickle.HIGHEST_PROTOCOL)
+
+ def hook(_, inputs, outputs):
+ """ Save model input tuple of tensors and output tensor"""
+
+ fn_context = [f.function for f in inspect.stack()[1:]] # skipping innermost stack-frame[0] i.e. hook(..)
+ if version.parse(torch.__version__) < version.parse("2.5"):
+ assert fn_context[0] == '_call_impl',\
+ f'unexpected context={fn_context}, hook expected to be invoked from torch.nn.module._call_impl(..) proxy fn'
+ else:
+ assert fn_context[:2] == ['inner', '_call_impl'], \
+ f'unexpected context={fn_context} , hook expected to be invoked from torch.nn.module._call_impl.inner(..) proxy fn'
+
+ forward_fn = inspect.currentframe().f_back
+ inputs = *inputs, forward_fn.f_locals['kwargs']
+ self._verify_model_inputs(inputs)
+ numpy.save(
+ f'{self.save_path}/{FILENAME_PREFIX}{self.index}',
+ [{'inputs' : inputs, 'outputs': outputs}], allow_pickle=True)
+ self.index += 1
+
+ if self.index == self.num_samples:
+ self.stop_recording()
+
+ self.hook_handler.append(model.register_forward_hook(hook))
+
+ def _verify_model_inputs(self, inputs):
+ """
+ Implements a check to ensure the inputs args are not changing between model invocation.
+ :param inputs: inputs captured with current invocation of hook
+ """
+ *positional_inputs, kwargs = inputs
+ input_args_type = [type(x) for x in positional_inputs] + [(n,type(v)) for n,v in kwargs.items()]
+ if self.input_args_type is None:
+ self.input_args_type = input_args_type
+ else:
+ assert self.input_args_type == input_args_type, f'unsupported option, inputs arguments are changing, expected `{self.input_args_type}`, got `{input_args_type}`'
+
+ def stop_recording(self):
+ """ helper method to remove all registered hook in the model being observed. """
+ if self.hook_handler:
+ for hook in self.hook_handler:
+ hook.remove()
+ self.hook_handler = []
+ logger.info("removed hooks for model, calibration data (num=%d) saved at:%s",
+ self.index, self.save_path)
+
+ def extract_prepare_model_info(self, model: torch.nn.Module) -> List[str]:
+ """
+ Register forward hook for model input/output.
+ """
+
+ param_keys = inspect.signature(model.forward).parameters.keys()
+ file_path = f'{self.save_path}/cal_0.npy'
+ assert os.path.isfile(file_path), f'{file_path} does not exist, did you observe data flowing through the model with this recorder?'
+ inputs = numpy.load(file_path, allow_pickle=True)[0]['inputs']
+ with open(os.path.join(self.save_path, FWD_SIGNATURE_FILE), 'rb') as f:
+ model_parameter_defaults = pickle.load(f)
+ inputs = ModelDataReader.format_inputs(inputs, model_parameter_defaults)
+ input_keys = flatten_model_inputs(dict(zip(model_parameter_defaults.keys(), inputs)), tensors_only=True).keys()
+ assert len(param_keys) == len(input_keys), f'Prepared model forward expected {len(param_keys)} inputs=`{param_keys}`, \
+ but observed calibration data contains {len(input_keys)} inputs=`{input_keys}`'
+
+ with open(os.path.join(self.save_path, PREP_FWD_MAP_FILE), 'wb') as f:
+ input_param_amp = dict((zip(input_keys, param_keys)))
+ logger.info("calibration inputs to args mapping: %s", input_param_amp)
+ pickle.dump(input_param_amp, f, protocol=pickle.HIGHEST_PROTOCOL)
+ return list(param_keys)
+
+
+ @property
+ def reader(self) -> ModelDataReader:
+ """ return an input and output data reader """
+ return ModelDataReader(self.save_path, input_only=False, is_prepared_model=False, num_samples=self.num_samples)
+
+ @property
+ def calibration_data(self) -> ModelDataReader:
+ """ return an input data reader """
+ return ModelDataReader(self.save_path, input_only=True, is_prepared_model=False, num_samples=self.num_samples)
+
+ @property
+ def prepared_model_calibration_data(self) -> ModelDataReader:
+ """ return an input data reader """
+ return ModelDataReader(self.save_path, input_only=True, is_prepared_model=True, num_samples=self.num_samples)
+
+ @property
+ def prepared_model_reader(self) -> ModelDataReader:
+ """ return an input data reader """
+ return ModelDataReader(self.save_path, input_only=False, is_prepared_model=True, num_samples=self.num_samples)
+
+
+def get_output_names(outputs: Any) -> Optional[List[str]]:
+ """
+ provides default output names for LVM models
+ :param outputs: FP32 model output
+ :return: return list of output names or None
+ """
+ if hasattr(outputs, 'to_tuple'):
+ outputs = outputs.to_tuple()
+ if isinstance(outputs, tuple) and isinstance(outputs[0], torch.Tensor):
+ return [f'output{i}' for i in range(len(outputs))]
+ elif isinstance(outputs, torch.Tensor):
+ return ['output']
+ return None
+
+def nested_map(data, fn: Callable[[torch.Tensor], torch.Tensor]):
+ """
+ Apply a function to a nested tuple, list, or dict of tensors.
+ :param data: Tensor, or a nested tuple, list, or dict of tensors.
+ :param fn: Function to apply to the tensors
+ :return: Nested structure of tensors with function applied
+ """
+ if isinstance(data, torch.Tensor):
+ return fn(data)
+
+ if isinstance(data, (tuple, list)):
+ cls = tuple if isinstance(data, tuple) else list
+ return cls(nested_map(x, fn) for x in data)
+
+ if isinstance(data, dict):
+ return {
+ key: nested_map(value, fn) for key, value in data.items()
+ }
+
+ logger.debug('unexpected input type=%s, expecting torch.Tensor, tuple, list, or dict. skipping..', type(data))
+ return data
+
+def change_tensor_dtype(data: Union[torch.Tensor, List, Tuple, Dict], dtype: torch.dtype):
+ """
+ Change the data's dtype.
+ :param data: Tensor, or a nested tuple, list, or dict of tensors.
+ :param dtype: dtype
+ :return: data with modified dtype.
+ """
+ return nested_map(data, lambda x: x.to(dtype=dtype) if x.is_floating_point() else x)
+
+
+@contextlib.contextmanager
+def get_num_adaround_iterations(dtype: torch.dtype, num_iterations: int):
+ """
+ Sets adaround config for fp16 execution from default fp32.
+ """
+ adaround_batch_size = adaround_optimizer.BATCH_SIZE
+ if dtype == torch.half:
+ adaround_optimizer.BATCH_SIZE = adaround_batch_size * 2
+ num_iterations = int(num_iterations / 2)
+ yield num_iterations
+ adaround_optimizer.BATCH_SIZE = adaround_batch_size
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/defs.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/defs.py
new file mode 100644
index 000000000..c2a947977
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/defs.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" common definitions used by aimet extension and test utils etc """
+from dataclasses import dataclass
+from typing import Callable, Dict, Any, Optional, List
+import torch
+
+
+@dataclass
+class ModelCfg:
+ """
+ Configuration for model that needs to be quantized.
+ """
+ model_path: str # path for model within the diffuser pipeline e.g. 'vae.decoder'
+ model: torch.nn.Module # model to be quantized.
+ model_input_format_fn: Callable[[Dict[str, Any]], Dict[str, Any]] = None # input format/filtering function
+ compute_sqnr_fn: Callable = None # method used for computing SQNR between FP and sim o/p.
+ out_mapper_fn: Callable = None # method used to map Tensor/Tuple(Tensor,) to FP model o/p.
+ adaround_num_iterations: int = None # number of adaround interation, 0 is to reuse, None is to disable.
+ input_names: Optional[List[str]] = None
+ output_names: Optional[List[str]] = None
+ intermediate_modules_names: Optional[List[str]] = None
+ checkpoints_config_file: Optional[str] = None # path for JSON file to provide checkpoints config
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/activations.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/activations.py
new file mode 100644
index 000000000..86b3bde15
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/activations.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file contains model adaptations for activations. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/activations.py"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate
+
+
+ACTIVATION_FUNCTIONS = {
+ "swish": nn.SiLU(),
+ "silu": nn.SiLU(),
+ "mish": nn.Mish(),
+ "gelu": nn.GELU(),
+ "relu": nn.ReLU(),
+}
+
+
+def get_activation(act_fn: str) -> nn.Module:
+ """Helper function to get activation function from string.
+
+ Args:
+ act_fn (str): Name of activation function.
+
+ Returns:
+ nn.Module: Activation function.
+ """
+
+ act_fn = act_fn.lower()
+ if act_fn in ACTIVATION_FUNCTIONS:
+ return ACTIVATION_FUNCTIONS[act_fn]
+ else:
+ raise ValueError(f"Unsupported activation function: {act_fn}")
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
+ super().__init__()
+ #self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.proj = nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias) # Qualcomm
+ self.approximate = approximate
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate, approximate=self.approximate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+class GEGLU(nn.Module):
+ r"""
+ A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ #self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+ self.proj = nn.Conv2d(dim_in, dim_out * 2, kernel_size=1, bias=bias) # Qualcomm
+
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states, *args, **kwargs):
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=1) # Qualcomm change from dim=-1 which is C on Linear, to dim=1 which is C on Conv2d
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ r"""
+ The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this
+ [paper](https://arxiv.org/abs/1606.08415).
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
+ super().__init__()
+ #self.proj = nn.Linear(dim_in, dim_out, bias=bias)
+ self.proj = nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias) # Qualcomm
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/attention.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/attention.py
new file mode 100644
index 000000000..00b75c92f
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/attention.py
@@ -0,0 +1,685 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to attention. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py"""
+
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate, logging
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from genai_lib.lvm.dev.diffusers.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.attention_processor import Attention
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
+
+
+logger = logging.get_logger(__name__)
+
+
+def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
+ # "feed_forward_chunk_size" can be used to save memory
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+ ff_output = torch.cat(
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+ return ff_output
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
+
+
+@maybe_allow_in_graph
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ only_cross_attention (`bool`, *optional*):
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
+ double_self_attention (`bool`, *optional*):
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
+ upcast_attention (`bool`, *optional*):
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
+ Whether to use learnable elementwise affine parameters for normalization.
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
+ final_dropout (`bool` *optional*, defaults to False):
+ Whether to apply a final dropout after the last feed-forward layer.
+ attention_type (`str`, *optional*, defaults to `"default"`):
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
+ positional_embeddings (`str`, *optional*, defaults to `None`):
+ The type of positional embeddings to apply to.
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
+ The maximum number of positional embeddings to apply.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_elementwise_affine: bool = True,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_eps: float = 1e-5,
+ final_dropout: bool = False,
+ attention_type: str = "default",
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
+ ada_norm_bias: Optional[int] = None,
+ ff_inner_dim: Optional[int] = None,
+ ff_bias: bool = True,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+
+ # We keep these boolean flags for backward-compatibility.
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
+ self.use_layer_norm = norm_type == "layer_norm"
+ self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
+
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
+ )
+
+ self.norm_type = norm_type
+ self.num_embeds_ada_norm = num_embeds_ada_norm
+
+ if positional_embeddings and (num_positional_embeddings is None):
+ raise ValueError(
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
+ )
+
+ if positional_embeddings == "sinusoidal":
+ self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
+ else:
+ self.pos_embed = None
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ if norm_type == "ada_norm":
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_zero":
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm1 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None or double_self_attention:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ if norm_type == "ada_norm":
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
+ elif norm_type == "ada_norm_continuous":
+ self.norm2 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "rms_norm",
+ )
+ else:
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ out_bias=attention_out_bias,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ if norm_type == "ada_norm_continuous":
+ self.norm3 = AdaLayerNormContinuous(
+ dim,
+ ada_norm_continous_conditioning_embedding_dim,
+ norm_elementwise_affine,
+ norm_eps,
+ ada_norm_bias,
+ "layer_norm",
+ )
+
+ elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
+ self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
+ elif norm_type == "layer_norm_i2vgen":
+ self.norm3 = None
+
+ self.ff = FeedForward(
+ dim,
+ dropout=dropout,
+ activation_fn=activation_fn,
+ final_dropout=final_dropout,
+ inner_dim=ff_inner_dim,
+ bias=ff_bias,
+ )
+
+ # 4. Fuser
+ if attention_type == "gated" or attention_type == "gated-text-image":
+ self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
+
+ # 5. Scale-shift for PixArt-Alpha.
+ if norm_type == "ada_norm_single":
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = 0
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ self._chunk_dim = dim
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ ) -> torch.FloatTensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm1(hidden_states, timestep)
+ elif self.norm_type == "ada_norm_zero":
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
+ # norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = self.norm1(hidden_states.permute(0,2,3,1)).permute(0,3,1,2) # Qualcomm change for 4d inputs
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif self.norm_type == "ada_norm_single":
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.squeeze(1)
+ else:
+ raise ValueError("Incorrect norm used")
+
+ if self.pos_embed is not None:
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ # 1. Prepare GLIGEN inputs
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ if self.norm_type == "ada_norm_zero":
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ elif self.norm_type == "ada_norm_single":
+ attn_output = gate_msa * attn_output
+
+ hidden_states = attn_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ # 1.2 GLIGEN Control
+ if gligen_kwargs is not None:
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ if self.norm_type == "ada_norm":
+ norm_hidden_states = self.norm2(hidden_states, timestep)
+ elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
+ # norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = self.norm2(hidden_states.permute(0,2,3,1)).permute(0,3,1,2) # Qualcomm change for 4d inputs
+ elif self.norm_type == "ada_norm_single":
+ # For PixArt norm2 isn't applied here:
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
+ norm_hidden_states = hidden_states
+ elif self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ else:
+ raise ValueError("Incorrect norm")
+
+ if self.pos_embed is not None and self.norm_type != "ada_norm_single":
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ # i2vgen doesn't have this norm 🤷♂️
+ if self.norm_type == "ada_norm_continuous":
+ norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
+ elif not self.norm_type == "ada_norm_single":
+ # norm_hidden_states = self.norm3(hidden_states)
+ norm_hidden_states = self.norm3(hidden_states.permute(0,2,3,1)).permute(0,3,1,2) # Qualcomm change for 4d inputs
+ if self.norm_type == "ada_norm_zero":
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ if self.norm_type == "ada_norm_single":
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ if self._chunk_size is not None:
+ # "feed_forward_chunk_size" can be used to save memory
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ### Qualcomm modification to handle expected 3d input ###
+ #B, C, H, W = norm_hidden_states.shape
+ #ff_output = self.ff(norm_hidden_states.reshape(B, C, H*W).permute(0,2,1)).permute(0,2,1).reshape(B, C, H, W) # Qualcomm change to handle 3d layer
+ ff_output = self.ff(norm_hidden_states)
+ ### end ###
+
+ if self.norm_type == "ada_norm_zero":
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+ elif self.norm_type == "ada_norm_single":
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = ff_output + hidden_states
+ if hidden_states.ndim == 4:
+ hidden_states = hidden_states.squeeze(1)
+
+ return hidden_states
+
+
+@maybe_allow_in_graph
+class TemporalBasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block for video like data.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ time_mix_inner_dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.is_res = dim == time_mix_inner_dim
+
+ self.norm_in = nn.LayerNorm(dim)
+
+ # Define 3 blocks. Each block has its own normalization layer.
+ # 1. Self-Attn
+ self.ff_in = FeedForward(
+ dim,
+ dim_out=time_mix_inner_dim,
+ activation_fn="geglu",
+ )
+
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn1 = Attention(
+ query_dim=time_mix_inner_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ cross_attention_dim=None,
+ )
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None:
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
+ # the second cross attention block.
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
+ self.attn2 = Attention(
+ query_dim=time_mix_inner_dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.norm2 = None
+ self.attn2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
+
+ # let chunk size default to None
+ self._chunk_size = None
+ self._chunk_dim = None
+
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
+ # Sets chunk feed-forward
+ self._chunk_size = chunk_size
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
+ self._chunk_dim = 1
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ num_frames: int,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ # Notice that normalization is always applied before the real computation in the following blocks.
+ # 0. Self-Attention
+ batch_size = hidden_states.shape[0]
+
+ batch_frames, seq_length, channels = hidden_states.shape
+ batch_size = batch_frames // num_frames
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
+
+ residual = hidden_states
+ hidden_states = self.norm_in(hidden_states)
+
+ if self._chunk_size is not None:
+ hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ hidden_states = self.ff_in(hidden_states)
+
+ if self.is_res:
+ hidden_states = hidden_states + residual
+
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
+ hidden_states = attn_output + hidden_states
+
+ # 3. Cross-Attention
+ if self.attn2 is not None:
+ norm_hidden_states = self.norm2(hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm3(hidden_states)
+
+ if self._chunk_size is not None:
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
+ else:
+ ff_output = self.ff(norm_hidden_states)
+
+ if self.is_res:
+ hidden_states = ff_output + hidden_states
+ else:
+ hidden_states = ff_output
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
+
+ return hidden_states
+
+
+class SkipFFTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ kv_input_dim: int,
+ kv_input_dim_proj_use_bias: bool,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ attention_out_bias: bool = True,
+ ):
+ super().__init__()
+ if kv_input_dim != dim:
+ self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
+ else:
+ self.kv_mapper = None
+
+ self.norm1 = RMSNorm(dim, 1e-06)
+
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim,
+ out_bias=attention_out_bias,
+ )
+
+ self.norm2 = RMSNorm(dim, 1e-06)
+
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ out_bias=attention_out_bias,
+ )
+
+ def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
+
+ if self.kv_mapper is not None:
+ encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
+
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ norm_hidden_states = self.norm2(hidden_states)
+
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ **cross_attention_kwargs,
+ )
+
+ hidden_states = attn_output + hidden_states
+
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ inner_dim=None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ #self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ self.net.append(nn.Conv2d(inner_dim, dim_out, kernel_size=1, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/attention_processors.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/attention_processors.py
new file mode 100644
index 000000000..61800f2a2
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/attention_processors.py
@@ -0,0 +1,639 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" Custom attention processors for improving on device latency and accuracy of Diffusers-based models https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py"""
+
+from typing import Optional
+
+import torch
+from aimet_torch import elementwise_ops
+from diffusers.models.attention_processor import Attention
+
+from genai_lib.common.dev.model_adaptation.linear_to_conv import ConvInplaceLinear
+
+
+class AttnProcessorConvSHA:
+ """
+ Convolution SHA processor for performing attention-related computations.
+ Modified from source: https://github.com/huggingface/diffusers/blob/d7001400764acb8de5df343bbc4c54479c0e6ebe/src/diffusers/models/attention_processor.py#L710
+ """
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor, # 2, 4096, 640 / 2, 1024, 1280
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # 2, 77, 2048
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ head_dim = int(attn.inner_dim/attn.heads)
+ args = () #if USE_PEFT_BACKEND else (scale,)
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ q_weight = attn.to_q.weight.reshape(attn.to_q.weight.shape[:2])
+ k_weight = attn.to_k.weight.reshape(attn.to_k.weight.shape[:2])
+ v_weight = attn.to_v.weight.reshape(attn.to_v.weight.shape[:2])
+ proj_weight = attn.to_out[0].weight.reshape(attn.to_out[0].weight.shape[:2])
+ scale_sqrt = attn.scale**0.5
+
+ ### SHA ###
+ hidden_states = hidden_states.unsqueeze(-1).permute(0, 2, 3, 1) # -> (B, C, 1, emb_dim)
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(-1).permute(0, 2, 3, 1) # -> (B, C, 1, emb_dim)
+ outs = []
+ for idx in range(attn.heads):
+ # (B, C, H, W) -> conv -> (B, head_dim, H, W)
+ _query = torch.nn.functional.conv2d(hidden_states,
+ weight=q_weight[idx*head_dim:(idx+1)*head_dim, :, None, None] * scale_sqrt,
+ bias=attn.to_q.bias[idx*head_dim:(idx+1)*head_dim] * scale_sqrt if attn.to_q.bias is not None else None,
+ ).permute(0, 2, 3, 1) # -> (B, H, W, head_dim)
+ _key_T = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=k_weight[idx*head_dim:(idx+1)*head_dim, :, None, None] * scale_sqrt,
+ bias=attn.to_k.bias[idx*head_dim:(idx+1)*head_dim] * scale_sqrt if attn.to_k.bias is not None else None,
+ ).reshape(batch_size, 1, head_dim, -1) # -> (B, 1, head_dim, H*W)
+ _value = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=v_weight[idx*head_dim:(idx+1)*head_dim, :, None, None],
+ bias=attn.to_v.bias[idx*head_dim:(idx+1)*head_dim] if attn.to_v.bias is not None else None,
+ ).permute(0, 2, 3, 1).reshape(batch_size, 1, -1, head_dim) # -> (B, 1, H*W, head_dim)
+
+ _attn = elementwise_ops.MatMul()(_query, _key_T) # -> (B, H, W, H*W)
+ _attn = torch.nn.functional.softmax(_attn, dim=-1) # -> (B, H, W, H*W)
+ _out = elementwise_ops.MatMul()(_attn, _value) # -> (B, H, W, head_dim)
+ outs.append(_out)
+
+ hidden_states = elementwise_ops.Concat(-1)(*outs) # -> (B, H, W, C)
+ ### END ###
+
+ # linear proj
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # -> (B, C, H, W)
+ hidden_states = torch.nn.functional.conv2d(hidden_states,
+ weight=proj_weight[:, :, None, None],
+ bias=attn.to_out[0].bias if attn.to_out[0].bias is not None else None,
+ ) # -> (B, C, H, W)
+ hidden_states = hidden_states.permute(0, 2, 3, 1).squeeze(1) # -> (B, emb_dim, C)
+
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+
+class AttnProcessorWithoutMask:
+ """
+ Linear MHA processor without Attention Mask for performing attention-related computations.
+ Modified from source: https://github.com/huggingface/diffusers/blob/d7001400764acb8de5df343bbc4c54479c0e6ebe/src/diffusers/models/attention_processor.py#L710
+ """
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+
+ args = () #if USE_PEFT_BACKEND else (scale,)
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ ### Without Mask ###
+ query = attn.to_q(hidden_states, *args)
+ key = attn.to_k(encoder_hidden_states, *args)
+ value = attn.to_v(encoder_hidden_states, *args)
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+ attention_scores = torch.bmm(query * attn.scale, key.transpose(-1, -2))
+ attention_probs = attention_scores.softmax(dim=-1)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+ ## END ###
+
+ ### MHA reference ###
+ # query = attn.to_q(hidden_states, *args)
+ # key = attn.to_k(encoder_hidden_states, *args)
+ # value = attn.to_v(encoder_hidden_states, *args)
+ # query = attn.head_to_batch_dim(query)
+ # key = attn.head_to_batch_dim(key)
+ # value = attn.head_to_batch_dim(value)
+ # attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ # hidden_states = torch.bmm(attention_probs, value)
+ # hidden_states = attn.batch_to_head_dim(hidden_states)
+ ## END ###
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states, *args)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+
+class AttnProcessorVAE:
+ """
+ Convolution tiled-SHA processor for slicing VAE-attention computations.
+ Modified from source: https://github.com/huggingface/diffusers/blob/d7001400764acb8de5df343bbc4c54479c0e6ebe/src/diffusers/models/attention_processor.py#L710
+ """
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ inner_dim = 512,
+ heads = 1,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ head_dim = int(inner_dim/heads)
+ args = () #if USE_PEFT_BACKEND else (scale,)
+
+ batch_size, channel, height, width = hidden_states.shape
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states)#.transpose(1, 2))#.transpose(1, 2)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ q_weight = attn.to_q.weight.reshape(attn.to_q.weight.shape[:2])
+ k_weight = attn.to_k.weight.reshape(attn.to_k.weight.shape[:2])
+ v_weight = attn.to_v.weight.reshape(attn.to_v.weight.shape[:2])
+ proj_weight = attn.to_out[0].weight.reshape(attn.to_out[0].weight.shape[:2])
+ scale_sqrt = attn.scale**0.5
+
+ ### SHA ###
+ outs = []
+ for idx in range(heads):
+ # (B, C, H, W) -> conv -> (B, head_dim, H, W)
+ _query = torch.nn.functional.conv2d(hidden_states,
+ weight=q_weight[idx*head_dim:(idx+1)*head_dim, :, None, None] * scale_sqrt,
+ bias=attn.to_q.bias[idx*head_dim:(idx+1)*head_dim] * scale_sqrt if attn.to_q.bias is not None else None,
+ ).permute(0, 2, 3, 1) # -> (B, H, W, head_dim)
+ _key_T = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=k_weight[idx*head_dim:(idx+1)*head_dim, :, None, None] * scale_sqrt,
+ bias=attn.to_k.bias[idx*head_dim:(idx+1)*head_dim] * scale_sqrt if attn.to_k.bias is not None else None,
+ ).reshape(batch_size, 1, head_dim, -1) # -> (B, 1, head_dim, H*W)
+ _value = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=v_weight[idx*head_dim:(idx+1)*head_dim, :, None, None],
+ bias=attn.to_v.bias[idx*head_dim:(idx+1)*head_dim] if attn.to_v.bias is not None else None,
+ ).permute(0, 2, 3, 1).reshape(batch_size, 1, -1, head_dim) # -> (B, 1, H*W, head_dim)
+
+ _attn = elementwise_ops.MatMul()(_query, _key_T) # -> (B, H, W, H*W)
+ _attn = torch.nn.functional.softmax(_attn, dim=-1) # -> (B, H, W, H*W)
+ _out = elementwise_ops.MatMul()(_attn, _value) # -> (B, H, W, head_dim)
+ outs.append(_out)
+
+ hidden_states = elementwise_ops.Concat(-1)(*outs) # -> (B, H, W, C)
+ ### END ###
+
+ # linear proj
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # -> (B, C, H, W)
+ hidden_states = torch.nn.functional.conv2d(hidden_states,
+ weight=proj_weight[:, :, None, None],
+ bias=attn.to_out[0].bias if attn.to_out[0].bias is not None else None,
+ ) # -> (B, C, H, W)
+
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states) # -> (B, C, H, W)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual # -> (B, C, H, W)
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+
+class AttnProcessorUnet4dSHA:
+ """
+ Convolution SHA processor for performing attention-related computations.
+ Modified from source: https://github.com/huggingface/diffusers/blob/d7001400764acb8de5df343bbc4c54479c0e6ebe/src/diffusers/models/attention_processor.py#L710
+ """
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor, # 2, 4096, 640 / 2, 1024, 1280
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # 2, 77, 2048
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ head_dim = int(attn.inner_dim/attn.heads)
+ args = () #if USE_PEFT_BACKEND else (scale,)
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ batch_size, channel, height, width = hidden_states.shape
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ if encoder_hidden_states.ndim == 3:
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(-1).permute(0, 2, 3, 1) # B, emb_dim, C -> B, C, 1, emb_dim
+
+ q_weight = attn.to_q.weight.reshape(attn.to_q.weight.shape[:2])
+ k_weight = attn.to_k.weight.reshape(attn.to_k.weight.shape[:2])
+ v_weight = attn.to_v.weight.reshape(attn.to_v.weight.shape[:2])
+ proj_weight = attn.to_out[0].weight.reshape(attn.to_out[0].weight.shape[:2])
+ scale_sqrt = attn.scale**0.5
+
+ ### SHA ###
+ outs = []
+ for idx in range(attn.heads):
+ # (B, C, H, W) -> conv -> (B, head_dim, H, W)
+ _query = torch.nn.functional.conv2d(hidden_states,
+ weight=q_weight[idx*head_dim:(idx+1)*head_dim, :, None, None] * scale_sqrt,
+ bias=attn.to_q.bias[idx*head_dim:(idx+1)*head_dim] * scale_sqrt if attn.to_q.bias is not None else None,
+ ).permute(0, 2, 3, 1) # -> (B, H, W, head_dim)
+ _key_T = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=k_weight[idx*head_dim:(idx+1)*head_dim, :, None, None] * scale_sqrt,
+ bias=attn.to_k.bias[idx*head_dim:(idx+1)*head_dim] * scale_sqrt if attn.to_k.bias is not None else None,
+ ).reshape(batch_size, 1, head_dim, -1) # -> (B, 1, head_dim, H*W)
+ _value = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=v_weight[idx*head_dim:(idx+1)*head_dim, :, None, None],
+ bias=attn.to_v.bias[idx*head_dim:(idx+1)*head_dim] if attn.to_v.bias is not None else None,
+ ).permute(0, 2, 3, 1).reshape(batch_size, 1, -1, head_dim) # -> (B, 1, H*W, head_dim)
+
+ _attn = elementwise_ops.MatMul()(_query, _key_T) # -> (B, H, W, H*W)
+ _attn = torch.nn.functional.softmax(_attn, dim=-1) # -> (B, H, W, H*W)
+ _out = elementwise_ops.MatMul()(_attn, _value) # -> (B, H, W, head_dim)
+ outs.append(_out)
+
+ hidden_states = elementwise_ops.Concat(-1)(*outs) # -> (B, H, W, C)
+ ### END ###
+
+ # linear proj
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # -> (B, C, H, W)
+ hidden_states = torch.nn.functional.conv2d(hidden_states,
+ weight=proj_weight[:, :, None, None],
+ bias=attn.to_out[0].bias if attn.to_out[0].bias is not None else None,
+ ) # -> (B, C, H, W)
+
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+ hidden_states = hidden_states / attn.rescale_output_factor
+ return hidden_states
+
+
+class AttnProcessorUnet4dMHA:
+ """
+ Convolution SHA processor for performing attention-related computations.
+ Modified from source: https://github.com/huggingface/diffusers/blob/d7001400764acb8de5df343bbc4c54479c0e6ebe/src/diffusers/models/attention_processor.py#L710
+ Forward with SHA but treat linears as convs
+ """
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor, # 2, 4096, 640 / 2, 1024, 1280
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # 2, 77, 2048
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ head_dim = int(attn.inner_dim/attn.heads)
+ input_dim = hidden_states.ndim
+
+ if attn.spatial_norm is not None: # None
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ if attn.group_norm is not None: # None
+ hidden_states = attn.group_norm(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ batch_size, channel, height, width = hidden_states.shape
+
+ if encoder_hidden_states.ndim == 3:
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(-1).permute(0, 2, 3, 1) # (B, L, D) -> (B, D, 1, L)/(B, C, 1, emb_dim)
+
+ q_weight = attn.to_q.weight.reshape(attn.to_q.weight.shape[:2])
+ k_weight = attn.to_k.weight.reshape(attn.to_k.weight.shape[:2])
+ v_weight = attn.to_v.weight.reshape(attn.to_v.weight.shape[:2])
+ proj_weight = attn.to_out[0].weight.reshape(attn.to_out[0].weight.shape[:2])
+
+ ### MHA2SHA friendly MHA-Conv ###
+ # Assuming hidden_state shape = (B, C, H, W)
+ # Assuming encoder_hidden_states shape = (B, C, 1, emb_dim) or (B, C, H, W)
+ # Returning shape (to match sha after concat) = (B, H, W, C)
+
+ # All the ops in MHA will be replaced with sha in G2G, so inefficient ops will be cleaned up.
+ # They are doing in such way that G2G can capture pattern better.
+ # Requirement for a G2G friendly MHA-Conv:
+ # 1) G2G uses the first reshape's shape[-2] after q_proj to determine head num.
+ # Remember to have the reshape after query in [..., head_num, head_dim]
+ # 2) We have observed that ONNX will fold Mul in Matmul(Q, K) -> Mul -> Softmax to somewhere
+ # G2G do not have logic to search for scale in MatMul(Q*scale, K) pattern, so before the
+ # folding issue in Matmul(Q, K)*scale -> Softmax is clear, We should manually fold scale
+ # into query and key weights.
+ # e.g. Conv2d(x, q.weight) -> Conv2d(x, q.weight*scale), or
+ # Fold with sqrt: Conv2d(x, q.weight*sqrt(scale)) on both Q, and K for better numirical stability.
+ #
+ #
+ # ONNX mha2sha converter will cleanup permute -> reshape after qkv_matmul.
+
+ _query = torch.nn.functional.conv2d(hidden_states,
+ weight=q_weight[..., None, None],
+ bias=attn.to_q.bias,
+ ).permute(0, 2, 3, 1).reshape(batch_size, -1, attn.heads, head_dim) # (B, C, H, W) -> (B, H, W, head_dim) -> (B, H*W, head, head_dim)
+ _query = _query.permute(0, 2, 1, 3) # (B, H*W, head, head_dim) -> (B, head, H*W, head_dim)
+
+ _key = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=k_weight[..., None, None],
+ bias=attn.to_k.bias,
+ ).reshape(batch_size, attn.heads, head_dim, -1) # (B, D, 1, L) -> (B, head, d, L)
+ # (B, C, H, W) -> (B, head, d, H*W)
+
+ _value = torch.nn.functional.conv2d(encoder_hidden_states,
+ weight=v_weight[..., None, None],
+ bias=attn.to_v.bias if attn.to_v.bias is not None else None,
+ ).reshape(batch_size, attn.heads, head_dim, -1).permute(0, 1, 3, 2) # (B, D, 1, L) -> (B, head, d, L) -> (B, head, L, d)
+ # # (B, C, H, W) -> (B, head, d, H*W) -> (B, head, H*W, d)
+ _attn = elementwise_ops.MatMul()(_query * attn.scale, _key) # (B, head, L, L)/(B, head, H*W, H*W)
+ _attn = torch.nn.functional.softmax(_attn, dim=-1)
+ _out = elementwise_ops.MatMul()(_attn, _value) # (B, head, L, d)/(B, head, H*W, head_dim)
+
+ # These permute reshape will be cleaned up in MHA2SHA-onnx-converter
+ _out = _out.permute(0, 2, 1, 3) # (B, head, H*W, d) -> (B, H*W, head, d)
+ hidden_states = _out.reshape(batch_size, height, width, -1) # (B, H*W, head, d) -> (B, H, W, C)
+ ### END ###
+
+ # linear proj
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
+ hidden_states = torch.nn.functional.conv2d(hidden_states,
+ weight=proj_weight[..., None, None],
+ bias=attn.to_out[0].bias if attn.to_out[0].bias is not None else None) # -> (B, C, H, W)
+
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+ hidden_states = hidden_states / attn.rescale_output_factor
+ hidden_states = hidden_states.reshape(batch_size, channel, height, width)
+ return hidden_states
+
+
+class AttnProcessorLoRA4D:
+ """
+ LoRA-Compatible processor for performing attention-related computations.
+ Modified from source: https://github.com/huggingface/diffusers/blob/d7001400764acb8de5df343bbc4c54479c0e6ebe/src/diffusers/models/attention_processor.py#L710
+ """
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor, # 2, 4096, 640 / 2, 1024, 1280
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # 2, 77, 2048
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ head_dim = int(attn.inner_dim/attn.heads)
+
+ if attn.spatial_norm is not None: # None
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ if attn.group_norm is not None: # None
+ hidden_states = attn.group_norm(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ batch_size, channel, height, width = hidden_states.shape
+
+ if encoder_hidden_states.ndim == 3:
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(-1).permute(0, 2, 3, 1) # (B, L, D) -> (B, D, 1, L)/(B, C, 1, emb_dim)
+
+ if isinstance(attn.to_q, (torch.nn.Linear, ConvInplaceLinear)): # Linear expects channel's last
+ _query = attn.to_q(hidden_states.permute(0, 2, 3, 1)).reshape(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3) # -> (B, head, H*W, head_dim)
+ else: # Conv expects channel's first
+ _query = attn.to_q(hidden_states).permute(0, 2, 3, 1).reshape(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3) # -> (B, head, H*W, head_dim)
+
+ if isinstance(attn.to_k, (torch.nn.Linear, ConvInplaceLinear)): # Linear expects channel's last
+ _key_T = attn.to_k(encoder_hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).reshape(batch_size, attn.heads, head_dim, -1) # -> (B, heads, head_dim, H*W)
+ else: # Conv expects channel's first
+ _key_T = attn.to_k(encoder_hidden_states).reshape(batch_size, attn.heads, head_dim, -1) # -> (B, heads, head_dim, H*W)
+
+ if isinstance(attn.to_v, (torch.nn.Linear, ConvInplaceLinear)): # Linear expects channel's last
+ _value = attn.to_v(encoder_hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).reshape(batch_size, attn.heads, head_dim, -1).permute(0, 1, 3, 2) # -> (B, heads, H*W, head_dim)
+ else: # Conv expects channel's first
+ _value = attn.to_v(encoder_hidden_states).reshape(batch_size, attn.heads, head_dim, -1).permute(0, 1, 3, 2) # -> (B, heads, H*W, head_dim)
+
+ _attn = elementwise_ops.MatMul()(_query * attn.scale, _key_T) # -> (B, heads, H*W, H*W)
+ _attn = torch.nn.functional.softmax(_attn, dim=-1) # -> (B, heads, H*W, H*W)
+ _out = elementwise_ops.MatMul()(_attn, _value) # -> (B, heads, H*W, head_dim)
+
+ # linear proj
+ _out = _out.permute(0, 2, 1, 3) # (B, heads, H*W, head_dim) -> (B, H*W, heads, head_dim)
+ hidden_states = _out.reshape(batch_size, height, width, -1) # (B, H*W, heads, head_dim) -> (B, H, W, C)
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
+
+ if isinstance(attn.to_out[0], (torch.nn.Linear, ConvInplaceLinear)): # Linear expects channel's last
+ hidden_states = attn.to_out[0](hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ else: # Conv expects channel's first
+ hidden_states = attn.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+ hidden_states = hidden_states / attn.rescale_output_factor
+ hidden_states = hidden_states.reshape(batch_size, channel, height, width)
+ return hidden_states
+
+class AttnProcessorUnet4dMhaNamed:
+ """
+ Convolution SHA processor for performing attention-related computations.
+ Modified from source: https://github.com/huggingface/diffusers/blob/d7001400764acb8de5df343bbc4c54479c0e6ebe/src/diffusers/models/attention_processor.py#L710
+ Forward with SHA but treat linears as convs
+ """
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.FloatTensor, # 2, 4096, 640 / 2, 1024, 1280
+ encoder_hidden_states: Optional[torch.FloatTensor] = None, # 2, 77, 2048
+ attention_mask: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ head_dim = int(attn.inner_dim/attn.heads)
+ input_dim = hidden_states.ndim
+
+ if attn.spatial_norm is not None: # None
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ if attn.group_norm is not None: # None
+ hidden_states = attn.group_norm(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ batch_size, channel, height, width = hidden_states.shape
+
+ if encoder_hidden_states.ndim == 3:
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(-1).permute(0, 2, 3, 1) # (B, L, D) -> (B, D, 1, L)/(B, C, 1, emb_dim)
+
+ _query = attn.to_q(hidden_states).permute(0, 2, 3, 1).reshape(batch_size, -1, attn.heads, head_dim).permute(0, 2, 1, 3) # (B, C, H, W) -> (B, H, W, head_dim) -> (B, H*W, head, head_dim) -> (B, head, H*W, head_dim)
+ _key_T = attn.to_k(encoder_hidden_states).reshape(batch_size, attn.heads, head_dim, -1) # (B, C, H, W) -> (B, head, d, H*W)
+ _value = attn.to_v(encoder_hidden_states).reshape(batch_size, attn.heads, head_dim, -1).permute(0, 1, 3, 2) # (B, C, H, W) -> (B, head, d, H*W) -> (B, head, H*W, d)
+
+ _attn = elementwise_ops.MatMul()(_query * attn.scale, _key_T) # (B, head, L, L)/(B, head, H*W, H*W)
+ _attn = torch.nn.functional.softmax(_attn, dim=-1)
+ _out = elementwise_ops.MatMul()(_attn, _value) # (B, head, L, d)/(B, head, H*W, head_dim)
+
+ # These permute reshape will be cleaned up in MHA2SHA-onnx-converter
+ _out = _out.permute(0, 2, 1, 3) # (B, head, H*W, d) -> (B, H*W, head, d)
+ hidden_states = _out.reshape(batch_size, height, width, -1) # (B, H*W, head, d) -> (B, H, W, C)
+ ### END ###
+
+ # linear proj
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
+ hidden_states = attn.to_out[0](hidden_states) # -> (B, C, H, W)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+ hidden_states = hidden_states / attn.rescale_output_factor
+ hidden_states = hidden_states.reshape(batch_size, channel, height, width)
+ return hidden_states
+
+
+def replace_attention_linear_with_conv(module):
+ '''
+ Replace linear layers in attention with conv2d layers. This function will also align the module's qkvo layer
+ names with that in AttnProcessorLoRA4D for ease of loading base encodings. When this function is used to do
+ linear to conv adaptation in attention layers, AttnProcessorUnet4dMhaNamed should be used as attention processor.
+
+ :param module: The attention module in which linears are to be replaced
+ :return: None. Replacement are in-place
+ '''
+ q_weight = module.to_q.weight.reshape(module.to_q.weight.shape[:2])
+ q_bias = module.to_q.bias
+ module.to_q = torch.nn.Conv2d(q_weight.shape[1], q_weight.shape[0], kernel_size=1,
+ bias=module.to_q.bias is not None, device=q_weight.device)
+ module.to_q.weight.data = q_weight[:, :, None, None]
+ if module.to_q.bias is not None:
+ module.to_q.bias.data = q_bias
+
+ k_weight = module.to_k.weight.reshape(module.to_k.weight.shape[:2])
+ k_bias = module.to_k.bias
+ module.to_k = torch.nn.Conv2d(k_weight.shape[1], k_weight.shape[0], kernel_size=1,
+ bias=module.to_k.bias is not None, device=k_weight.device)
+ module.to_k.weight.data = k_weight[:, :, None, None]
+ if module.to_k.bias is not None:
+ module.to_k.bias.data = k_bias
+
+ v_weight = module.to_v.weight.reshape(module.to_v.weight.shape[:2])
+ v_bias = module.to_v.bias
+ module.to_v = torch.nn.Conv2d(v_weight.shape[1], v_weight.shape[0], kernel_size=1,
+ bias=module.to_v.bias is not None, device=v_weight.device)
+ module.to_v.weight.data = v_weight[:, :, None, None]
+ if module.to_v.bias is not None:
+ module.to_v.bias.data = v_bias
+
+ proj_weight = module.to_out[0].weight.reshape(module.to_out[0].weight.shape[:2])
+ proj_bias = module.to_out[0].bias
+ module.to_out[0] = torch.nn.Conv2d(proj_weight.shape[1], proj_weight.shape[0], kernel_size=1,
+ bias=module.to_out[0].bias is not None, device=proj_weight.device)
+ module.to_out[0].weight.data = proj_weight[:, :, None, None]
+ if module.to_out[0].bias is not None:
+ module.to_out[0].bias.data = proj_bias
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/pipeline_stable_diffusion.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/pipeline_stable_diffusion.py
new file mode 100644
index 000000000..0c6caba36
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/pipeline_stable_diffusion.py
@@ -0,0 +1,1050 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""This file provides adaptations to the stable diffusion pipeline. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py"""
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ LoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ def decode_latents(self, latents):
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
+
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
+ using zero terminal SNR.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ # to deal with lora scaling and other possible forward hooks
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 6.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None)
+ else None
+ )
+
+ # 6.2 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Denoising loop
+ self.unet = self.unet.to(device)
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # Qualcomm change, pass precomputed t_emb instead of having Unet compute from timestep input
+ t_emb = self.unet.get_time_embed(sample=latent_model_input, timestep=t).to(self.unet.device)
+ emb = self.unet.time_embedding(t_emb, timestep_cond).to(self.unet.device)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ emb,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
+ 0
+ ]
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/pipeline_stable_diffusion_xl.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/pipeline_stable_diffusion_xl.py
new file mode 100644
index 000000000..d50511207
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/pipeline_stable_diffusion_xl.py
@@ -0,0 +1,1380 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""This file provides adaptations to the stable diffusion xl pipeline. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py"""
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection#, UNet2DConditionModel
+from .unet_2d_condition import UNet2DConditionModel #custom module supporting precomputing of timestep
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from .watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXLPipeline(
+ DiffusionPipeline,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "negative_add_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if self.do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ image_embeds = ip_adapter_image_embeds
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stages where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if not hasattr(self, "unet"):
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
+ def disable_freeu(self):
+ """Disables the FreeU mechanism if enabled."""
+ self.unet.disable_freeu()
+
+ def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+ """
+ self.fusing_unet = False
+ self.fusing_vae = False
+
+ if unet:
+ self.fusing_unet = True
+ self.unet.fuse_qkv_projections()
+ self.unet.set_attn_processor(FusedAttnProcessor2_0())
+
+ if vae:
+ if not isinstance(self.vae, AutoencoderKL):
+ raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
+
+ self.fusing_vae = True
+ self.vae.fuse_qkv_projections()
+ self.vae.set_attn_processor(FusedAttnProcessor2_0())
+
+ def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
+ """Disable QKV projection fusion if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ Args:
+ unet (`bool`, defaults to `True`): To apply fusion on the UNet.
+ vae (`bool`, defaults to `True`): To apply fusion on the VAE.
+
+ """
+ if unet:
+ if not self.fusing_unet:
+ logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.unet.unfuse_qkv_projections()
+ self.fusing_unet = False
+
+ if vae:
+ if not self.fusing_vae:
+ logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
+ else:
+ self.vae.unfuse_qkv_projections()
+ self.fusing_vae = False
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 9. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self.unet = self.unet.to(device=device)
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds.to(self.unet.device)}#, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ # Qualcomm change, pass precomputed t_emb instead of having Unet compute from timestep input
+ t_emb = self.unet.get_time_embed(sample=latent_model_input.to(self.dtype), timestep=t).to(self.device)
+ t_emb = t_emb.to(self.dtype)
+ timestep_cond = timestep_cond.to(self.dtype) if timestep_cond is not None else None
+ emb = self.unet.time_embedding(t_emb, timestep_cond).to(self.unet.device)
+
+ # Qualcomm change, pass precomputed time_embeds instead of having Unet compute from timestep input
+ time_embeds = self.unet.add_time_proj(add_time_ids.flatten().to(self.unet.device).to(self.dtype))
+ time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
+ time_embeds = time_embeds.to(self.dtype)
+ self.unet.time_embeds = time_embeds
+ #added_cond_kwargs.update({'time_embeds': time_embeds})
+
+ noise_pred = self.unet(
+ latent_model_input,
+ emb,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=None, # Qualcomm change: `emb`was already precomputed with `timestep_cond`, so pass None as input here
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.decoder.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image.sample if not torch.is_tensor(image) else image, output_type=output_type) #changed
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/unet_2d_condition.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/unet_2d_condition.py
new file mode 100644
index 000000000..3d16fce48
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/remove_timestep/unet_2d_condition.py
@@ -0,0 +1,1333 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""This file provides adaptations to unet. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py"""
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ GLIGENTextBoundingboxProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unets.unet_2d_blocks import (
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ self._check_config(
+ down_block_types=down_block_types,
+ up_block_types=up_block_types,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
+ time_embedding_type,
+ block_out_channels=block_out_channels,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ time_embedding_dim=time_embedding_dim,
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ self._set_encoder_hid_proj(
+ encoder_hid_dim_type,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ )
+
+ # class embedding
+ self._set_class_embedding(
+ class_embed_type,
+ act_fn=act_fn,
+ num_class_embeds=num_class_embeds,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ timestep_input_dim=timestep_input_dim,
+ )
+
+ self._set_add_embedding(
+ addition_embed_type,
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
+ addition_time_embed_dim=addition_time_embed_dim,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ )
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ temb_channels=blocks_time_embed_dim,
+ in_channels=block_out_channels[-1],
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ output_scale_factor=mid_block_scale_factor,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ num_attention_heads=num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[-1],
+ dropout=dropout,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
+
+ def _check_config(
+ self,
+ down_block_types: Tuple[str],
+ up_block_types: Tuple[str],
+ only_cross_attention: Union[bool, Tuple[bool]],
+ block_out_channels: Tuple[int],
+ layers_per_block: Union[int, Tuple[int]],
+ cross_attention_dim: Union[int, Tuple[int]],
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
+ reverse_transformer_layers_per_block: bool,
+ attention_head_dim: int,
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
+ ):
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ def _set_time_proj(
+ self,
+ time_embedding_type: str,
+ block_out_channels: int,
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ time_embedding_dim: int,
+ ) -> Tuple[int, int]:
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ return time_embed_dim, timestep_input_dim
+
+ def _set_encoder_hid_proj(
+ self,
+ encoder_hid_dim_type: Optional[str],
+ cross_attention_dim: Union[int, Tuple[int]],
+ encoder_hid_dim: Optional[int],
+ ):
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ def _set_class_embedding(
+ self,
+ class_embed_type: Optional[str],
+ act_fn: str,
+ num_class_embeds: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ timestep_input_dim: int,
+ ):
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ def _set_add_embedding(
+ self,
+ addition_embed_type: str,
+ addition_embed_type_num_heads: int,
+ addition_time_embed_dim: Optional[int],
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ cross_attention_dim: Optional[int],
+ encoder_hid_dim: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ ):
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = GLIGENTextBoundingboxProjection(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def unload_lora(self):
+ """Unloads LoRA weights."""
+ deprecate(
+ "unload_lora",
+ "0.28.0",
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
+ )
+ for module in self.modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+
+ def get_time_embed(
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
+ ) -> Optional[torch.Tensor]:
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ return t_emb
+
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ class_emb = None
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+ return class_emb
+
+ def get_aug_embed(
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> Optional[torch.Tensor]:
+ aug_emb = None
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds").to(self.device)
+ # Qualcomm change to precompute time_embeds from time_ids, instead of passing time_ids and having Unet compute the embeddings
+ # if "time_ids" not in added_cond_kwargs:
+ # raise ValueError(
+ # f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ # )
+ # time_ids = added_cond_kwargs.get("time_ids")
+ # time_embeds = self.add_time_proj(time_ids.flatten())
+ # time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ # time_embeds = added_cond_kwargs.get("time_embeds") # Qualcomm: get precomputed time_embeds
+ time_embeds = self.time_embeds[:text_embeds.shape[0]].to(self.device)
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb = self.add_embedding(image_embs, hint)
+ return aug_emb
+
+ def process_encoder_hidden_states(
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> torch.Tensor:
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
+ return encoder_hidden_states
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ emb: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ # t_emb = self.get_time_embed(sample=sample, timestep=timestep) # Qualcomm change, pass precomputed t_emb instead of having Unet compute from timestep input
+ # emb = self.time_embedding(t_emb, timestep_cond) # Comes precomputed from the pipeline
+ if timestep_cond is not None:
+ raise ValueError("This custom model with precomputed timestep embeddings does not use the `timestep_cond` input, expected None. \
+ The pipeline class should use `timestep_cond` to precompute embeddings and harcode them into Unet, as a latency optimization")
+
+ aug_emb = None
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
+ if cross_attention_kwargs is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/transformer_2d.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/transformer_2d.py
new file mode 100644
index 000000000..c7aca82ed
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/transformer_2d.py
@@ -0,0 +1,545 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to transformers 2d. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py"""
+
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, deprecate, is_torch_version, logging
+from genai_lib.lvm.dev.diffusers.attention import BasicTransformerBlock
+from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.normalization import AdaLayerNormSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ double_self_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
+ norm_elementwise_affine: bool = True,
+ norm_eps: float = 1e-5,
+ attention_type: str = "default",
+ caption_channels: int = None,
+ interpolation_scale: float = None,
+ use_additional_conditions = None, #for supporting Diffusers > 0.27
+ ):
+ super().__init__()
+
+ # Validate inputs.
+ if patch_size is not None:
+ if norm_type not in ["ada_norm", "ada_norm_zero", "ada_norm_single"]:
+ raise NotImplementedError(
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
+ )
+ elif norm_type in ["ada_norm", "ada_norm_zero"] and num_embeds_ada_norm is None:
+ raise ValueError(
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
+ )
+
+ # Set some common variables used across the board.
+ self.use_linear_projection = False # Qualcomm modification, force model to use conv projection
+ self.interpolation_scale = interpolation_scale
+ self.caption_channels = caption_channels
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.gradient_checkpointing = False
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`. Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Initialize the right blocks.
+ # These functions follow a common structure:
+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
+ # c. Initialize the output blocks and other projection blocks when necessary.
+ if self.is_input_continuous:
+ self._init_continuous_input(norm_type=norm_type)
+ elif self.is_input_vectorized:
+ self._init_vectorized_inputs(norm_type=norm_type)
+ elif self.is_input_patches:
+ self._init_patched_inputs(norm_type=norm_type)
+
+ def _init_continuous_input(self, norm_type):
+ self.norm = torch.nn.GroupNorm(
+ num_groups=self.config.norm_num_groups, num_channels=self.in_channels, eps=1e-6, affine=True
+ )
+ if self.use_linear_projection:
+ self.proj_in = torch.nn.Linear(self.in_channels, self.inner_dim)
+ else:
+ self.proj_in = torch.nn.Conv2d(self.in_channels, self.inner_dim, kernel_size=1, stride=1, padding=0)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ cross_attention_dim=self.config.cross_attention_dim,
+ activation_fn=self.config.activation_fn,
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
+ attention_bias=self.config.attention_bias,
+ only_cross_attention=self.config.only_cross_attention,
+ double_self_attention=self.config.double_self_attention,
+ upcast_attention=self.config.upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ attention_type=self.config.attention_type,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
+
+ if self.use_linear_projection:
+ self.proj_out = torch.nn.Linear(self.inner_dim, self.out_channels)
+ else:
+ self.proj_out = torch.nn.Conv2d(self.inner_dim, self.out_channels, kernel_size=1, stride=1, padding=0)
+
+ def _init_vectorized_inputs(self, norm_type):
+ assert self.config.sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert (
+ self.config.num_vector_embeds is not None
+ ), "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = self.config.sample_size
+ self.width = self.config.sample_size
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=self.config.num_vector_embeds, embed_dim=self.inner_dim, height=self.height, width=self.width
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ cross_attention_dim=self.config.cross_attention_dim,
+ activation_fn=self.config.activation_fn,
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
+ attention_bias=self.config.attention_bias,
+ only_cross_attention=self.config.only_cross_attention,
+ double_self_attention=self.config.double_self_attention,
+ upcast_attention=self.config.upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ attention_type=self.config.attention_type,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
+
+ self.norm_out = nn.LayerNorm(self.inner_dim)
+ self.out = nn.Linear(self.inner_dim, self.config.num_vector_embeds - 1)
+
+ def _init_patched_inputs(self, norm_type):
+ assert self.config.sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = self.config.sample_size
+ self.width = self.config.sample_size
+
+ self.patch_size = self.config.patch_size
+ interpolation_scale = (
+ self.config.interpolation_scale
+ if self.config.interpolation_scale is not None
+ else max(self.config.sample_size // 64, 1)
+ )
+ self.pos_embed = PatchEmbed(
+ height=self.config.sample_size,
+ width=self.config.sample_size,
+ patch_size=self.config.patch_size,
+ in_channels=self.in_channels,
+ embed_dim=self.inner_dim,
+ interpolation_scale=interpolation_scale,
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ self.inner_dim,
+ self.config.num_attention_heads,
+ self.config.attention_head_dim,
+ dropout=self.config.dropout,
+ cross_attention_dim=self.config.cross_attention_dim,
+ activation_fn=self.config.activation_fn,
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
+ attention_bias=self.config.attention_bias,
+ only_cross_attention=self.config.only_cross_attention,
+ double_self_attention=self.config.double_self_attention,
+ upcast_attention=self.config.upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
+ norm_eps=self.config.norm_eps,
+ attention_type=self.config.attention_type,
+ )
+ for _ in range(self.config.num_layers)
+ ]
+ )
+
+ if self.config.norm_type != "ada_norm_single":
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
+ self.proj_out_2 = nn.Linear(
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
+ )
+ elif self.config.norm_type == "ada_norm_single":
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
+ self.proj_out = nn.Linear(
+ self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
+ )
+
+ # PixArt-Alpha blocks.
+ self.adaln_single = None
+ self.use_additional_conditions = False
+ if self.config.norm_type == "ada_norm_single":
+ self.use_additional_conditions = self.config.sample_size == 128
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
+ # additional conditions until we find better name
+ self.adaln_single = AdaLayerNormSingle(
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
+ )
+
+ self.caption_projection = None
+ if self.caption_channels is not None:
+ self.caption_projection = PixArtAlphaTextProjection(
+ in_features=self.caption_channels, hidden_size=self.inner_dim
+ )
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ attention_mask ( `torch.Tensor`, *optional*):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ if self.is_input_continuous:
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ ### Qualcomm modifications: Keep data flowing as 4D for better on-device attention computation ###
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ # inner_dim = hidden_states.shape[1]
+ # hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ else:
+ raise Exception('AIMET/QNN expect self.use_linear_projection = False')
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+ ### Qualcomm modifications: end ###
+
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
+ hidden_states = self.pos_embed(hidden_states)
+
+ if self.adaln_single is not None:
+ if self.use_additional_conditions and added_cond_kwargs is None:
+ raise ValueError(
+ "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
+ )
+ batch_size = hidden_states.shape[0]
+ timestep, embedded_timestep = self.adaln_single(
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ # 2. Blocks
+ if self.is_input_patches and self.caption_projection is not None:
+ batch_size = hidden_states.shape[0]
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ for block in self.transformer_blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ cross_attention_kwargs,
+ class_labels,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ if self.is_input_continuous:
+ ### Qualcomm modifications: Keep data flowing as 4D for better on-device attention computation ###
+ if not self.use_linear_projection:
+ # hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ raise Exception('AIMET/QNN expect self.use_linear_projection = False')
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ ### Qualcomm modifications: end ###
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+
+ if self.is_input_patches:
+ if self.config.norm_type != "ada_norm_single":
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+ elif self.config.norm_type == "ada_norm_single":
+ shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states)
+ # Modulation
+ hidden_states = hidden_states * (1 + scale) + shift
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.squeeze(1)
+
+ # unpatchify
+ if self.adaln_single is None:
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/vae_decoder_merged.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/vae_decoder_merged.py
new file mode 100644
index 000000000..f21113796
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/diffusers/vae_decoder_merged.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import torch
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+class VAEDecoderMerged(torch.nn.Module):
+ '''
+ Unify decoding steps from vae.post_quant_conv and vae.decoder
+ into a single ONNX graph
+ '''
+ def __init__(self, post_quant_conv, decoder):
+ super().__init__()
+ self.post_quant_conv = post_quant_conv
+ self.decoder = decoder
+ dtype = next(iter(decoder.parameters())).dtype
+ self.to(dtype)
+
+ def __getattr__(self, attr):
+ if attr in set(self._modules.keys()):
+ return self._modules[attr]
+ return getattr(self._modules['decoder'], attr)
+
+ def forward(self, z: torch.FloatTensor, return_dict: bool = True):
+ '''
+ Adapted from https://github.com/huggingface/diffusers/blob/v0.27.2/src/diffusers/models/autoencoders/autoencoder_kl.py#L270-L280
+ '''
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/transformer/clip_attention.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/transformer/clip_attention.py
new file mode 100644
index 000000000..5f4ce4d8c
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/transformer/clip_attention.py
@@ -0,0 +1,335 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. Not a Contribution
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+# =============================================================================
+# coding=utf-8
+# Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+""" This file provides adaptations to clip attention. These adaptations are being done to optimize the model execution on the HTP backend. https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py"""
+
+import torch
+from torch import nn
+from aimet_torch import elementwise_ops
+from typing import Optional, Tuple
+
+
+class CLIPAttention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper
+ Modified from source: https://github.com/huggingface/transformers/blob/5d29530ea25fab34eaf193116512753609f2ff54/src/transformers/models/clip/modeling_clip.py#L224
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim**-0.5
+ self.dropout = torch.nn.Dropout(config.attention_dropout)
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ self.bmm_1 = elementwise_ops.MatMul()
+ self.bmm_2 = elementwise_ops.MatMul()
+ self.softmax = torch.nn.Softmax(dim=-1)
+ self.mask_causal_attn = elementwise_ops.Add()
+ self.mask_attn = elementwise_ops.Add()
+
+ self.is_sha = False
+ self.mha_conv = False
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def _split_attention_heads(self):
+ self.q_convs = nn.ModuleList([nn.Conv2d(self.embed_dim, self.head_dim, 1, bias=self.q_proj.bias is not None) for _ in range(self.num_heads)])
+ self.k_convs = nn.ModuleList([nn.Conv2d(self.embed_dim, self.head_dim, 1, bias=self.k_proj.bias is not None) for _ in range(self.num_heads)])
+ self.v_convs = nn.ModuleList([nn.Conv2d(self.embed_dim, self.head_dim, 1, bias=self.v_proj.bias is not None) for _ in range(self.num_heads)])
+
+ self.matmul_1 = nn.ModuleList([elementwise_ops.MatMul() for _ in range(self.num_heads)])
+ self.matmul_2 = nn.ModuleList([elementwise_ops.MatMul() for _ in range(self.num_heads)])
+ self.concat_1 = elementwise_ops.Concat(-1)
+
+ for idx in range(self.num_heads):
+ self.q_convs[idx].weight.data.copy_(self.q_proj.weight[idx*self.head_dim:(idx+1)*self.head_dim, :, None, None]).to(self.q_proj.weight)
+ self.k_convs[idx].weight.data.copy_(self.k_proj.weight[idx*self.head_dim:(idx+1)*self.head_dim, :, None, None]).to(self.k_proj.weight)
+ self.v_convs[idx].weight.data.copy_(self.v_proj.weight[idx*self.head_dim:(idx+1)*self.head_dim, :, None, None]).to(self.v_proj.weight)
+ if self.q_convs[idx].bias is not None: self.q_convs[idx].bias.data.copy_(self.q_proj.bias[idx*self.head_dim:(idx+1)*self.head_dim]).to(self.q_proj.weight)
+ if self.k_convs[idx].bias is not None: self.k_convs[idx].bias.data.copy_(self.k_proj.bias[idx*self.head_dim:(idx+1)*self.head_dim]).to(self.k_proj.weight)
+ if self.v_convs[idx].bias is not None: self.v_convs[idx].bias.data.copy_(self.v_proj.bias[idx*self.head_dim:(idx+1)*self.head_dim]).to(self.v_proj.weight)
+ self.is_sha = True
+
+ def _replace_linear_to_conv(self):
+ self.k_proj_conv = nn.Conv2d(self.embed_dim, self.embed_dim, (1, 1))
+ self.v_proj_conv = nn.Conv2d(self.embed_dim, self.embed_dim, (1, 1))
+ self.q_proj_conv = nn.Conv2d(self.embed_dim, self.embed_dim, (1, 1))
+ self.out_proj_conv = nn.Conv2d(self.embed_dim, self.embed_dim, (1, 1))
+
+ with torch.no_grad():
+ self.k_proj_conv.weight.copy_(self.k_proj.weight.unsqueeze(-1).unsqueeze(-1))
+ self.k_proj_conv.bias.copy_(self.k_proj.bias)
+ self.v_proj_conv.weight.copy_(self.v_proj.weight.unsqueeze(-1).unsqueeze(-1))
+ self.v_proj_conv.bias.copy_(self.v_proj.bias)
+ self.q_proj_conv.weight.copy_(self.q_proj.weight.unsqueeze(-1).unsqueeze(-1))
+ self.q_proj_conv.bias.copy_(self.q_proj.bias)
+ self.out_proj_conv.weight.copy_(self.out_proj.weight.unsqueeze(-1).unsqueeze(-1))
+ self.out_proj_conv.bias.copy_(self.out_proj.bias)
+ self.k_proj_conv.to(self.k_proj.weight)
+ self.v_proj_conv.to(self.v_proj.weight)
+ self.q_proj_conv.to(self.q_proj.weight)
+ self.out_proj_conv.to(self.out_proj.weight)
+ self.mha_conv = True
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ """
+ if self.is_sha:
+ return self.forward_sha(hidden_states)
+ elif self.mha_conv:
+ return self.forward_mha_conv(hidden_states = hidden_states,
+ attention_mask= attention_mask,
+ causal_attention_mask= causal_attention_mask,
+ output_attentions= output_attentions)
+ return self.forward_mha_linear(hidden_states = hidden_states,
+ attention_mask= attention_mask,
+ causal_attention_mask= causal_attention_mask,
+ output_attentions= output_attentions)
+
+
+ def forward_sha(self, x: torch.Tensor) -> torch.Tensor:
+ """"input shape = (B, seq_len, emb_dim) = (1, 257, 1024) for 224x224 image inputs"""
+ B, seq_len, emb_dim = x.shape
+ x = x.permute(0, 2, 1).unsqueeze(-1) # (1, 257, 1024) -> (1, 1024, 257, 1)
+
+ outs = []
+ for ndx in range(self.num_heads):
+ # (1, 1024, 257, 1) -> conv -> (1, 64, 257, 1) -> permute -> (1, 257, 1, 64)-> reshape -> (1, 257, 64)
+ _query = self.q_convs[ndx](x).permute(0, 2, 3, 1).reshape(B, -1, self.head_dim)
+ _key = self.k_convs[ndx](x).permute(0, 2, 3, 1).reshape(B, -1, self.head_dim)
+ _value = self.v_convs[ndx](x).permute(0, 2, 3, 1).reshape(B, -1, self.head_dim)
+
+ _attn = self.matmul_1[ndx]((_query * self.scale), _key.transpose(-2, -1))
+ _attn = self.softmax(_attn)
+ _attn = self.dropout(_attn)
+ _out = self.matmul_2[ndx](_attn, _value) # -> (1, 257, 64)
+ outs.append(_out)
+
+ out = self.concat_1(*outs) # -> (1, 257, 1024)
+ out = self.out_proj(out)
+ return out, None
+
+
+ def forward_mha_conv(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+ assert bsz == 1
+ hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(-1) # input shape from 3d to 4d: (1, 77, 768) -> (1, 768, 77) -> (1, 768, 77, 1)
+
+ # get query proj
+ # output shape from 4d to 3d: (1, 768, 77, 1) -> (1, 768, 77) -> (1, 77, 768)
+ query_states = self.q_proj_conv(hidden_states).squeeze(-1).permute(0, 2, 1) * self.scale
+ key_states = self._shape(self.k_proj_conv(hidden_states).squeeze(-1).permute(0, 2, 1), -1, bsz)
+ value_states = self._shape(self.v_proj_conv(hidden_states).squeeze(-1).permute(0, 2, 1), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ #attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+ attn_weights = self.bmm_1(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ #attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = self.softmax(attn_weights)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ #attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_probs = self.dropout(attn_weights)
+
+ #attn_output = self.bmm_2(attn_probs, value_states)
+ # Change to support 16-bit value (attn_probs [12, 14, 14], value_states [12, 14, 64], attn_output [12, 14, 64])
+ attn_output = self.bmm_2(value_states.transpose(-1,-2), attn_probs.transpose(-1,-2)).transpose(-1,-2)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ # input shape from 3d to 4d: (1, 77, 768) -> (1, 768, 77) -> (1, 768, 77, 1)
+ attn_output = attn_output.permute(0, 2, 1).unsqueeze(-1)
+ attn_output = self.out_proj_conv(attn_output)
+ # output shape from 4d to 3d: (1, 768, 77, 1) -> (1, 768, 77) -> (1, 77, 768)
+ attn_output = attn_output.squeeze(-1).permute(0, 2, 1)
+ return attn_output, attn_weights_reshaped
+
+
+ def forward_mha_linear(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ bsz, tgt_len, embed_dim = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scale
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.view(*proj_shape)
+ value_states = value_states.view(*proj_shape)
+
+ src_len = key_states.size(1)
+ #attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+ attn_weights = self.bmm_1(query_states, key_states.transpose(1, 2))
+
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ # apply the causal_attention_mask first
+ if causal_attention_mask is not None:
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
+ f" {causal_attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ #attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attn_weights = self.softmax(attn_weights)
+
+ if output_attentions:
+ # this operation is a bit akward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ #attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_probs = self.dropout(attn_weights)
+
+ #attn_output = self.bmm_2(attn_probs, value_states)
+ # Change to support 16-bit value (attn_probs [12, 14, 14], value_states [12, 14, 64], attn_output [12, 14, 64])
+ attn_output = self.bmm_2(value_states.transpose(-1,-2), attn_probs.transpose(-1,-2)).transpose(-1,-2)
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped
+
+
+def replace_mha_with_sha_blocks(clip_model):
+ print("linear layers in Attention will be replaced with convs with single head")
+ for name, module in clip_model.named_modules():
+ if isinstance(module, CLIPAttention):
+ module._split_attention_heads()
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/transformer/siglip_adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/transformer/siglip_adaptation.py
new file mode 100644
index 000000000..30c820261
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/transformer/siglip_adaptation.py
@@ -0,0 +1,282 @@
+# /usr/bin/env python3
+# -*- mode: python -*-
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" model adaptations for optimal accuracy and latency when quantizing and executing on constrained hardware """
+
+import functools
+import math
+import torch
+from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipVisionEmbeddings
+from aimet_torch.nn.modules import custom as elementwise_ops
+from peft.tuners.lora.layer import Linear as LoraLinear
+from aimet_torch.peft import LoraLayer as AimetLoraLayer
+
+
+@torch.no_grad
+def _convert_linear_to_conv(linear, scale=None):
+ conv = torch.nn.Conv2d(linear.weight.shape[1], linear.weight.shape[0], 1, bias=linear.bias is not None)
+ conv.weight.copy_(linear.weight.data[..., None, None] * scale if scale is not None else
+ linear.weight.data[..., None, None])
+ if linear.bias is not None:
+ conv.bias.copy_(linear.bias.data * scale if scale is not None else linear.bias.data)
+ return conv
+
+
+class SelfAttention4d(torch.nn.Module):
+
+ def __init__(self,
+ self_attn):
+ super(SelfAttention4d, self).__init__()
+ self.q_proj = self_attn.q_proj
+ self.k_proj = self_attn.k_proj
+ self.v_proj = self_attn.v_proj
+ self.out_proj = _convert_linear_to_conv(self_attn.out_proj)
+ self.scale = self_attn.scale
+
+ self.num_heads = self_attn.num_heads
+ self.head_dim = self_attn.head_dim
+
+ if isinstance(self.q_proj, ((AimetLoraLayer, LoraLinear))):
+ self.q_proj = self._replace_lora_linear_with_conv(self.q_proj)
+ self.k_proj = self._replace_lora_linear_with_conv(self.k_proj)
+ self.v_proj = self._replace_lora_linear_with_conv(self.v_proj)
+ else:
+ q_proj, k_proj, v_proj = self.q_proj, self.k_proj, self.v_proj
+ self.q_proj = _convert_linear_to_conv(q_proj)
+ self.k_proj = _convert_linear_to_conv(k_proj)
+ self.v_proj = _convert_linear_to_conv(v_proj)
+
+ del q_proj
+ del k_proj
+ del v_proj
+
+ del self_attn.out_proj
+
+ def forward(self, hidden_states, attn_mask=None):
+ batch_size, channel, height, width = hidden_states.shape
+ _query = self.q_proj(hidden_states).permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3) * self.scale # -> (B, num_heads, H*W, head_dim)
+
+ _key_T = self.k_proj(hidden_states).reshape(batch_size, self.num_heads, self.head_dim, -1) # -> (B, num_heads, head_dim, H*W)
+
+ _value = self.v_proj(hidden_states).reshape(batch_size, self.num_heads, self.head_dim, -1).permute(0, 1, 3, 2) # -> (B, num_heads, head_dim, H*W) -> (B, num_heads, H*W, head_dim)
+
+ _attn = elementwise_ops.MatMul()(_query, _key_T) # -> (B, num_heads, H*W, H*W)
+ _attn = torch.nn.functional.softmax(_attn, dim=-1)
+ _out = elementwise_ops.MatMul()(_attn, _value) # -> (B, num_heads, H*W, head_dim)
+
+ # These permute reshape will be cleaned up in MHA2SHA-onnx-converter
+ _out = _out.permute(0, 2, 1, 3) # -> (B, H*W, head, d)
+ hidden_states = _out.reshape(batch_size, height, width, -1) # (B, H*W, head, d) -> (B, H, W, C)
+
+ # linear proj
+ hidden_states = hidden_states.permute(0, 3, 1, 2) # (B, H, W, C) -> (B, C, H, W)
+ hidden_states = self.out_proj(hidden_states) # -> (B, C, H, W)
+
+ return hidden_states
+
+ @torch.no_grad
+ def _replace_lora_linear_with_conv(self, lora_layer, scale=None):
+ def _replace_lora_A_or_B(lora_A_or_B, scale=None):
+ if isinstance(lora_A_or_B, torch.nn.ModuleList):
+ lora_A_or_B_conv = torch.nn.ModuleList()
+ for mod in lora_A_or_B:
+ mod_conv = _convert_linear_to_conv(mod, scale)
+ lora_A_or_B_conv.append(mod_conv)
+ elif isinstance(lora_A_or_B, torch.nn.ModuleDict):
+ lora_A_or_B_conv = torch.nn.ModuleDict()
+ for name, mod in lora_A_or_B.items():
+ mod_conv = _convert_linear_to_conv(mod, scale)
+ lora_A_or_B_conv[name] = mod_conv
+ else:
+ raise ValueError(f"{lora_A_or_B} should be LoRA Layer, got: {type(lora_A_or_B)}")
+ return lora_A_or_B_conv
+
+ base_layer, lora_A, lora_B = lora_layer.base_layer, lora_layer.lora_A, lora_layer.lora_B
+ base_layer_conv = _convert_linear_to_conv(base_layer, scale)
+
+ lora_layer.base_layer = base_layer_conv
+ lora_layer.lora_A = _replace_lora_A_or_B(lora_A, scale=scale)
+ lora_layer.lora_B = _replace_lora_A_or_B(lora_B, scale=None)
+
+ if hasattr(lora_layer, "scaling"):
+ # Expand the lora alpha vector to be 4d
+ for i, scaling in enumerate(lora_layer.scaling):
+ if isinstance(scaling, torch.Tensor) and scaling.size():
+ lora_layer.scaling[i].data = scaling.data.view(1, scaling.size()[0], 1, 1)
+
+ del base_layer
+ del lora_A
+ del lora_B
+
+ return lora_layer
+
+
+class Mlp4d(torch.nn.Module):
+
+ def __init__(self,
+ siglip_mlp):
+ super(Mlp4d, self).__init__()
+ self.fc1 = _convert_linear_to_conv(siglip_mlp.fc1)
+ self.activation_fn = siglip_mlp.activation_fn
+ self.fc2 = _convert_linear_to_conv(siglip_mlp.fc2)
+ del siglip_mlp.fc1
+ del siglip_mlp.fc2
+
+ def forward(self, x, **kwargs):
+ x = self.fc1(x)
+ x = self.activation_fn(x)
+ x = self.fc2(x)
+ return x
+
+
+class TransformerBlock4d(torch.nn.Module):
+
+ def __init__(self,
+ siglip_encoder_layer):
+ super(TransformerBlock4d, self).__init__()
+ self.layer_norm1 = LayerNorm4d(siglip_encoder_layer.layer_norm1)
+ self.self_attn = SelfAttention4d(siglip_encoder_layer.self_attn)
+ self.layer_norm2 = LayerNorm4d(siglip_encoder_layer.layer_norm2)
+ self.mlp = Mlp4d(siglip_encoder_layer.mlp)
+
+ def forward(self, hidden_states, attention_mask=None, output_attentions=None, **kwargs):
+ if not hidden_states.ndim == 4:
+ raise NotImplementedError(f'{self.__class__.__name__} expects 4d channels first inputs, but got {hidden_states.ndim}-dimensional inputs')
+
+ output_attentions_kwarg = kwargs.get('output_attentions', output_attentions)
+ if not (output_attentions_kwarg is None or output_attentions_kwarg is False):
+ raise NotImplementedError(f'{self.__class__.__name__} has not implemented support for output_attentions=True option')
+
+ for k, v in kwargs.items():
+ if not (v is None or v is False):
+ raise NotImplementedError(f'{self.__class__.__name__} has not implemented support param {k} with value {v}.')
+
+ attn_out = self.layer_norm1(hidden_states)
+ attn_out = self.self_attn(attn_out)
+ attn_out += hidden_states
+ mlp_out = self.layer_norm2(attn_out)
+ mlp_out = self.mlp(mlp_out)
+ mlp_out += attn_out
+ return mlp_out, None
+
+
+class SiglipVisionEmbeddings4d(torch.nn.Module):
+ """ Adapted from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/siglip/modeling_siglip.py#L250-L318 """
+ def __init__(self, siglip_vision_embeddings):
+ super().__init__()
+ self.siglip_vision_embeddings = siglip_vision_embeddings
+
+ def forward(self, x, **kwargs):
+ for k, v in kwargs.items():
+ if not (v is None or v is False):
+ raise NotImplementedError(f'{self.__class__.__name__} cannot support param {k} with value {v}.')
+
+ x = self.siglip_vision_embeddings(x)
+ B, emb_dim, C = x.shape
+ emb_dim_sqrt = int(math.sqrt(emb_dim)) # 4D handling from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/siglip/modeling_siglip.py#L295-L296
+ x = x.permute(0,2,1)
+ x = x.reshape(B, C, emb_dim_sqrt, emb_dim_sqrt)
+ return x
+
+
+class LayerNorm4d(torch.nn.Module):
+ def __init__(self, layer_norm):
+ super().__init__()
+ self.layer_norm = layer_norm
+
+ def forward(self, x, **kwargs):
+ for k, v in kwargs.items():
+ if not (v is None or v is False):
+ raise NotImplementedError(f'{self.__class__.__name__} has not implemented support param {k} with value {v}.')
+
+ return self.layer_norm(x.permute(0,2,3,1)).permute(0,3,1,2)
+
+
+class GLU4d(torch.nn.Module):
+ def __init__(self, GLU):
+ super().__init__()
+ self.linear_proj = _convert_linear_to_conv(GLU.linear_proj)
+ self.norm1 = LayerNorm4d(GLU.norm1)
+ self.act1 = GLU.act1
+ self.act2 = GLU.act2
+ self.dense_h_to_4h = _convert_linear_to_conv(GLU.dense_h_to_4h)
+ self.gate_proj = _convert_linear_to_conv(GLU.gate_proj)
+ self.dense_4h_to_h = _convert_linear_to_conv(GLU.dense_4h_to_h)
+
+ del GLU.linear_proj
+ del GLU.dense_h_to_4h
+ del GLU.gate_proj
+ del GLU.dense_4h_to_h
+
+ def forward(self, x, **kwargs):
+ for k, v in kwargs.items():
+ if not (v is None or v is False):
+ raise NotImplementedError(f'{self.__class__.__name__} has not implemented support param {k} with value {v}.')
+
+ B, C, H, W = x.shape
+ if not H == W:
+ raise NotImplementedError(f'{self.__class__.__name__} only supports square 4d inputs, but got height = {H}, and width = {W}.')
+
+ x = self.linear_proj(x)
+ x = self.act1(self.norm1(x))
+ x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
+ x = self.dense_4h_to_h(x)
+
+ return x
+
+
+class Adapter4d(torch.nn.Module):
+ def __init__(self, adapter):
+ super().__init__()
+ self.conv = adapter.conv
+ self.boi = adapter.boi
+ self.eoi = adapter.eoi
+ self.linear_proj = GLU4d(adapter.linear_proj)
+
+ def forward(self, image_emb, **kwargs):
+ for k, v in kwargs.items():
+ if not (v is None or v is False):
+ raise NotImplementedError(f'{self.__class__.__name__} has not implemented support param {k} with value {v}.')
+
+ B, C, H, W = image_emb.shape
+ if not H == W:
+ raise NotImplementedError(f'{self.__class__.__name__} only supports square 4d inputs, but got height = {H}, and width = {W}.')
+
+ image_emb = self.conv(image_emb)
+ image_emb = self.linear_proj(image_emb)
+ image_emb = image_emb.flatten(2).transpose(1, 2)
+ image_emb = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1)
+ return image_emb
+
+
+def rsetattr(obj, attr, val):
+ pre, _, post = attr.rpartition('.')
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
+
+
+def rgetattr(obj, attr, *args):
+ def _getattr(obj, attr):
+ return getattr(obj, attr, *args)
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
+
+
+def align_siglip_vision_model_tensor_dimensionlity(model):
+ for name, module in model.named_modules():
+
+ if isinstance(module, SiglipVisionEmbeddings):
+ embedding_4d = SiglipVisionEmbeddings4d(module)
+ rsetattr(model, name, embedding_4d)
+
+ if isinstance(module, SiglipEncoderLayer):
+ transfomer_block_4d = TransformerBlock4d(module)
+ rsetattr(model, name, transfomer_block_4d)
+
+ if 'post_layernorm' in name:
+ layer_norm_4d = LayerNorm4d(module)
+ rsetattr(model, name, layer_norm_4d)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/utils/model_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/utils/model_utils.py
new file mode 100644
index 000000000..161c29e0a
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/utils/model_utils.py
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" common utilities and class implementation for model architecture adaptation """
+from typing import Dict, List, Tuple, Union, Any, Callable
+
+import torch
+
+from aimet_common.utils import AimetLogger
+logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
+
+class ModelOutputIndexFilter(torch.nn.Module):
+ """
+ Wrap a model that generates several outputs, and filter it to return only a subset of those
+ Used to simplify an ONNX model, shrinking its size and preventing generation of unused and unwanted outputs
+ """
+ def __init__(self, model: torch.nn.Module,
+ output_index_filter: List,
+ forward_kwargs: Dict[str, Any] = None,
+ out_mapper_fn: Callable = None):
+ """
+ :param model: model whose output is to be filtered
+ :param output_index_filter: list matching length of model's outputs, indicating which idx of each output to keep,
+ None means entirely skipping an output, integers mean selecting output at that index,
+ a string with a single colon ":" means selecting all entries for the current output
+ :param forward_kwargs: additional kwargs to be passed to the model during inference
+ :param out_mapper_fn: function to map the filtered outputs to the desired format
+ """
+ super().__init__()
+ self.model = model
+ self.output_index_filter = output_index_filter
+ self.forward_kwargs = forward_kwargs if forward_kwargs else {}
+ self.out_mapper_fn = out_mapper_fn if out_mapper_fn is not None else lambda x: x
+ self.device = next(model.parameters()).device
+ self.dtype = next(model.parameters()).dtype
+ self.config = model.config if hasattr(model, "config") else None
+
+
+ def forward(self, *args, **kwargs):
+ """
+ :param args: args to be forwarded to the model
+ :param kwargs: kwargs to be forwarded to the model
+ """
+ kwargs.update(self.forward_kwargs)
+ outputs = self.model(*args, **kwargs)
+ assert len(outputs) == len(self.output_index_filter), f'output_index_filter must match expected model output length, got {len(outputs)} model outputs, but {len(self.output_index_filter)} index filters'
+ filtered_outputs = []
+ for idx, output_filter in enumerate(self.output_index_filter):
+ if output_filter is None:
+ continue
+ if output_filter == ':':
+ chosen_output = outputs[idx]
+ else:
+ chosen_output = outputs[idx][output_filter]
+ # unsqueeze to make up for the lost dimension when we access at [output_filter]
+ if isinstance(outputs[idx], torch.Tensor):
+ chosen_output = chosen_output.unsqueeze(0)
+ else:
+ chosen_output = (chosen_output, )
+ filtered_outputs.append(chosen_output)
+ return self.out_mapper_fn(filtered_outputs)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/utils/unet_3d_to_4d.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/utils/unet_3d_to_4d.py
new file mode 100644
index 000000000..25737fc85
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/dev/utils/unet_3d_to_4d.py
@@ -0,0 +1,40 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import torch
+import copy
+from genai_lib.common.dev.utils import rsetattr
+from diffusers.models.transformers.transformer_2d import Transformer2DModel as Transformer2DModel_3d
+from genai_lib.lvm.dev.diffusers.transformer_2d import Transformer2DModel as Transformer2DModel_4d
+
+
+def convert_3d_unet_to_4d(model: torch.nn.Module):
+ model_4d = copy.deepcopy(model)
+
+ for name, module_3d in model_4d.named_modules():
+ if isinstance(module_3d, Transformer2DModel_3d):
+
+ layers_geglu = {f'transformer_blocks.{idx}.ff.net.0.proj.weight' for idx in range(10)}
+ layers_ff = {f'transformer_blocks.{idx}.ff.net.2.weight' for idx in range(10)}
+ layers_proj = {'proj_in.weight', 'proj_out.weight'}
+ linear_to_conv_weights = set.union(*[layers_geglu, layers_ff, layers_proj])
+
+ state_dict_3d = module_3d.state_dict()
+ layer_matches = set.intersection(linear_to_conv_weights, set(state_dict_3d.keys()))
+ for layer_name in layer_matches:
+ shape_2d = state_dict_3d[layer_name].shape[:2]
+ state_dict_3d[layer_name] = state_dict_3d[layer_name].reshape(shape_2d)[:,:,None,None]
+
+ module_4d = Transformer2DModel_4d(**module_3d.config)
+ module_4d.load_state_dict(state_dict_3d)
+
+ rsetattr(model_4d, name, module_4d)
+
+ model_4d.to(model.device)
+ return model_4d
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/eval_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/eval_utils.py
new file mode 100644
index 000000000..1c75e5082
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/eval_utils.py
@@ -0,0 +1,373 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" common utilities and class implementation for generating test vector, accuracy analysis """
+
+import contextlib
+import inspect
+import math
+import os
+import re
+import resource
+import sys
+from typing import Optional, Callable, Dict, Tuple, List, Union, Any, Iterator
+
+import numpy as np
+import pandas as pd
+import torch
+from tqdm import tqdm
+from bokeh.models import ColumnDataSource, HoverTool
+from bokeh.plotting import figure, output_file, save
+
+from aimet_common.utils import AimetLogger
+from genai_lib.lvm.calibration import ModelDataReader
+from aimet_torch.v2.quantsim import QuantizationSimModel
+from aimet_torch.v2.nn import BaseQuantizationMixin
+logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
+
+def eval_model(model: torch.nn.Module,
+ model_input: Tuple[torch.Tensor],
+ intermediate_modules_names: Tuple[str] = None):
+ """
+ eval models with optionally enabling intermediate tensors.
+ :param model: fp32 model or quantsim model.
+ :param model_input: prepared model input.
+ :param intermediate_modules_names: list of intermediate modules to dump tensors for.
+ """
+
+ intermediate_tensors = {}
+ hooks = []
+ if intermediate_modules_names:
+ modules_to_name = {module: name for name, module in model.named_modules() if name in intermediate_modules_names}
+
+ modules_not_found = set(intermediate_modules_names) - set(modules_to_name.values())
+ if modules_not_found:
+ raise ValueError(f'intermediate module name(s)={modules_not_found} not found in model: {type(model)}, '
+ f'defined at {inspect.getfile(model.__class__)}')
+
+ def hook(module: torch.nn.Module, _, output):
+ if isinstance(output, torch.Tensor):
+ intermediate_tensors[modules_to_name[module]] = np_tensor(output)
+
+ for module in modules_to_name.keys():
+ hooks.append(module.register_forward_hook(hook))
+
+ with torch.no_grad():
+ if isinstance(model_input, torch.Tensor):
+ output = model(model_input)
+ else:
+ output = model(*model_input)
+
+ for hook in hooks:
+ hook.remove()
+
+ return output, intermediate_tensors
+
+
+def generate_vectors(output_path: str,
+ sim_model: torch.nn.Module,
+ prepared_model_reader: ModelDataReader = None,
+ sample_inputs: Union[List[torch.Tensor], Tuple[torch.Tensor]] = None,
+ format_model_inputs_fn: Iterator[Dict[str, torch.Tensor]] = None,
+ input_names: Optional[List[str]] = None,
+ output_names: Optional[List[str]] = None,
+ align_fp_model: bool = True,
+ num_samples: int = 1,
+ intermediate_modules_names: Tuple[str] = None):
+ """
+ Generate test vectors and sample inputs for on-target compilation steps.
+ :param output_path: Path to save exported DLC file
+ :param sim_model: quantsim model to generate quantized o/p for on-target verification.
+ :param prepared_model_reader: recorder with calibration data. Mutually exclusive with sample_inputs
+ :param sample_inputs: iterable with calibration data. Mutually exclusive with prepared_model_reader
+ :param format_model_inputs_fn: user-provided function to adjust model inputs for export
+ :param input_names: input names to use for inputs tensors
+ :param output_names: input names to user for output tensors
+ :param align_fp_model: if True, extracts fp model from sim user-provided function to adjust model inputs for export
+ :param num_samples: number of inputs to use for sampling
+ :param intermediate_modules_names: list of intermediate modules to dump tensors for.
+ """
+ if (prepared_model_reader is not None) and (sample_inputs is not None):
+ raise ValueError("Expected only one of prepared_model_reader or sample_inputs to be passed, both were passed in this case")
+ elif (prepared_model_reader is None) and (sample_inputs is None):
+ raise ValueError("Expected atleast one of prepared_model_reader or sample_inputs to be passed, both were not passed in this case")
+ fp_vectors = []
+ sim_vectors = []
+ prepared_model = QuantizationSimModel.get_original_model(sim_model) if align_fp_model else None
+ os.makedirs(output_path, exist_ok=True)
+ if input_names is None:
+ input_names = list(inspect.signature(sim_model.forward).parameters.keys())
+
+ if format_model_inputs_fn is None:
+ def dummy_format(x: Union[Dict, Tuple]) -> Union[Dict, Tuple]:
+ yield x
+ format_model_inputs_fn = dummy_format
+
+ def dump(tensors: List[np.ndarray], names: List[str] = None) -> Dict:
+ if not isinstance(tensors, (list, tuple)):
+ tensors = [tensors]
+ if names:
+ assert len(tensors) == len(names), f'Length mismatch: tensors length is {len(tensors)}, name length is {len(names)}.'
+ else:
+ names = [f't{i}' for i in range(len(tensors))]
+ return dict(zip(names, tensors))
+
+ def eval_prepared_and_quantsim_models(model_inputs: Tuple):
+
+ if prepared_model:
+ fp_outputs, intermediate_tensors = eval_model(prepared_model, model_inputs, intermediate_modules_names)
+ outputs = dump(np_tensor(fp_outputs), output_names)
+ outputs.update(intermediate_tensors)
+ _sim_outputs, intermediate_tensors = eval_model(sim_model, model_inputs, intermediate_modules_names)
+ sim_outputs = dump(np_tensor(_sim_outputs), output_names)
+ sim_outputs.update(intermediate_tensors)
+ model_inputs = dump(np_tensor(model_inputs), input_names)
+
+ fp_vectors.append({'inputs': model_inputs, 'outputs': np_tensor(outputs)})
+ sim_vectors.append({'inputs': model_inputs, 'outputs': sim_outputs})
+
+ with torch.no_grad():
+
+ for i in tqdm(range(min(len(prepared_model_reader if sample_inputs is None else sample_inputs), num_samples))):
+ if num_samples is not None and num_samples == i:
+ break
+ for formatted_input in format_model_inputs_fn(prepared_model_reader.get_flattened_inputs(i) if sample_inputs is None else sample_inputs[i]):
+ model_inputs = tuple(formatted_input.values()) if isinstance(formatted_input,dict) else formatted_input
+ eval_prepared_and_quantsim_models(model_inputs)
+
+ np.save(f'{output_path}/fp.npy', fp_vectors, allow_pickle=True)
+ np.save(f'{output_path}/int8.npy', sim_vectors, allow_pickle=True)
+
+ inputs_dir = os.path.join(output_path, 'raw_inputs')
+ os.makedirs(inputs_dir, exist_ok=True)
+ sample_inputs = sim_vectors[0]['inputs']
+ with open(os.path.join(inputs_dir, 'input_list.txt'), 'w') as f:
+ for name, tensor in sample_inputs.items():
+ filename = f'{name}.raw'
+ tensor.astype(np.float32).tofile(os.path.join(inputs_dir, filename))
+ f.write(f"{filename} ")
+
+
+def get_sqnr_nptensor(
+ fp_tensor: np.ndarray, sim_tensor: np.ndarray, eps: float, sample_wise: bool = False
+) -> float:
+ """
+ Computes SQNR
+
+ :param fp_tensor: FP32 np array
+ :param sim_tensor: np array with QDQ noise.
+ :param eps: the smallest positive value to avoid div-by-zero error.
+ :param sample_wise: Specifies whether to calculate SQNR per sample before averaging.
+ If set to True, It mitigates the dominance of stronger signal samples in the SQNR calculation.
+ If set to False, SQNR is calculated across all elements. Defaults to False for backward compatibility.
+ :return: Sigal-to-quantization noise ratio in dB scale.
+ """
+ if fp_tensor.shape != sim_tensor.shape:
+ raise ValueError('Both tensors must have the same shape')
+
+ if sample_wise and fp_tensor.ndim < 2:
+ raise ValueError(
+ 'For sample-wise calculation, both tensors must have at least two dimensions'
+ )
+
+ sim_error = fp_tensor - sim_tensor
+
+ if sample_wise:
+ # Sample-wise calculation (dims_except_sample_dim)
+ axis = tuple(range(fp_tensor.ndim)[1:])
+ else:
+ axis = None
+
+ exp_noise = (sim_error**2).mean(axis=axis) + eps
+ exp_signal = (fp_tensor**2).mean(axis=axis)
+ sqnr_db = 10 * (np.log10(exp_signal) - np.log10(exp_noise))
+ return sqnr_db.mean()
+
+
+def get_sqnr(
+ fp_tensor: torch.Tensor,
+ sim_tensor: torch.Tensor,
+ eps: float = sys.float_info.min,
+ sample_wise: bool = False,
+) -> float:
+ """
+ Computes SQNR
+
+ :param fp_tensor: FP32 torch tensor
+ :param sim_tensor: torch tensor with QDQ noise.
+ :param eps: the smallest positive value to avoid div-by-zero error.
+ :param sample_wise: Specifies whether to calculate SQNR per sample before averaging.
+ If set to True, it mitigates the dominance of stronger signal samples in the SQNR calculation.
+ If set to False, SQNR is calculated across all elements. Defaults to False for backward compatibility
+ :return: Sigal-to-quantization noise ratio in dB scale.
+ """
+ return get_sqnr_nptensor(
+ np_tensor(fp_tensor), np_tensor(sim_tensor), eps, sample_wise
+ )
+
+
+def np_tensor(tensor):
+ """
+ :return: numpy tensor(s) from torch tensor(s)
+ """
+ if isinstance(tensor, torch.Tensor):
+ return tensor.detach().cpu().numpy()
+
+ if isinstance(tensor, (list, tuple)):
+ cls = tuple if isinstance(tensor, tuple) else list
+ return cls(np_tensor(x) for x in tensor)
+
+ if isinstance(tensor, dict):
+ return {key: np_tensor(value) for key, value in tensor.items()}
+
+ return tensor
+
+
+def compute_sqnr(model_inputs, fp_model, sim_model, sample_wise: bool = False) -> float:
+ """
+ method to compute sqnr given model inputs as an iterable, fp and sim models
+
+ :param model_inputs: input to the models, iterable
+ :param fp_model: floating point model
+ :param sim_model: quantsim model
+ :param sample_wise: Specifies whether to calculate SQNR per sample before averaging.
+ If set to True, it mitigates the dominance of stronger signal samples in the SQNR calculation.
+ If set to False, SQNR is calculated across all elements. Defaults to False for backward compatibility
+ """
+ fp_outs, qt_outs = [], []
+ with torch.no_grad():
+ for i, img in enumerate(model_inputs):
+ embeddings_fp = fp_model(img)
+ fp_outs.append(embeddings_fp)
+ embeddings_qt = sim_model(img)
+ qt_outs.append(embeddings_qt)
+ sqnr = get_sqnr(
+ torch.stack(fp_outs).cpu(), torch.stack(qt_outs).cpu(), sample_wise=sample_wise
+ )
+ return sqnr
+
+
+def analyze_layerwise(sim: Union[QuantizationSimModel, Any],
+ calibration_data: ModelDataReader,
+ prepared_model: torch.nn.Module,
+ use_prev_layer_quant_output: bool = False,
+ num_samples: int = 1,
+ output_path: str = None,
+ match_name: str = None,
+ export_csv_and_plot: bool = False):
+ """
+ method to dump and log layerwise stats (FP vs quantsim).
+ :param sim: quantization simulation model.
+ :param calibration_data: calibration data to use for comparing sim against fp
+ :param prepared_model: use prepared model as reference for FP.
+ :param use_prev_layer_quant_output: supports two options for layer wise analysis based on the flag
+ True => sim input -> fp/sim output (if prepared model not provided use wrapped fp module)
+ False => fp input -> fp/sim output
+ :param num_samples: number of inputs to use for sampling
+ :param output_path: path to save the stats as pickle file.
+ :param match_name: regex string to match one or more layers
+ :param export_csv_and_plot: export layerwise stats as csv and plot the bottom 25 sqnr values
+ :return: optimized and calibrated quantsim model.
+ """
+ modules_stats = {}
+ fp_to_quant_modules = {}
+ quant_to_fp_modules = {}
+
+ def hook(module: torch.nn.Module, inputs, output):
+ """ Save model input tuple of tensors and output tensor"""
+ kwargs = inspect.currentframe().f_back.f_locals['kwargs']
+ if isinstance(module, BaseQuantizationMixin):
+ quant_module = module
+ fp_module = quant_to_fp_modules[module]
+ fp_output, sim_output = fp_module(*inputs, **kwargs), output
+ else:
+ quant_module = fp_to_quant_modules[module]
+ fp_output, sim_output = output, quant_module(*inputs, **kwargs)
+
+ if isinstance(fp_output, (tuple,list)) and isinstance(sim_output, (tuple, list)):
+ modules_stats[quant_module]['down-selected'] = len(fp_output),len(sim_output)
+ fp_tensor, sim_tensor =np_tensor(fp_output[0]), np_tensor(sim_output[0])
+ else:
+ fp_tensor, sim_tensor =np_tensor(fp_output), np_tensor(sim_output)
+
+ try:
+ modules_stats[quant_module]['sqnr'].append(get_sqnr_nptensor(fp_tensor, sim_tensor, sys.float_info.min))
+ modules_stats[quant_module]['min'].append({'fp': np.min(fp_tensor), 'int': np.min(sim_tensor)})
+ modules_stats[quant_module]['max'].append({'fp': np.min(fp_tensor), 'int': np.min(sim_tensor)})
+ except Exception as e:
+ modules_stats[quant_module]['skipped_on_error'] = type(e)
+
+ name_to_fp_modules = {name: module for name, module in prepared_model.named_modules()} if prepared_model else {}
+
+ hooks = []
+ for name, quant_module in sim.quant_wrappers():
+ modules_stats[quant_module] = {'name': name, 'sqnr': [], 'min': [], 'max': []}
+ if not prepared_model:
+ fp_module = quant_module._module_to_wrap # pylint: disable = protected-access
+ else:
+ fp_module = name_to_fp_modules[name]
+ fp_to_quant_modules[fp_module] = quant_module
+ quant_to_fp_modules[quant_module] = fp_module
+ if use_prev_layer_quant_output:
+ hooks.append(quant_module.register_forward_hook(hook))
+ else:
+ hooks.append(fp_module.register_forward_hook(hook))
+
+ with torch.no_grad():
+ for i, inputs in enumerate(tqdm(calibration_data, total=(min(len(calibration_data),num_samples)))):
+ if i == num_samples:
+ break
+ _ = prepared_model(*inputs) if not use_prev_layer_quant_output else sim.model(*inputs)
+
+ for hook in hooks:
+ hook.remove()
+
+ if match_name:
+ match_name = re.compile(match_name)
+ logger.info('SQNR layerwise stats (num_layers=%d) for %s (%s):', len(modules_stats), type(sim.model),
+ ("sim input -> fp/sim output" if use_prev_layer_quant_output else "fp input -> fp/sim output") )
+ for i, stats in enumerate(modules_stats.values()):
+ mean_sqnr = stats['mean_sqnr'] = np.array(stats['sqnr']).mean()
+ if (match_name is None and not math.isnan(mean_sqnr) and mean_sqnr < 100) \
+ or (match_name and re.match(match_name, stats['name'])):
+ logger.info('%-4.0f\t%-5.2f dB : %s', i, stats['mean_sqnr'], stats['name'])
+
+ if output_path:
+ filepath = os.path.join(output_path, 'layerwise_stats' + ('_sim_in' if use_prev_layer_quant_output else '_fp_in'))
+ np.save(filepath, list(modules_stats.values()), allow_pickle=True)
+ logger.info('saved stats at:%s.npy', filepath)
+
+ if export_csv_and_plot:
+ layerwise_stats_df=pd.DataFrame(list(modules_stats.values()))
+ layerwise_stats_df=layerwise_stats_df.sort_values('mean_sqnr')
+ source = ColumnDataSource(layerwise_stats_df[['name','mean_sqnr']].head(25))
+ plot = figure(y_range=layerwise_stats_df['name'].head(25), height=350,width=700,title="Layerwise mean SQNR Value - Bottom 25")
+ plot.hbar(y='name', right='mean_sqnr', height=0.9,source=source)
+ hover = HoverTool()
+ hover.tooltips = [("Name", "@name"), ("Mean SQNR", "@mean_sqnr")]
+ plot.add_tools(hover)
+ output_file(f'{output_path}/plot.html', mode='inline')
+ save(plot)
+ layerwise_stats_df=layerwise_stats_df[['name','min','max','mean_sqnr']]
+ layerwise_stats_df.to_csv(f"{output_path}/layerwise_data.csv",index=False)
+
+@contextlib.contextmanager
+def increase_recursion_limit():
+ """
+ Utility to temporarily increase recursion limit.
+ :return: None
+ """
+ original_recursion_limit = sys.getrecursionlimit()
+ original_rlimit = resource.getrlimit(resource.RLIMIT_STACK)
+ try:
+ sys.setrecursionlimit(10 ** 9)
+ resource.setrlimit(resource.RLIMIT_STACK, (-1, -1))
+ yield
+ finally:
+ sys.setrecursionlimit(original_recursion_limit)
+ resource.setrlimit(resource.RLIMIT_STACK, original_rlimit)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/lvm_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/lvm_utils.py
new file mode 100644
index 000000000..d5fbce21a
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/lvm_utils.py
@@ -0,0 +1,49 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" Utilities to manage AIMET LVM pipeline elements """
+
+import torch
+from aimet_common.utils import AimetLogger
+logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
+
+
+def adjust_buffer_batch_size(model: torch.nn.Module, batch_size: int, buffer_names: list[str] = '*') -> None:
+ """
+ Adjusts all buffers in a torch.nn.Module to align with desired inference batch size
+ :param model: model to be adjusted inplace.
+ :param batch_size: batch size to set buffer to.
+ :param buffer_names: list of which buffers to update batch-size, or '*' to update all buffers
+ :return: None
+ """
+ all_buffers = [(name, buffer.ndim) for name, buffer in model.named_buffers() if name in buffer_names or buffer_names == '*']
+ for name, ndim in all_buffers:
+ shape = torch.ones(ndim, dtype=torch.int64) * -1
+ shape[0] = batch_size
+ model._buffers[name] = model._buffers[name][:1].expand(tuple(shape))
+ logger.info(f'Adjusted buffers {[name for name, ndim in all_buffers]} to have batch size {batch_size}')
+
+
+def clip_activations(quantizer, threshold, clip_method):
+ if quantizer.get_scale() > threshold:
+ old_min = quantizer.get_min()
+ old_max = quantizer.get_max()
+ old_scaling = quantizer.get_scale()
+
+ if clip_method == 'minside':
+ k = (old_max - threshold / old_scaling * (old_max - old_min)) / old_min
+ new_max = old_max
+ new_min = old_min * k
+ elif clip_method == 'twoside':
+ k = threshold / old_scaling
+ new_max = old_max * k
+ new_min = old_min * k
+ else:
+ raise ValueError(f'{clip_method} is not supported! `clip_method` should be a choice of [`minside`, `twoside`]')
+
+ quantizer.set_range(new_min, new_max)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/model_input_formatter.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/model_input_formatter.py
new file mode 100644
index 000000000..859d662b9
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/model_input_formatter.py
@@ -0,0 +1,68 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" common utilities and class implementation for model inputs adaptation """
+from typing import Dict
+import torch
+from aimet_common.utils import AimetLogger
+
+logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
+
+
+def all_equal(iterable):
+ iterator = iter(iterable)
+ try:
+ first = next(iterator)
+ except StopIteration:
+ return True
+ return all(first == x for x in iterator)
+
+
+def _recursor_set_batch_to_1(input_dict, batch_idx):
+ """
+ recursively shrink an input_dict from batch size B to 1,
+ selecting an specific entry withing the batch size, based on batch_idx
+ :param input_dict: flattened input tensor dictionary
+ :param batch_idx: which batch idx to select
+ :return: dictionary adjusted for batch=1
+ """
+ input_batch_sizes = (value.shape[0] for value in input_dict.values() if isinstance(value, torch.Tensor) and value.ndim>1)
+ assert all_equal(input_batch_sizes), "All inputs must have same batch size"
+
+ batch_1_input_dict = {}
+ for k in list(input_dict.keys()):
+ if isinstance(input_dict[k], dict):
+ batch_1_input_dict[k] = _recursor_set_batch_to_1(input_dict[k], batch_idx)
+ elif isinstance(input_dict[k], torch.Tensor) and input_dict[k].ndim >= 2:
+ batch_1_input_dict[k] = torch.unsqueeze(input_dict[k][batch_idx], 0)
+ elif not isinstance(input_dict[k], bool) and input_dict[k] is not None:
+ batch_1_input_dict[k] = torch.tensor(input_dict[k]).reshape(-1)
+ return batch_1_input_dict
+
+
+def generator_set_batch_to_1(input_dict):
+ """
+ generator that consumes an input_dict with data having
+ batch size B and yields samples having batch size 1
+ :param input_dict: flattened input tensor dictionary
+ :return: tuple of Tensor with 4-d tensor adjusted for batch=1
+ """
+ input_batch_sizes = (value.shape[0] for value in input_dict.values() if isinstance(value, torch.Tensor) and value.ndim>1)
+ assert all_equal(input_batch_sizes), "All inputs must have same batch size"
+ input_batch_sizes = (value.shape[0] for value in input_dict.values() if isinstance(value, torch.Tensor) and value.ndim>1)
+
+ for idx in range(next(input_batch_sizes)):
+ batch_1_input_dict = {}
+ for k in list(input_dict.keys()):
+ if isinstance(input_dict[k], dict):
+ batch_1_input_dict[k] = _recursor_set_batch_to_1(input_dict[k], batch_idx=idx)
+ elif isinstance(input_dict[k], torch.Tensor) and input_dict[k].ndim >= 2:
+ batch_1_input_dict[k] = torch.unsqueeze(input_dict[k][idx], 0)
+ elif not isinstance(input_dict[k], bool) and input_dict[k] is not None:
+ batch_1_input_dict[k] = torch.tensor(input_dict[k]).reshape(-1)
+ yield batch_1_input_dict
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/sim_model_wrapper.py b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/sim_model_wrapper.py
new file mode 100644
index 000000000..27834d5eb
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/genai_lib/lvm/sim_model_wrapper.py
@@ -0,0 +1,103 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" Proxy model to insert sim.model into an existing model workflow """
+import inspect
+from typing import Callable, Optional
+
+import torch
+from diffusers import DiffusionPipeline
+
+from aimet_common.utils import AimetLogger
+from genai_lib.lvm.calibration import flatten_model_inputs, ModelDataReader
+from aimet_torch.utils import get_device
+
+logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
+
+
+class ModelWrapper(torch.nn.Module):
+ """
+ Adapter class to enable interfacing sim object within original model context.
+ """
+ def __init__(self,
+ model_to_wrap: torch.nn.Module,
+ allow_passthrough: bool,
+ model: torch.nn.Module,
+ output_mapper_function: Callable):
+ super(ModelWrapper, self).__init__()
+ self.original_forward_parameters = inspect.signature(model_to_wrap.forward).parameters
+ self._model = model
+ self.pruned = False
+ self.config = getattr(model_to_wrap, 'config', None)
+ self._model_to_wrap = model_to_wrap if allow_passthrough else None
+ self._output_mapper_fn = output_mapper_function
+
+ def forward(self, *inputs, **kwargs):
+ """ forward proxy """
+ inputs = ModelDataReader.format_inputs((*inputs, kwargs), self.original_forward_parameters)
+ inputs = tuple(flatten_model_inputs(dict(zip(self.original_forward_parameters.keys(), inputs))).values())
+ outputs = self._modules['_model'](*inputs)
+ outputs = self._output_mapper_fn(outputs)
+ return outputs
+
+ def __getattr__(self, name):
+ """
+ method to allow forwarding getattr request to the original model
+ """
+ try:
+ model = self._modules['_model']
+ if name == 'device':
+ return get_device(model)
+ if name == 'dtype':
+ return next(model.parameters()).dtype
+
+ return self.__dict__[name]
+ except KeyError as e:
+ info = inspect.getframeinfo(inspect.currentframe().f_back)
+ original_model = self._modules['_model_to_wrap']
+ if original_model is None:
+ logger.error('Unsupported access for %s on prepared model (%s), from %s (%s #%s), ',
+ name, type(model), info.function, info.filename, info.lineno)
+ raise e
+ item = getattr(original_model, name)
+ if isinstance(item, torch.nn.Module):
+ logger.error('Unsupported access for %s of prepared model modules (%s), from %s (%s #%s), ',
+ name, type(original_model), info.function, info.filename, info.lineno)
+ return item
+
+
+def set_model_in_pipeline(model,
+ target: DiffusionPipeline,
+ target_module: str,
+ use_wrapper: bool,
+ output_mapper_function: Optional[Callable] = None,
+ allow_passthrough: bool = False):
+ """
+ replaces the target module with provided model instance
+ :param model: model to replace with.
+ :param target: pipeline containing the target module
+ :param use_wrapper: If True, sets Wrapper model to adapt with original model context.
+ :param target_module: instance path with the target module e.g. 'vae.decoder'
+ :param output_mapper_function: maps the prepared model output which is tensor or Tuple of Tensor
+ :param allow_passthrough: if True, allow accessing to wrapped model for attribute access.
+ """
+ if '.' in target_module:
+ *parent, target_module = target_module.split('.')
+ for m in parent:
+ target = getattr(target, m)
+
+ if not use_wrapper:
+ setattr(target, target_module, model)
+ return
+
+ if output_mapper_function is None:
+ output_mapper_function = lambda x: x
+
+ model_to_wrap = getattr(target, target_module)
+ logger.info('wrapping model(%s) @ %s', type(model_to_wrap), target_module)
+ setattr(target, target_module, ModelWrapper(model_to_wrap, allow_passthrough, model, output_mapper_function))
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/htp_sc8380xp.json b/microsoft-Phi-4-mini-instruct/QAIRT/htp_sc8380xp.json
new file mode 100644
index 000000000..8040b6c4c
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/htp_sc8380xp.json
@@ -0,0 +1,61 @@
+{
+ "input_model": {
+ "type": "QairtPreparedModel",
+ "model_path": "microsoft/phi-4-mini-instruct"
+ },
+ "passes": {
+ "qp": {
+ "type": "QairtPreparation",
+ "script_path": "phi4.py",
+ "script_config": {
+ "PLATFORM_GEN": 2
+ }
+ },
+ "qgab": {
+ "type": "QairtGenAIBuilder",
+ "log_level": "ERROR",
+ "backend": "HTP",
+ "soc_details": "chipset:SC8380XP",
+ "hvx_threads": 8,
+ "vtcm_size_in_mb": 8,
+ "extended_udma": false,
+ "sequence_lengths": [1, 128],
+ "native_kv": false,
+ "multi_graph": true
+ },
+ "qe": {
+ "type": "QairtEncapsulation",
+ "log_level": "ERROR",
+ "run_checker": "true",
+ "engine_config_overrides": {
+ "n_threads": 3,
+ "htp": {
+ "cpu_mask": "0xe0",
+ "allow_async_init": false,
+ "poll": false,
+ "mmap_budget": 0
+ }
+ },
+ "backend_extensions_overrides": {
+ "graphs": null,
+ "devices": [{"pd_session": "unsigned"}],
+ "groupContext": {"share_resources": false}
+ },
+ "genie_overrides": {
+ "eos_token": [199999, 200020],
+ "positional_encoding": {
+ "rope_theta": 100000,
+ "rope_scaling": {
+ "rope_type": "longrope",
+ "factor": 32,
+ "original_max_position_embeddings": 4096
+ }
+ }
+ }
+ }
+ },
+ "log_severity_level": 1,
+ "output_dir": "models/phi4-mini-instruct-hamoa",
+ "cache_dir": "cache",
+ "no_artifacts": true
+}
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/amazon_dataloader.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/amazon_dataloader.py
new file mode 100644
index 000000000..ab3af9a93
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/amazon_dataloader.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" utility method to build and pre-process Amazon dataset """
+import torch
+from datasets import IterableDataset, load_dataset
+from functools import partial
+
+
+def get_amazon_dataset(tokenizer, processor, cache_dir, dataset_path="philschmid/amazon-product-descriptions-vlm", append_assistant_response=False):
+ def _map(example, tokenizer, append_assistant_response):
+ user_prompt = """Create a Short Product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image.
+ Only return description. The description should be SEO optimized and for a better mobile search experience.
+
+ ##PRODUCT NAME##: {product_name}
+ ##CATEGORY##: {category}"""
+
+ content = user_prompt.format(product_name=example["Product Name"], category=example["Category"])
+
+ conversations = [
+ {
+ "role": "system",
+ "content": [{"type": "text", "text": "You are an expert product description writer for Amazon."}],
+ },
+ {
+ "role": "user",
+ "content": [{"type": "image"}, {"type": "text", "text": content}],
+ },
+ ]
+
+ if append_assistant_response:
+ conversations.append(
+ {
+ "role": "assistant",
+ "content": [{"type": "text", "text": example["description"]}],
+ }
+ )
+
+ return {"text": conversations}
+
+ def _transform(example):
+ inputs = tokenizer.apply_chat_template(example['text'], return_tensors="pt", return_dict=True, tokenize=True,
+ add_generation_prompt=True)
+ inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
+ inputs.update({"pixel_values": torch.tensor(processor(example["image"]).pixel_values).unsqueeze(0)})
+ return inputs
+
+ dataset = load_dataset(path=dataset_path, cache_dir=cache_dir, split='train')
+ dataset = dataset.map(partial(_map, tokenizer=tokenizer, append_assistant_response=append_assistant_response))
+ return dataset.with_transform(_transform)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/disc_dataloader.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/disc_dataloader.py
new file mode 100644
index 000000000..0d411deb4
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/disc_dataloader.py
@@ -0,0 +1,82 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import json
+from itertools import chain
+
+from datasets import load_dataset, Dataset
+from transformers import default_data_collator
+from torch.utils.data import DataLoader
+import copy
+
+class DiSCDataset():
+ def __init__(self, tokenizer, block_size, batch_size, return_train_dataset=True, return_test_dataset=False):
+ self._tokenizer = copy.deepcopy(tokenizer)
+ self._block_size = block_size
+ self._batch_size = batch_size
+ self._return_train_dataset = return_train_dataset
+ self._return_test_dataset = return_test_dataset
+
+ def get_disc_dataloader(self, path):
+ tokenizer = self._tokenizer
+ block_size = self._block_size
+ batch_size = self._batch_size
+
+ dataset = load_dataset(path)
+ categor_list = ['grade_elementary', 'length_long']
+
+ # Process both train and test data
+ processed_datasets = {}
+ for split, data in dataset.items():
+ split_data = {'input_ids': [], 'attention_mask': []}
+ for row_tmp in data:
+ if row_tmp['category'] in categor_list:
+ chat, _, _ = self.process_style_adapt_template(row_tmp, False)
+ row_text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=self._return_test_dataset)
+ tokenizer.pad_token = tokenizer.eos_token
+ row1 = tokenizer(row_text, return_tensors="pt", truncation=True, max_length=block_size, padding="max_length")
+ split_data['input_ids'].append(row1['input_ids'][0])
+ split_data['attention_mask'].append(row1['attention_mask'][0])
+
+ processed_datasets[split] = Dataset.from_dict(split_data)
+
+ collate_fn = default_data_collator
+
+ train_dataloader = DataLoader(
+ processed_datasets['train'], shuffle=False,
+ batch_size=batch_size,
+ collate_fn=collate_fn,
+ ) if self._return_train_dataset else None
+
+ test_dataloader = DataLoader(
+ processed_datasets['train'], shuffle=False,
+ batch_size=batch_size,
+ collate_fn=collate_fn,
+ ) if self._return_test_dataset else None
+
+ return train_dataloader, test_dataloader, processed_datasets
+
+ def process_style_adapt_template(self, row, reverse):
+ system_prompt = 'You are a helpful style adapter. Produce one unique rephrasing of the following content using a different style.'
+ if reverse:
+ user_content = row['generation']
+ assistant_content = row['original']
+ else:
+ user_content = row['original']
+ assistant_content = row['generation']
+
+ if self._return_train_dataset:
+ chat = [{'role': 'system', 'content': system_prompt},
+ {'role': 'user', 'content': f"{user_content}"},
+ {'role': 'assistant', 'content': f"{assistant_content}"}]
+ if self._return_test_dataset :
+ chat = [{'role': 'system', 'content': system_prompt},
+ {'role': 'user', 'content': f"{user_content}"}]
+ #row_text = self.cfg['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt=False)
+ return chat, user_content, assistant_content
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/forward_pass_wrapper.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/forward_pass_wrapper.py
new file mode 100644
index 000000000..c6df03099
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/forward_pass_wrapper.py
@@ -0,0 +1,718 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+#
+# This file contains certain notices of software components included with the
+# software that Qualcomm Technologies, Inc. ("QTI") is required to provide you.
+# Except where prohibited by the open source license, the content of this file is
+# provided solely to satisfy QTI's attribution and notice requirement; your use of
+# these software components together with the QTI software ("Software") is subject
+# to the terms of your license from QTI. Compliance with all copyright laws and
+# software license agreements included in the notice section of this file are the
+# responsibility of the user. Except as may be granted by separate express written
+# agreement, this file provides no license to any patents, trademarks, copyrights,
+# or other intellectual property of Qualcomm Incorporated or any of its
+# subsidiaries.
+#
+# Software provided with this notice is NOT A CONTRIBUTION to any open source
+# project. If alternative licensing is available for any of the components with
+# licenses or attributions provided below, a license choice is made for receiving
+# such code by QTI.
+
+# Copyright (c) 2023 Qualcomm Technologies, Inc. All rights reserved.
+
+# Qualcomm is a trademark of Qualcomm Incorporated, registered in the United
+# States and other countries. All Qualcomm Incorporated trademarks are used with
+# permission. Other products and brand names may be trademarks or registered
+# trademarks of their respective owners.
+#
+# =============================================================================
+# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
+# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =============================================================================
+
+""" utility method to adapt original model, prepared model and model forward pass invocation """
+import inspect
+
+import contextlib
+
+import json
+import math
+import torch
+from torch import nn
+from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
+from aimet_torch.utils import get_device
+
+from packaging import version
+from importlib.metadata import version as impLib_version
+
+from transformers.models.phi3.modeling_phi3 import DynamicCache
+
+class Phi4RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, partial_rotary_factor=0.75, device=None):
+ super().__init__()
+
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ self.partial_rotary_factor = partial_rotary_factor
+ self.dim = int(dim * self.partial_rotary_factor)
+ self.register_buffer("inv_freq", None, persistent=False)
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if self.inv_freq is None:
+ self.inv_freq = 1.0 / (
+ self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
+ )
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+class Phi4LongRoPEScaledRotaryEmbedding(Phi4RotaryEmbedding):
+ def __init__(self, dim, config, partial_rotary_factor, device=None):
+ super().__init__(dim,
+ max_position_embeddings = config.max_position_embeddings,
+ base = config.rope_theta,
+ partial_rotary_factor = partial_rotary_factor,
+ device = device)
+
+ self.short_factor = config.rope_scaling["short_factor"]
+ self.long_factor = config.rope_scaling["long_factor"]
+ self.original_max_position_embeddings = config.original_max_position_embeddings
+
+ @torch.no_grad()
+ def forward(self, x, position_ids, seq_len=None):
+ seq_len = seq_len or torch.max(position_ids) + 1
+ if seq_len > self.original_max_position_embeddings:
+ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
+ else:
+ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
+
+ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
+
+ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
+
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+
+ # Force float32 since bfloat16 loses precision on long contexts
+ # See https://github.com/huggingface/transformers/pull/29285
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+
+ scale = self.max_position_embeddings / self.original_max_position_embeddings
+ if scale <= 1.0:
+ scaling_factor = 1.0
+ else:
+ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
+
+ cos = emb.cos() * scaling_factor
+ sin = emb.sin() * scaling_factor
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+def flatten_tensors(tup):
+ if not isinstance(tup, (tuple, list)):
+ yield tup
+ return
+ for x in tup:
+ yield from flatten_tensors(x)
+
+def get_padded_kv_values(past_size, num_layers,hidden_size, num_attention_heads, batch_size=1, num_kv_heads=32,
+ transposed_key_cache=True, device='cuda', dtype=torch.float32):
+
+ def _cache(shape):
+ return torch.zeros(shape).to(device=device, dtype=dtype)
+
+ head_dim = num_kv_heads
+ value = (batch_size, head_dim, past_size, hidden_size // num_attention_heads)
+ key = (value[0], value[1], value[3], value[2]) if transposed_key_cache else tuple(value)
+ past_key_values = tuple((_cache(key), _cache(value)) for _ in range(num_layers))
+ return past_key_values
+
+
+class RopeEmbedding:
+ def __init__(self, device, head_dim=128, max_length=2048, partial_rotary_factor=0.75, config=None):
+ self.cos, self.sin = self.precompute(device, head_dim, partial_rotary_factor, max_length, config)
+
+ def precompute(self, device, head_dim, partial_rotary_factor, max_length, config):
+ def _support_llama3_rope():
+ import transformers
+ return tuple([int(i) for i in transformers.__version__.split(".")]) >= (4,43,2)
+ #return version.parse(impLib_version('transformers')) >= version.parse('4.43.2')
+
+ head_dim = config.head_dim if hasattr(config, 'head_dim') else config.hidden_size // config.num_attention_heads
+ kwargs = {
+ 'max_position_embeddings': config.max_position_embeddings,
+ 'base': config.rope_theta,
+ 'device': device,
+ }
+ if _support_llama3_rope():
+ kwargs['config'] = config
+
+ if not hasattr(config, 'rope_scaling'):
+ setattr(config, 'rope_scaling', None)
+
+ if config.rope_scaling is None:
+ rope = Phi4RotaryEmbedding(
+ dim=head_dim,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rope_theta,
+ partial_rotary_factor=partial_rotary_factor,
+ device=device
+ )
+ else:
+ scaling_type = config.rope_scaling["type"]
+ if scaling_type == "longrope":
+ rope = Phi4LongRoPEScaledRotaryEmbedding(dim = head_dim,
+ config = config,
+ partial_rotary_factor = partial_rotary_factor,
+ device=device)
+ else:
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+ dummy_x = torch.Tensor([1.0]).to(device)
+ # dummy_x device
+ position_ids = torch.arange(max_length).view(1, -1).to(device)
+ if hasattr(rope, '_original_forward'):
+ embeddings = rope._original_forward(dummy_x, position_ids)
+ else:
+ embeddings = rope.forward(dummy_x, position_ids)
+
+ # for adapted llama
+ emb_size = embeddings[0].size(-1) // 2
+ embeddings = [emb[:, :, :emb_size] for emb in embeddings]
+ embeddings = [emb.unsqueeze(0) for emb in embeddings]
+ return embeddings
+
+ def get_embedding(self, position_ids, dtype=torch.float32):
+ '''
+ position_ids: [batch_size, sequence_length]
+ return [batch_size, 1, sequence_length, head_sim//2][2]
+ '''
+ cos = self.cos[0,0,:,:] # [seq_len, dim]
+ sin = self.sin[0,0,:,:] # [seq_len, dim]
+ cos = cos[position_ids].unsqueeze(1).to(dtype=dtype)
+ sin = sin[position_ids].unsqueeze(1).to(dtype=dtype)
+ return cos, sin
+
+def prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length, mask_neg=-100.0, sliding_window=None):
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
+ def _make_causal_mask(
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0,
+ mask_neg: float = -100.0, sliding_window = None
+ ):
+ """
+ Make causal mask used for bi-directional self-attention.
+ """
+ bsz, tgt_len = input_ids_shape[0], input_ids_shape[1]
+ # mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(mask_neg, device=device), device=device)
+ mask_cond = torch.arange(mask.size(-1), device=device)
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
+ mask = mask.to(dtype)
+
+ if past_key_values_length > 0:
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
+
+ # sliding window attention
+ if sliding_window is not None:
+ sliding_attn_mask = torch.zeros(mask.shape, device=device)
+ sliding_mask_cond = (mask_cond + past_key_values_length - sliding_window).view(mask_cond.size(0), 1)
+ sliding_attn_mask.masked_fill_(torch.arange(tgt_len + past_key_values_length, device=device) <= sliding_mask_cond, mask_neg)
+ mask += sliding_attn_mask
+
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
+
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, mask_neg: float = -100.0, tgt_len: int = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ # return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), mask_neg)
+
+
+ # create causal mask
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ combined_attention_mask = None
+ if input_shape[-1] > 1:
+ combined_attention_mask = _make_causal_mask(
+ input_shape,
+ inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ past_key_values_length=past_key_values_length,
+ mask_neg=mask_neg,
+ sliding_window=sliding_window,
+ )
+
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[1], mask_neg=mask_neg).to(
+ inputs_embeds.device
+ )
+
+ combined_attention_mask = (
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ )
+
+ return combined_attention_mask
+
+def get_position_embeddings_from_position_ids(position_ids, head_dim, max_length, partial_rotary_factor, device, dtype, config):
+ return RopeEmbedding(device = device,
+ head_dim = head_dim,
+ max_length = max_length,
+ partial_rotary_factor = partial_rotary_factor,
+ config = config).get_embedding(position_ids, dtype=dtype)
+
+def prepare_combined_attention_mask(attention_mask, input_shape, past_key_values_length, device, mask_neg=-100.0,
+ sliding_window=None, dtype=torch.float32):
+ dummy_embedding = torch.tensor((1.0,)).to(torch.float32).to(device)
+ new_mask = prepare_decoder_attention_mask(attention_mask, input_shape, dummy_embedding, past_key_values_length, mask_neg, sliding_window)
+ return new_mask.clamp_min(mask_neg).to(dtype)
+
+
+class LLMForwardPassManager:
+ def __init__(self, cfg, model, tokenizer, separate_tuple_input_output, num_tokens):
+ self.tokenizer = tokenizer
+ self.model = model
+ self.config = cfg
+ self.device = get_device(model)
+
+ self.num_heads = getattr(cfg, 'num_attention_heads', 1)
+ self.num_kv_heads = getattr(cfg, 'num_key_value_heads')
+ self.num_layers = getattr(cfg, 'num_hidden_layers', 32)
+ self.embed_dim = getattr(cfg, 'hidden_size', 1024)
+ self.rope_theta = getattr(cfg, "rope_theta", 10000.0)
+ self.sliding_window = getattr(cfg, "sliding_window", None)
+ self.max_tokens = tokenizer.model_max_length
+ self.num_tokens = num_tokens
+ self.use_position_embedding_input = getattr(cfg, 'use_position_embedding_input', False)
+ self.use_combined_mask_input = getattr(cfg, 'use_combined_mask_input', False)
+ self.transposed_key_cache = getattr(cfg, 'transposed_key_cache', False)
+ self.mask_neg = getattr(cfg, 'mask_neg', -100)
+ self.use_input_embeddings = getattr(cfg, 'use_input_embeddings', False)
+ self.return_new_key_value_only = getattr(cfg, 'return_new_key_value_only', False)
+ self.use_cache = getattr(cfg, 'use_cache', False)
+ self.separate_tuple_input_output = separate_tuple_input_output
+ self.record_test_vectors = False # users of this block wil enable/disable this as necessary with provided functions
+ self.dummy_kvcache_generator = None # DummyKvcacheGenerator(cfg)
+ self.input_id_to_embedding_converter = None
+ self.partial_rotary_factor = getattr(cfg, 'partial_rotary_factor', 0.75)
+
+
+ @property
+ def dtype(self):
+ return next(self.parameters()).dtype
+
+ def replace_model(self, new_model):
+ self.model = new_model
+ self.model.to(self.device)
+
+ @contextlib.contextmanager
+ def place_on_device(self, device):
+ original_device = self.device
+ try:
+ self.to(device)
+ yield
+ finally:
+ self.to(original_device)
+
+ def to(self, device=torch.device):
+ self.device = torch.device(device)
+ self.model.to(self.device)
+
+ def named_modules(self):
+ return self.model.named_modules()
+
+ def parameters(self):
+ return self.model.parameters()
+
+ def _tokenize_text(self, text, max_length):
+ if self.tokenizer == None:
+ print(
+ "No tokenizer was registered with forward pass manager. Attempt to forward text inputs has failed.")
+ assert False
+
+ encoded_tensor = self.tokenizer(text, add_special_tokens=False, max_length=max_length, truncation=True)
+ return encoded_tensor
+
+ def _update_kv_cache(self, prev_key_value, new_key_value, max_cache_size, is_concatenated=False):
+ # past_key_value: [num_layers][2][key_value], where key_value can be a tensor or tuple of heads
+ def _concat(a, b, dim):
+ if isinstance(a, tuple):
+ assert len(a) == len(b), 'Unexpected key/value pair'
+ return tuple(_concat(ai, bi, dim) for ai, bi in zip(a, b))
+ return torch.cat((a, b), dim=dim)
+
+ def _do_concat(a, b, key_dim, value_dim):
+ return tuple((_concat(ak, bk, key_dim), _concat(av, bv, value_dim)) for (ak, av), (bk, bv) in zip(a, b))
+
+ def _shift(a, dim, shift_size):
+ if isinstance(a, tuple):
+ return tuple(_shift(ai, dim) for ai in a)
+ assert dim in (2, 3), 'Unexpected shift axis'
+ return a[:, :, shift_size:, :] if dim == 2 else a[:, :, :, shift_size:]
+
+ def _do_shift(a, key_dim, value_dim, shift_size):
+ return tuple((_shift(k, key_dim, shift_size), _shift(v, value_dim, shift_size)) for k, v in a)
+
+ value_dim = 2
+ key_dim = 3 if self.transposed_key_cache else 2
+
+ if prev_key_value is None or is_concatenated:
+ # some models concat new key values and old key values internally
+ # `is_concatenated` indicates whether new_key_value is already concatenated
+ next_key_value = new_key_value
+ elif new_key_value is None:
+ # when dummy_kv + None
+ next_key_value = prev_key_value
+ else:
+ # if concat is NOT done, then concat
+ next_key_value = _do_concat(prev_key_value, new_key_value, key_dim, value_dim)
+
+ shift_size = next_key_value[0][1].shape[-2] - max_cache_size
+ if shift_size > 0:
+ next_key_value = _do_shift(next_key_value, key_dim, value_dim, shift_size)
+
+ return next_key_value
+
+ def validate_inputs(self, input_text=None, input_ids=None, input_embeddings=None, past_key_values=None):
+ # make sure only one of input_text, input_ids, input_embeddings is passed in
+ input_count = 0
+ for input in (input_text, input_ids, input_embeddings):
+ if input is not None:
+ input_count = input_count + 1
+ if input_count != 1:
+ print("Incorrect number of arguments: one of (input_text, input_ids, input_embeddings) expected.")
+ return False
+
+ # make sure that input embedding function has been selected if input embeddings are to be used
+ if self.use_input_embeddings and self.input_id_to_embedding_converter is None and input_embeddings is None:
+ print(
+ "use_input_embeddings is set to true, but no input_embeddings were provided, and input_id_to_embedding_converter is None.")
+ return False
+
+ if past_key_values is not None and past_key_values[0][1].shape[-2] > self.max_tokens - self.num_tokens:
+ print(
+ "Provided past_key_values are too long. past_key_values length cannot exceed max_tokens - num_tokens.")
+ return False
+
+ return True
+
+ def validate_input_lengths(self, input_length, mask_length, attn_length):
+ if 1 > input_length or input_length > self.num_tokens:
+ print(
+ f"Incorrect sequence length provided: input_length({input_length}) must be less than or equal to num_tokens ({self.num_tokens}).")
+ return False
+
+ if attn_length < mask_length or mask_length < input_length:
+ print(
+ f"Incorrect attention length provided: mask_length({mask_length}) must be greater than or equal to input_length({input_length}) and less than or equal to the sum({attn_length}) of input_length and kv_length.")
+ return False
+
+ return True
+
+ def validate_processed_inputs(self, input=None, attention_mask=None, past_key_values=None):
+ # if input make sure that only correct length sequence is provided
+ if input.shape[1] != self.num_tokens:
+ print(
+ f"Incorrect prcessing for sequence length: dim 1({input.shape[1]}) of input must be of length num_tokens in KV cache mode.")
+ return False
+
+ if attention_mask.shape[1] != self.max_tokens:
+ print(
+ f"Incorrect prcessing for attention length: dim 1({attention_mask.shape[1]}) of input must be of length max_tokens.")
+ return False
+
+ if past_key_values is not None and past_key_values[0][1].shape[-2] != self.max_tokens - self.num_tokens:
+ print(
+ f"Incorrect prcessing for past_kv length: dim 1({past_key_values[0][1].shape[-2]}) of input must be of length max_tokens - num_tokens.")
+ return False
+
+ return True
+
+ def get_position_embeddings_from_position_ids(self, position_ids):
+ return get_position_embeddings_from_position_ids(position_ids,
+ head_dim=self.embed_dim // self.num_heads,
+ max_length=self.max_tokens,
+ partial_rotary_factor = self.partial_rotary_factor,
+ device=self.device,
+ dtype=self.dtype,
+ config=self.config)
+
+ def prepare_combined_attention_mask(self, attention_mask, input_shape, past_kv_length, sliding_window):
+ return prepare_combined_attention_mask(attention_mask, input_shape=input_shape,
+ past_key_values_length=past_kv_length, device=self.device,
+ mask_neg=self.mask_neg, sliding_window=sliding_window, dtype=self.dtype)
+
+ def prepare_inputs(self, input_text=None, input_ids=None, input_embeddings=None, attention_mask=None,
+ past_key_values=None, **kwargs):
+ assert self.validate_inputs(input_text, input_ids, input_embeddings, past_key_values)
+
+ kvcache_info_bundle = {} # dict to hold values needed for KV cache post-processing
+
+ if input_text is not None:
+ max_length = self.num_tokens
+ encoded = self._tokenize_text(input_text, max_length=max_length)
+ input_ids = encoded.input_ids
+ attention_mask = encoded.attention_mask
+
+ if self.use_input_embeddings:
+ if input_embeddings is None:
+ input_embeddings = self.input_id_to_embedding_converter(input_ids).to(dtype=self.dtype)
+ input = input_embeddings
+ # if we cast this input to long, all floats become zero in the input which we do not want
+ input = torch.tensor(input.clone().detach(), dtype=self.dtype, device=self.device)
+ else:
+ input = input_ids
+ input = torch.tensor(input.clone().detach(), dtype=torch.long, device=self.device)
+ batch_size = input.shape[0]
+ input_length = input.shape[1]
+
+ kvcache_info_bundle["input_length"] = input_length
+
+ # get kv_length from past values because values are not transposed.
+ kv_length = past_key_values[0][1].shape[-2] if past_key_values is not None else 0
+ attn_length = min(input_length + kv_length, self.max_tokens)
+
+ # Checking attention_mask first, otherwise we will create attention_mask from input_extensions.
+ # input_extensions will be empty tensors and so as attention_mask.
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, input_length + kv_length), dtype=torch.long, device=self.device)
+
+ # cast type and move device
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = attention_mask.to(dtype=torch.long, device=self.device)
+ else:
+ # if attention_mask is not a tensor, get tensor
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long, device=self.device)
+ mask_length = attention_mask.shape[1]
+
+ assert self.validate_input_lengths(input_length, mask_length, attn_length)
+
+ # Pad inputs
+ if input_length < self.num_tokens:
+ shape = (batch_size, self.num_tokens - input_length)
+ # expand shape if input is input_embeddings
+ if self.use_input_embeddings:
+ shape += (input.shape[-1],)
+ input_extensions = torch.full(
+ shape,
+ fill_value=self.tokenizer.eos_token_id,
+ dtype=input.dtype,
+ device=self.device
+ )
+ input = torch.cat((input_extensions, input), dim=1)
+
+ # Pad attention_mask
+ attention_mask_extension_for_padded_kvcache = torch.zeros((batch_size, attn_length - mask_length),
+ dtype=torch.long, device=self.device)
+ attn_mask_extensions_for_padded_input = torch.zeros((batch_size, self.num_tokens - input_length), \
+ dtype=torch.long, device=self.device)
+ attention_mask = torch.cat((
+ attention_mask_extension_for_padded_kvcache,
+ attention_mask[:, :-input_length],
+ attn_mask_extensions_for_padded_input,
+ attention_mask[:, -input_length:]
+ ), dim=1
+ )
+
+ desired_kv_length = self.max_tokens - self.num_tokens
+ kv_padding_length = max(desired_kv_length - kv_length, 0)
+ kvcache_info_bundle['kv_padding_length'] = kv_padding_length
+
+ past_key_values_extension = get_padded_kv_values(past_size=kv_padding_length,
+ num_layers=self.num_layers,
+ hidden_size=self.embed_dim,
+ num_attention_heads=self.num_heads,
+ num_kv_heads=self.num_kv_heads,
+ transposed_key_cache=self.transposed_key_cache,
+ device=self.device,
+ dtype=self.dtype)
+ past_key_values = self._update_kv_cache(past_key_values_extension, past_key_values, desired_kv_length)
+
+ attention_mask_extension = torch.zeros((batch_size, kv_padding_length), dtype=torch.long,
+ device=self.device)
+ attention_mask = torch.cat((attention_mask_extension, attention_mask), dim=1)
+
+ assert self.validate_processed_inputs(input, attention_mask, past_key_values)
+
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
+ position_ids = position_ids.clip(0, self.max_tokens - 1)
+ position_ids = position_ids[..., -self.num_tokens:]
+
+ if self.use_position_embedding_input:
+ position_ids = self.get_position_embeddings_from_position_ids(position_ids)
+
+ if self.use_combined_mask_input:
+ past_kv_length = self.max_tokens - self.num_tokens
+ attention_mask = self.prepare_combined_attention_mask(attention_mask, input.shape, past_kv_length, self.sliding_window)
+
+ inputs = {
+ 'attention_mask': attention_mask,
+ }
+
+ if self.separate_tuple_input_output and self.config.use_position_embedding_input:
+ inputs['position_ids_cos'] = position_ids[0]
+ inputs['position_ids_sin'] = position_ids[1]
+ else:
+ inputs['position_ids'] = position_ids
+
+ if self.use_input_embeddings:
+ inputs['inputs_embeds'] = input
+ else:
+ inputs['input_ids'] = input
+
+ if self.separate_tuple_input_output:
+ if "input_names" in kwargs:
+ input_names = kwargs['input_names']
+ else:
+ signature = inspect.signature(self.model.forward)
+ input_names = tuple(signature.parameters.keys())
+ flattened_key_values = flatten_tensors(past_key_values)
+ # input_ids, attention_mask, position_ids_cos, position_ids_sin, (past_key_values)
+ # this order is different when we use the input_embeddings -> attention_mask, position_ids_cos, position_ids_sin, (past_key_values), inputs_embeds
+ if not self.use_input_embeddings:
+ offset = 4 if self.config.use_position_embedding_input else 3
+ for key, value in zip(input_names[offset:], flattened_key_values):
+ inputs[key] = value
+ else:
+ for key, value in zip(input_names[3:-1], flattened_key_values):
+ inputs[key] = value
+ else:
+ inputs['past_key_values'] = past_key_values
+
+ # print("past_key_values:", type(inputs['past_key_values']))
+ # print(inputs.keys())
+ # print("lm_logits:", type(lm_logits))
+
+ return inputs, kvcache_info_bundle
+
+ def prepare_outputs(self, outputs, prepared_inputs, kvcache_info_bundle):
+ """
+ Args:
+ outputs (tuple): Tuple of model outputs.
+ outputs[0]: logits (batch, num_tokens, vocab_size)
+ outputs[-1]: kv caches with max_tokens length
+ prepared_inputs (dict): Dictionary of prepared inputs.
+ kvcache_info_bundle (dict): Dictionary containing information about key-value cache.
+
+ Returns:
+ dict: A dictionary containing 'lm_logits' and 'past_key_values'.
+ lm_logits: (batch, num_tokens, vocab_size)
+ past_key_values: having length as the number of non-dummy inputs
+ """
+ lm_logits = outputs[0]
+ lm_logits = lm_logits[:, -kvcache_info_bundle["input_length"]:, :]
+
+ def _get_past_kv_from_outputs(outputs):
+ if self.separate_tuple_input_output:
+ return tuple((outputs[(2 * i) + 1], outputs[(2 * i) + 2]) for i in range(self.num_layers))
+ else:
+ return outputs[-1]
+
+ def _get_past_kv_from_prepared_inputs(prepared_inputs):
+ if self.separate_tuple_input_output:
+ return tuple((prepared_inputs[f"past_key_{i}_in"], prepared_inputs[f"past_value_{i}_in"]) for i in range(self.num_layers))
+ else:
+ return prepared_inputs['past_key_values'] if 'past_key_values' in prepared_inputs else None
+
+ new_past_key_values = _get_past_kv_from_outputs(outputs)
+ new_past_key_values = self._update_kv_cache(
+ None,
+ new_past_key_values,
+ kvcache_info_bundle["input_length"]
+ )
+ old_past_key_values = _get_past_kv_from_prepared_inputs(prepared_inputs)
+
+ current_kv_length_with_padding_removed = self.max_tokens - self.num_tokens - kvcache_info_bundle[
+ 'kv_padding_length'] + kvcache_info_bundle['input_length'] # number of non-dummy inputs
+
+ past_key_values = self._update_kv_cache(
+ old_past_key_values,
+ new_past_key_values,
+ current_kv_length_with_padding_removed
+ )
+
+ return {'lm_logits': lm_logits, 'past_key_values': past_key_values}
+
+ def __call__(self, *args, **kwargs):
+ prepared_inputs, kvcache_info_bundle = self.prepare_inputs(*args, **kwargs)
+ outputs = self.model(**prepared_inputs)
+ prepared_outputs = self.prepare_outputs(outputs, prepared_inputs, kvcache_info_bundle)
+ return prepared_outputs
+
+
+def slice_inputs_and_run_successive_kvcache_inference(fpm, input_ids=None, input_embeds=None, **kwargs):
+ if input_ids is not None:
+ input_length = input_ids.shape[1]
+ else:
+ input_length = input_embeds.shape[1]
+
+ outputs = {}
+
+ attention_mask = kwargs.pop('attention_mask', None)
+
+ for idx in range(0, input_length, fpm.num_tokens)[::-1]:
+ idx = input_length - idx
+
+ if attention_mask is not None:
+ cache_offset = attention_mask.shape[1] - input_length
+ kwargs["attention_mask"] = attention_mask[:, max(0, cache_offset + idx - fpm.max_tokens):cache_offset + idx]
+
+ if input_ids is not None:
+ cur_outputs = fpm(input_ids=input_ids[:, max(0, idx - fpm.num_tokens):idx], **kwargs)
+ elif input_embeds is not None:
+ cur_outputs = fpm(input_ids=None, input_embeddings=input_embeds[:, max(0, idx - fpm.num_tokens):idx, :],
+ **kwargs)
+ else:
+ print("No input_ids or inputs_embeds provided to inference generator!")
+ assert False
+
+ # get valid outputs
+ bsz, length, dim = cur_outputs['lm_logits'].shape
+
+ outputs['lm_logits'] = torch.cat(
+ (outputs.get('lm_logits', torch.zeros((bsz, 0, dim), device=fpm.device)), cur_outputs['lm_logits']),
+ dim=1)
+ kwargs['past_key_values'] = outputs['past_key_values'] = cur_outputs['past_key_values']
+
+ return outputs
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/llava_dataloader.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/llava_dataloader.py
new file mode 100644
index 000000000..be8ce83f5
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/llava_dataloader.py
@@ -0,0 +1,76 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" utility method to build and pre-process WikiText dataset """
+import os
+import torch
+from itertools import chain
+from torch.utils.data import DataLoader, Dataset
+from datasets import IterableDataset, load_dataset
+from transformers import default_data_collator
+from PIL import Image as PILImage
+from transformers.feature_extraction_utils import BatchFeature
+
+def _convert_one_conversation(conversation: list[dict[str, str]]) -> dict[str, str | list[dict]]:
+ """Convert the given conversation to LLaVA style.
+
+ Examples:
+
+ >>> conversation = {"from": "human", "value": "What are the colors of the bus in the image?"}
+ >>> LlavaConversationProcessor._convert(conversation)
+ {
+ 'role': 'user',
+ 'content': [{'type': 'image'}, {'type': 'text', 'text': 'What are the colors of the bus in the image?'}]
+ }
+ >>> conversation = {"from": "gpt", "value": "The bus in the image is white and red."}
+ >>> _convert(conversation)
+ {
+ 'role': 'assistant',
+ 'content': [{'type': 'text', 'text': 'The bus in the image is white and red.'}]
+ }
+ """
+ who = conversation.get("from")
+ match who:
+ case "human":
+ role = "user"
+ case "gpt":
+ role = "assistant"
+ case _:
+ raise ValueError(f"Unknown role: {who}")
+
+ text = conversation.get("value")
+
+ if "" in text:
+ has_image = True
+ text = text.replace("", "")
+ else:
+ has_image = False
+
+ return {
+ "role": role,
+ "content": (
+ [{"type": "image"}, {"type": "text", "text": text}] if has_image else [{"type": "text", "text": text}]
+ ),
+ }
+
+def get_llava_dataset(tokenzier, processor, data_files, dataset_path, cache_dir):
+ def _map(examples):
+ examples['text'] = [_convert_one_conversation(conversation=conversation) for conversation in
+ examples['conversations']]
+ return examples
+
+ def _load_image_and_tokenize(example):
+ inputs = tokenzier.apply_chat_template(example['text'], add_generation_prompt=True, tokenize=True,
+ return_tensors="pt", return_dict=True)
+ inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
+ inputs.update({"pixel_values": torch.tensor(processor(PILImage.open(fp=os.path.join(dataset_path, example["image"][0]))).pixel_values).unsqueeze(0)})
+ return inputs
+
+ dataset = load_dataset("json", data_files=data_files, cache_dir=cache_dir, split='train')
+ dataset = dataset.map(_map)
+ return dataset.with_transform(_load_image_and_tokenize)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/lora_utils.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/lora_utils.py
new file mode 100644
index 000000000..d8853cc2c
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/lora_utils.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import os
+import re
+
+import torch
+from safetensors.torch import save_file
+from aimet_torch.peft import LoraLayer
+
+from peft.tuners.lora.layer import LoraLayer as PeftLoraLayer
+from peft import PeftMixedModel
+from typing import Dict
+
+def set_lora_scaling(
+ model: PeftMixedModel,
+ lora_scaling: Dict[str, float]
+):
+
+ """Set the LoRA adapters' scaling parameter to the given value.
+
+ Args:
+ model (PeftModel): The model for which to set the scale parameter.
+ scale (float): The scale value to set as a dictionary ['adapter_name': value].
+ """
+ for name, module in model.named_modules():
+ if isinstance(module, PeftLoraLayer):
+ for adapter_name, scaling in lora_scaling.items():
+ module.scaling[adapter_name] = scaling
+
+
+def save_lora_weights_after_adaptation(model: torch.nn.Module, path: str, filename_prefix: str):
+ """
+ Utility to save model weights after model adaptations
+
+ :param model: PEFT model
+ :param path: path where to store weights after adaptation
+ :param filename_prefix: Prefix to use for filenames
+ """
+ param_to_name = {}
+
+ for name, param in model.named_parameters():
+ param_to_name[param] = name
+
+ lora_weights = {}
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ for _, param in module.lora_A.named_parameters():
+ name = param_to_name[param]
+ lora_weights[name] = param
+ for _, param in module.lora_B.named_parameters():
+ name = param_to_name[param]
+ lora_weights[name] = param
+
+ filename_prefix = filename_prefix + '.safetensor'
+ model_params_path = os.path.join(path, filename_prefix)
+ save_file(lora_weights, model_params_path)
+
+def adapt_true_base_encodings(true_base_encodings_path: str, target_modules: list, output_path: str):
+ """
+ Modifies the encodings by appending "_base_layer" to the target module substrings
+ within the file and saves the updated content.
+ :param true_base_encodings_path (str): Path to the encodings file.
+ :param target_modules (list): List of module names to search for and modify.
+ :param output_path (str): Path to save the modified encodings.
+ """
+ # Read the content of the file
+ with open(true_base_encodings_path, 'r') as file:
+ lines = file.readlines()
+ # Initialize modified lines
+ modified_lines = []
+ for line in lines:
+ modified_line = line
+ for module in target_modules:
+ if module in line:
+ # Append ".base_layer" after the module substring
+ modified_line = modified_line.replace(module, f"{module}.base_layer")
+ modified_lines.append(modified_line)
+ # Write the modified lines back to the output file
+ with open(output_path, 'w') as file:
+ file.writelines(modified_lines)
+
+def adapt_peft_config_for_lora_meta_data(adapter_name, peft_config):
+ """
+ Given the target_modules in the HuggingFace PEFT config, find all the corresponding full layer names
+ in the JSON prepare layer map. Return a list of the new target modules.
+ :param adapter_name (str): Name of the adapter.
+ :param peft_config (dict): Individual HuggingFace PEFT config containing `target_modules`.
+ Returns:
+ set: A list of new target modules (full prepared graph layer names matching the criteria).
+ """
+ mpp_peft_config = {}
+ mpp_peft_config["name"] = adapter_name
+ mpp_peft_config["rank"] = peft_config.__dict__["r"]
+ mpp_peft_config["alpha"] = peft_config.__dict__["lora_alpha"]
+ mpp_peft_config["target_modules"] = list(peft_config.target_modules)
+ return mpp_peft_config
+
+
+def convert_linear_to_conv_weights(weights_dict):
+ """
+ Convert 2D linear LoRA weights to 4D format compatible with convolutional operations.
+ Processes a state dict and returns the converted version.
+ :param weights_dict: Dictionary containing LoRA weights
+ Returns:
+ Dictionary with weights converted to convolutional format
+ """
+ converted_weights = {}
+ for name, tensor in weights_dict.items():
+ if 'lora_A' in name or 'lora_B' in name:
+ converted_weights[name] = tensor.unsqueeze(-1).unsqueeze(-1)
+ else:
+ converted_weights[name] = tensor
+ return converted_weights
+
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/mixed_precision_overrides.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/mixed_precision_overrides.py
new file mode 100644
index 000000000..6f329f87b
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/mixed_precision_overrides.py
@@ -0,0 +1,147 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+""" util to consume mixed precision config and apply to quantizer """
+import re
+import json
+from collections import defaultdict
+
+from aimet_torch.v2.nn.base import BaseQuantizationMixin
+from aimet_torch.v2.quantization.float.quantizer import FloatQuantizeDequantize
+from aimet_torch.quantsim_config.builder import LazyQuantizer
+from aimet_common.defs import QuantScheme, QuantizationDataType
+
+
+def apply_input_output_exception(quantizer, exception):
+ assert isinstance(exception, dict), f"Exception is not a dictionary type: {exception}"
+
+ if exception.get("enabled", True) is False:
+ return None
+
+ if quantizer is None:
+ quantizer = LazyQuantizer(exception.get("bitwidth", 16),
+ 'nearest',
+ QuantScheme.post_training_tf,
+ exception.get("asymmetric", True) is False,
+ enabled_by_default=True,
+ data_type=QuantizationDataType.float if exception.get("data_type", "int") == "float" else QuantizationDataType.int
+ ).realize()
+
+ if "bitwidth" in exception.keys():
+ if isinstance(quantizer, FloatQuantizeDequantize):
+ assert exception['bitwidth'] in [8, 16], "Bitwidth for FloatQuantizeDequantize should only be 8 or 16"
+ # For quantizers with float dtype, we can't set value to the property "bitwidth" directly,
+ # we should set the "exponent_bits" and "mantissa_bits" correspondingly
+ quantizer.exponent_bits = 5 if exception['bitwidth'] == 16 else 4
+ quantizer.mantissa_bits = 10 if exception['bitwidth'] == 16 else 3
+ else:
+ quantizer.bitwidth = exception['bitwidth']
+
+ if "asymmetric" in exception.keys():
+ quantizer._symmetric = not exception["asymmetric"]
+ quantizer._signed = not exception["asymmetric"]
+ assert (quantizer._symmetric is True and quantizer._signed is True) or (quantizer._symmetric is False and quantizer._signed is False), "symmetric and signed must be aligned (True or False)"
+
+ if "encoding_overrides" in exception.keys():
+ encodings = exception['encoding_overrides']
+ assert isinstance(encodings, dict)
+
+ if not "is_symmetric" in encodings.keys():
+ assert hasattr(quantizer, "_symmetric"), f"Quantizer {quantizer} doesn't have attribute \"_symmetric\""
+ # set_legacy_encodings function expects the value of is_symmetric to be a string
+ encodings["is_symmetric"] = str(quantizer._symmetric)
+
+ if not "bitwidth" in encodings.keys():
+ assert hasattr(quantizer, "bitwidth"), f"Quantizer {quantizer} doesn't have attribute \"bitwidth\""
+ encodings["bitwidth"] = quantizer.bitwidth
+
+ quantizer.set_legacy_encodings([encodings])
+ quantizer._allow_overwrite = False
+
+ return quantizer
+
+def apply_param_exception(quantizer, exception):
+ assert quantizer
+ assert isinstance(exception, dict), f"Exception is not a dictionary type: {exception}"
+
+ if exception.get("enabled", True) is False:
+ quantizer._allow_overwrite = False
+
+ quantizer.symmetric = exception.get("asymmetric", False) is False
+ quantizer.signed = exception.get("asymmetric", False) is False
+ assert (quantizer.symmetric is True and quantizer.signed is True) or (quantizer.symmetric is False and quantizer.signed is False), "symmetric and signed must be aligned (True or False)"
+ quantizer.bitwidth = exception.get("bitwidth", 4)
+
+ return quantizer
+
+def get_module_exception(module, exception_dict):
+ if type(module).__name__ in exception_dict.keys():
+ return exception_dict[module._get_name()]
+
+ return None
+
+
+class ManualQuantsimMixedPrecisionConfig:
+ def __init__(self, mixed_precision_config_file):
+ exception_types = ("name", "module")
+ exceptions = {k: defaultdict(lambda: nop_exception) for k in exception_types}
+
+ with open(mixed_precision_config_file) as f:
+ exception_config = json.load(f)
+
+ # Populate op_list here
+ for etype in exception_types:
+ for item in exception_config[f'{etype}_list']:
+ expections_str = {k: v for k, v in item['exceptions'].items() if v is not None}
+ print(f"Applying {item['module_name']}:\t{expections_str}")
+ exceptions[etype].update({item['module_name']: item['exceptions']})
+
+ self.exceptions_dict = exceptions
+
+ def apply_exceptions(self, quant_sim):
+ for etype in ("module", "name"):
+ exception_modules = self.exceptions_dict[etype].keys()
+
+ for name, module in quant_sim.model.named_modules():
+ if isinstance(module, BaseQuantizationMixin):
+ exception = None
+ if etype == 'module':
+ exception = get_module_exception(module, self.exceptions_dict[etype])
+ elif etype == 'name':
+ for key in exception_modules:
+ if "*" in key and key.replace("*", "") in name:
+ exception = self.exceptions_dict[etype][key]
+ else:
+ match = re.fullmatch(key, name)
+ if match:
+ exception = self.exceptions_dict[etype][key]
+
+ if exception is not None:
+ if exception["param_exceptions"] is not None:
+ self.apply_param_exception_to_module(exception["param_exceptions"], module)
+
+ if exception["input_exceptions"] is not None:
+ self.apply_input_exception_to_module(exception["input_exceptions"], module)
+
+ if exception["output_exceptions"] is not None:
+ self.apply_output_exception_to_module(exception["output_exceptions"], module)
+
+ def apply_param_exception_to_module(self, param_exceptions, module):
+ is_enable = (int(param_exceptions["bitwidth"]) < 32) if "bitwidth" in param_exceptions else True
+ if not is_enable:
+ param_exceptions["enabled"] = is_enable
+
+ module.param_quantizers['weight'] = apply_param_exception(module.param_quantizers['weight'], param_exceptions)
+
+ def apply_input_exception_to_module(self, input_exceptions, module):
+ for index in range(len(input_exceptions)):
+ module.input_quantizers[index] = apply_input_output_exception(module.input_quantizers[index], input_exceptions[index])
+
+ def apply_output_exception_to_module(self, output_exceptions, module):
+ for index in range(len(output_exceptions)):
+ module.output_quantizers[index] = apply_input_output_exception(module.output_quantizers[index], output_exceptions[index])
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/qcphi4_adaptation.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/qcphi4_adaptation.py
new file mode 100644
index 000000000..b61270f22
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/qcphi4_adaptation.py
@@ -0,0 +1,506 @@
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+#
+# This file contains certain notices of software components included with the
+# software that Qualcomm Technologies, Inc. ("QTI") is required to provide you.
+# Except where prohibited by the open source license, the content of this file is
+# provided solely to satisfy QTI's attribution and notice requirement; your use of
+# these software components together with the QTI software ("Software") is subject
+# to the terms of your license from QTI. Compliance with all copyright laws and
+# software license agreements included in the notice section of this file are the
+# responsibility of the user. Except as may be granted by separate express written
+# agreement, this file provides no license to any patents, trademarks, copyrights,
+# or other intellectual property of Qualcomm Incorporated or any of its
+# subsidiaries.
+#
+# Software provided with this notice is NOT A CONTRIBUTION to any open source
+# project. If alternative licensing is available for any of the components with
+# licenses or attributions provided below, a license choice is made for receiving
+# such code by QTI.
+
+# Copyright (c) 2023 Qualcomm Technologies, Inc. All rights reserved.
+
+# Qualcomm is a trademark of Qualcomm Incorporated, registered in the United
+# States and other countries. All Qualcomm Incorporated trademarks are used with
+# permission. Other products and brand names may be trademarks or registered
+# trademarks of their respective owners.
+#
+# ==============================================================================
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =============================================================================
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from torch import nn
+
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.models.phi3.modeling_phi3 import (
+ repeat_kv,
+ Cache,
+ DynamicCache,
+ Phi3Attention,
+ Phi3Config,
+ apply_rotary_pos_emb,
+ Phi3ForCausalLM,
+ Phi3MLP,
+)
+
+def _apply_rope_single(x, rope_vals: Tuple[torch.Tensor, torch.Tensor]):
+ '''
+ Based on FacebookResearch's llama, provided by Carl
+ '''
+ rope_real = rope_vals[0] # shape should be 1, 1, seqlen, head_dim * partial_rotary_factor
+ rope_im = rope_vals[1] # shape should be 1, 1, seqlen, head_dim * partial_rotary_factor
+
+ # TODO: Why HF uses different coordinates from the paper
+ x_real = x[...,:x.shape[-1]//2] # extract first half elements
+ x_im = x[...,x.shape[-1]//2:] # extract second half elements
+
+ x_prod_real = x_real * rope_real - x_im * rope_im
+ x_prod_im = x_real * rope_im + x_im * rope_real
+
+ # TODO: HF need to uses different interleaving
+ x = torch.cat((x_prod_real,x_prod_im),dim=3).view(*x.shape)
+ return x
+ # return x_prod_real, x_prod_im
+
+def bypass_Phi4RotaryEmbedding(self, x, position_ids, *args, **kwargs):
+ return position_ids
+
+class QcPhi4Attention(Phi3Attention):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
+ #QC
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+ partial_rotary_factor = self.config.partial_rotary_factor if hasattr(self.config, 'partial_rotary_factor') else 0.75
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.config.num_attention_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ query_states = query_states.view(bsz, q_len, self.config.num_attention_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Partial rotary embedding
+ rope_dim = int(partial_rotary_factor * self.head_dim)
+ query_rot, query_pass = (
+ query_states[..., : rope_dim],
+ query_states[..., rope_dim :],
+ )
+ key_rot, key_pass = (
+ key_states[..., : rope_dim],
+ key_states[..., rope_dim :],
+ )
+
+ if isinstance(position_ids, (tuple, list)): # QC
+ rope_embedding = position_ids
+ cos, sin = rope_embedding
+ query_rot = _apply_rope_single(query_rot, rope_embedding)
+ key_rot = _apply_rope_single(key_rot, rope_embedding)
+ else:
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
+
+ # [batch_size, seq_length, num_heads, head_dim]
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if transposed_key_cache: # QC
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_value is not None:
+ assert isinstance(past_key_value, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache: # QC
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+ def prepare_conv(self):
+ if not hasattr(self, 'forward_no_conv'):
+ self.q_proj_conv = nn.Conv2d(self.config.hidden_size, self.config.num_attention_heads * self.head_dim, 1, bias=False)
+ self.k_proj_conv = nn.Conv2d(self.config.hidden_size, self.num_key_value_heads * self.head_dim, 1, bias=False)
+ self.v_proj_conv = nn.Conv2d(self.config.hidden_size, self.num_key_value_heads * self.head_dim, 1, bias=False)
+ self.o_proj_conv = nn.Conv2d(self.config.num_attention_heads * self.head_dim, self.config.hidden_size, 1, bias=False)
+
+ self.forward_no_conv = self.forward
+ self.forward = self.forward_conv
+
+ query_pos = self.config.num_attention_heads * self.head_dim
+ kv_dim = self.num_key_value_heads * self.head_dim
+ self.q_proj_conv.weight.data.copy_(self.qkv_proj.weight[:query_pos, :, None, None])
+ self.k_proj_conv.weight.data.copy_(self.qkv_proj.weight[query_pos : query_pos + kv_dim, :, None, None])
+ self.v_proj_conv.weight.data.copy_(self.qkv_proj.weight[query_pos + kv_dim :, :, None, None])
+ self.o_proj_conv.weight.data.copy_(self.o_proj.weight[:, :, None, None])
+
+ del self.qkv_proj
+
+ def forward_conv(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
+ #QC
+ return_new_key_value_only = self.config.return_new_key_value_only if hasattr(self.config, 'return_new_key_value_only') else False
+ transposed_key_cache = self.config.transposed_key_cache if hasattr(self.config, 'transposed_key_cache') else False
+ partial_rotary_factor = self.config.partial_rotary_factor if hasattr(self.config, 'partial_rotary_factor') else 0.75
+
+ bsz, q_len, _ = hidden_states.size()
+
+ hidden_states = torch.reshape(hidden_states, (bsz, q_len, 1, self.config.hidden_size)).transpose(1, 3)
+
+ query_states = self.q_proj_conv(hidden_states)
+ key_states = self.k_proj_conv(hidden_states)
+ value_states = self.v_proj_conv(hidden_states)
+
+ query_states = query_states.reshape(bsz, self.config.num_attention_heads, self.head_dim, q_len).transpose(2, 3)
+ key_states = key_states.reshape(bsz, self.num_key_value_heads, self.head_dim, q_len).transpose(2, 3)
+ value_states = value_states.reshape(bsz, self.num_key_value_heads, self.head_dim, q_len).transpose(2, 3)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Partial rotary embedding
+ rope_dim = int(partial_rotary_factor * self.head_dim)
+ query_rot, query_pass = (
+ query_states[..., : rope_dim],
+ query_states[..., rope_dim :],
+ )
+ key_rot, key_pass = (
+ key_states[..., : rope_dim],
+ key_states[..., rope_dim :],
+ )
+
+ if isinstance(position_ids, (tuple, list)): # QC
+ rope_embedding = position_ids
+ cos, sin = rope_embedding
+ query_rot = _apply_rope_single(query_rot, rope_embedding)
+ key_rot = _apply_rope_single(key_rot, rope_embedding)
+ else:
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
+
+ # [batch_size, seq_length, num_heads, head_dim]
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if transposed_key_cache:
+ key_states = key_states.transpose(2, 3)
+
+ if past_key_value is not None:
+ assert isinstance(past_key_value, DynamicCache)
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position,
+ "return_new_key_value_only": return_new_key_value_only,
+ "transposed_key_cache": transposed_key_cache,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if transposed_key_cache: # QC
+ attn_weights = torch.matmul(query_states, key_states) / math.sqrt(self.head_dim)
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.config.num_attention_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.config.num_attention_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, 1, self.config.hidden_size)
+ attn_output = attn_output.transpose(1, 3)
+ attn_output = self.o_proj_conv(attn_output)
+ attn_output = attn_output.transpose(1, 3)
+ attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def bypass_update_causal_mask(self, attention_mask, *args, **kwargs):
+ # attention_mask is Causal mask and given as model input
+ return attention_mask
+
+
+def Phi4MLP_prepare_conv(self):
+ if not hasattr(self, 'forward_linear'):
+ self.gate_proj_conv = nn.Conv2d(self.config.hidden_size, self.config.intermediate_size, 1, bias=False)
+ self.up_proj_conv = nn.Conv2d(self.config.hidden_size, self.config.intermediate_size, 1, bias=False)
+ self.down_proj_conv = nn.Conv2d(self.config.intermediate_size, self.config.hidden_size, 1, bias=False)
+ self.forward_linear = self.forward
+ self.forward = self.forward_conv
+
+ self.gate_proj_conv.weight.data.copy_(self.gate_up_proj.weight[:self.config.intermediate_size, :, None, None])
+ self.up_proj_conv.weight.data.copy_(self.gate_up_proj.weight[self.config.intermediate_size:, :, None, None])
+ self.down_proj_conv.weight.data.copy_(self.down_proj.weight[:, :, None, None])
+
+ del self.gate_up_proj
+ del self.down_proj
+
+def Phi4MLP_forward_conv(self, x):
+ bsz, _, _ = x.size()
+ x = torch.reshape(x, (bsz, -1, 1, self.config.hidden_size))
+ x = x.transpose(1,3) # Transpose right before and after Conv
+ x = self.down_proj_conv(self.activation_fn(self.gate_proj_conv(x)) * self.up_proj_conv(x))
+ x = x.transpose(1,3)
+ x = torch.reshape(x, (bsz, -1, self.config.hidden_size))
+ return x
+
+def Phi4ForCausalLM_prepare_conv(self):
+ if not hasattr(self, 'lm_head_conv'):
+
+ def lm_head_conv_forward(x):
+ bsz, _, _ = x.size()
+ x = torch.reshape(x, (bsz, -1, 1, self.config.hidden_size))
+ x = x.transpose(1,3) # Transpose right before and after Conv
+ x = self.lm_head_conv(x)
+ x = x.transpose(1,3)
+ x = torch.reshape(x, (bsz, -1, self.config.vocab_size))
+ return x
+
+ self.lm_head_conv = nn.Conv2d(self.config.hidden_size, self.config.vocab_size, 1, bias=False)
+ self.lm_head_conv.weight.data.copy_(self.lm_head.weight[:, :, None, None])
+
+ del self.lm_head
+ self.lm_head = lm_head_conv_forward
+
+def Phi4ForCausalLM_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **loss_kwargs,
+) -> Union[Tuple, CausalLMOutputWithPast]:
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def DynamicCache_update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the number of seen tokens
+ if layer_idx == 0:
+ self._seen_tokens += value_states.shape[-2]
+
+ # Update the cache
+ if len(self.key_cache) <= layer_idx:
+ self.key_cache.append(key_states)
+ self.value_cache.append(value_states)
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+ else:
+ return_new_key_value_only = cache_kwargs.get('return_new_key_value_only', False)
+ transposed_key_cache = cache_kwargs.get('transposed_key_cache', False)
+ key_cat_dim = -1 if transposed_key_cache else -2
+
+ key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=key_cat_dim)
+ value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ if return_new_key_value_only:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = key_cache
+ self.value_cache[layer_idx] = key_cache
+ return key_cache, value_cache
+
+
+def DynamicCache_get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # TODO: deprecate this function in favor of `cache_position`
+ # if len(self.value_cache) <= layer_idx:
+ # return 0
+ # return self.value_cache[layer_idx].shape[-2]
+
+ """
+ Replacement for DynamicCache.get_seq_length for Transformers 4.56.0+
+
+ - Does NOT use value_cache / key_cache
+ - Does NOT call the original get_seq_length (avoids recursion)
+ - Uses cache_position, the canonical source of truth
+ """
+
+ cache_position = getattr(self, "cache_position", None)
+
+ if cache_position is None:
+ return 0
+
+ # cache_position = tensor([0, 1, ..., seq_len - 1])
+ if cache_position.numel() > 0:
+ return int(cache_position[-1].item() + 1)
+
+ return 0
+
+
+
+def update_attr(cls, attr_name, new_attr):
+ attr_backup_name = f'_original_{attr_name}'
+ if hasattr(cls, attr_name):
+ if not hasattr(cls, attr_backup_name):
+ setattr(cls, attr_backup_name, getattr(cls, attr_name))
+ setattr(cls, attr_name, new_attr)
+ return True
+ return False
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/test_vectors.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/test_vectors.py
new file mode 100644
index 000000000..b7a2f8508
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/test_vectors.py
@@ -0,0 +1,276 @@
+#!/usr/bin/env python3
+# ==============================================================================
+# @@-COPYRIGHT-START-@@
+#
+# Copyright 2024 Qualcomm Technologies, Inc. All rights reserved.
+# Confidential & Proprietary - Qualcomm Technologies, Inc. ("QTI")
+#
+# The party receiving this software directly from QTI (the "Recipient")
+# may use this software as reasonably necessary solely for the purposes
+# set forth in the agreement between the Recipient and QTI (the
+# "Agreement"). The software may be used in source code form solely by
+# the Recipient's employees (if any) authorized by the Agreement. Unless
+# expressly authorized in the Agreement, the Recipient may not sublicense,
+# assign, transfer or otherwise provide the source code to any third
+# party. Qualcomm Technologies, Inc. retains all ownership rights in and
+# to the software
+#
+# This notice supersedes any other QTI notices contained within the software
+# except copyright notices indicating different years of publication for
+# different portions of the software. This notice does not supersede the
+# application of any third party copyright notice to that third party's
+# code.
+#
+# @@-COPYRIGHT-END-@@
+# ==============================================================================
+""" utils to generate test vectors """
+from typing import Tuple, Union, List, Dict
+import re
+import os
+
+import numpy as np
+import torch.nn
+import torch
+import pickle
+import contextlib
+
+from tqdm import tqdm
+from aimet_torch.utils import change_tensor_device_placement, nested_map
+
+from aimet_torch.utils import in_eval_mode, is_leaf_module
+from aimet_torch.onnx_utils import OnnxExportApiArgs
+from aimet_torch.layer_output_utils import LayerOutput, LayerOutputUtil, NamingScheme
+
+from aimet_torch.v2.nn.base import BaseQuantizationMixin
+from aimet_torch.v2.quantization import QuantizedTensorBase
+from aimet_torch.quantsim import ExportableQuantModule
+
+MODULE_TYPE_FOR_ATTACHING_HOOK = (ExportableQuantModule,)
+modules_to_treat_as_leaf = []
+
+def to_torch_tensor(t):
+ """ utilty to move test vectors from DequantizedTensor to torch.Tensor """
+ return nested_map(t, lambda x: torch.tensor(x) if isinstance(x, QuantizedTensorBase) else x)
+
+def to_cpu(t):
+ return change_tensor_device_placement(t, torch.device('cpu'))
+
+def quantizers_state(sim, disabled) -> contextlib.ExitStack:
+ exit_stack = contextlib.ExitStack()
+ if disabled:
+ for _, module in sim.model.named_modules():
+ if isinstance(module, BaseQuantizationMixin):
+ exit_stack.enter_context(module._remove_all_quantizers())
+ return exit_stack
+
+def run_hook_for_layers_with_given_input_get_output(model: torch.nn.Module,
+ input_tensor: Union[torch.Tensor, Tuple, Dict], hook,
+ module_type_for_attaching_hook=None, module_regex_to_include=None,
+ leaf_node_only=True, fwd_func=None):
+ """
+ Register the given hook function for all layers in the model
+ :param model: Model
+ :param input_tensor: Input tensor to the model. If more than one model inputs, use a tuple
+ :param hook: Hook function to register
+ :param module_type_for_attaching_hook: Tuple of torch.nn module types for which hook has to be attached
+ :param leaf_node_only: Set to False if all modules are required
+ :param fwd_func: forward function for model inference
+ :return: None
+ """
+ # ------------------------
+ # Register hook function
+ # ------------------------
+ hooks = []
+ # All leaf modules
+ modules = []
+
+ # Based on the modules in modules_to_treat_as_leaf, we do not want to further continue searching for next level
+ # of modules present in modules_to_treat_as_leaf. To achieve this, save them in modules_to_skip
+ modules_to_skip = set()
+
+ if module_regex_to_include:
+ patterns = [re.compile(pattern) for pattern in module_regex_to_include]
+ name_match_modules = [module for name, module in model.named_modules() if any (re.match(pattern, name) for pattern in patterns)]
+ else:
+ name_match_modules = model.modules()
+
+ for module in name_match_modules:
+ if module not in modules_to_skip:
+ # pylint: disable=protected-access
+ if isinstance(module, tuple(modules_to_treat_as_leaf)):
+ modules.append(module)
+ # check for modules inside the 'module' and add them to modules_to_skip
+ for sub_module in module._modules.values():
+ modules_to_skip.add(sub_module)
+ else:
+ if leaf_node_only:
+ if is_leaf_module(module):
+ modules.append(module)
+ else:
+ modules.append(module)
+
+ if module_type_for_attaching_hook:
+ # if needed, filter by module types specified by caller
+ modules = [module for module in modules if isinstance(module, module_type_for_attaching_hook)]
+
+ try:
+ for module in modules:
+ hooks.append(module.register_forward_hook(hook))
+
+ # ------------------------------------------------
+ # Run forward pass to execute the hook functions
+ # ------------------------------------------------
+ with in_eval_mode(model), torch.no_grad():
+ if fwd_func:
+ output = fwd_func(model, input_tensor)
+ else:
+ if isinstance(input_tensor, (list, tuple)):
+ output = model(*input_tensor)
+ elif isinstance(input_tensor, dict):
+ output = model(**input_tensor)
+ else:
+ output = model(input_tensor)
+
+ finally:
+ # --------------------------
+ # Remove all hooks we added
+ # --------------------------
+ for h in hooks:
+ h.remove()
+
+ return output
+
+
+class LLMLayerOutput(LayerOutput):
+ def __init__(self, model: torch.nn.Module, dir_path: str, naming_scheme: NamingScheme = NamingScheme.PYTORCH,
+ dummy_input = None, onnx_export_args: Union[OnnxExportApiArgs, Dict] = None, regex_patterns = None):
+ super().__init__(model, dir_path, naming_scheme, dummy_input, onnx_export_args)
+ self.regex_patterns = regex_patterns
+
+ def record_outputs(self, module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
+ """
+ Hook function to capture output of a layer.
+
+ :param module: Layer-module in consideration.
+ :param input: Input of the layer-module.
+ :param output: Output of the layer-module.
+ :return: None
+ """
+ layer_name = self.module_to_name_dict[module]
+ self.layer_name_to_layer_output_dict[layer_name] = {"input": to_cpu(input[0]), "output": to_cpu(output)}
+
+ def get_outputs(self, input_batch) -> Dict[str, torch.Tensor]:
+ """
+ This function captures layer-outputs and renames them as per the AIMET exported pytorch/onnx/torchscript model.
+
+ :param input_batch: Batch of inputs for which we want to obtain layer-outputs.
+ :return: layer-name to layer-output batch dict
+ """
+
+ # Fetch outputs of all the layers
+ self.layer_name_to_layer_output_dict = {}
+ if self.is_quantsim_model:
+ # Apply record-output hook to QuantizeWrapper modules (one node above leaf node in model graph)
+ model_output = run_hook_for_layers_with_given_input_get_output(self.model, input_batch, self.record_outputs,
+ module_type_for_attaching_hook=MODULE_TYPE_FOR_ATTACHING_HOOK,
+ leaf_node_only=False, module_regex_to_include=self.regex_patterns)
+ else:
+ # Apply record-output hook to Original modules (leaf node in model graph)
+ model_output = run_hook_for_layers_with_given_input_get_output(self.model, input_batch, self.record_outputs,
+ leaf_node_only=True, module_regex_to_include=self.regex_patterns)
+
+ # Rename outputs according to pytorch/onnx/torchscript model
+ layer_output_name_to_layer_output_dict = LayerOutput.rename_layer_outputs(self.layer_name_to_layer_output_dict,
+ self.layer_name_to_layer_output_name_dict)
+
+ return layer_output_name_to_layer_output_dict, model_output
+
+
+class LLMLayerOutputUtil(LayerOutputUtil):
+ def __init__(self, model: torch.nn.Module, dir_path: str, file_prefix: str, naming_scheme: NamingScheme = NamingScheme.PYTORCH,
+ dummy_input = None, onnx_export_args: Union[OnnxExportApiArgs, Dict] = None, regex_patterns = None):
+ """
+ Constructor for LayerOutputUtil.
+
+ :param model: Model whose layer-outputs are needed.
+ :param dir_path: Directory wherein layer-outputs will be saved.
+ :param naming_scheme: Naming scheme to be followed to name layer-outputs. There are multiple schemes as per
+ the exported model (pytorch, onnx or torchscript). Refer the NamingScheme enum definition.
+ :param dummy_input: Dummy input to model. Required if naming_scheme is 'NamingScheme.ONNX' or 'NamingScheme.TORCHSCRIPT'.
+ :param onnx_export_args: Should be same as that passed to quantsim export API to have consistency between
+ layer-output names present in exported onnx model and generated layer-outputs. Required if naming_scheme is
+ 'NamingScheme.ONNX'.
+ """
+ super().__init__(model, dir_path, naming_scheme, dummy_input, onnx_export_args)
+ self.output_dir = dir_path
+ self.file_prefix = file_prefix
+
+ # Utility to capture layer-outputs
+ self.layer_output = LLMLayerOutput(model=model, naming_scheme=naming_scheme, dir_path=dir_path, dummy_input=dummy_input,
+ onnx_export_args=onnx_export_args, regex_patterns=regex_patterns)
+
+ def generate_layer_outputs(self, input_batch, batch_idx):
+ """
+ This method captures output of every layer of a model & saves the inputs and corresponding layer-outputs to disk.
+
+ :param input_batch: Batch of inputs for which we want to obtain layer-outputs.
+ :return: None
+ """
+
+ # Obtain layer-output name to output dictionary
+ layer_output_batch_dict, model_outputs = self.layer_output.get_outputs(input_batch)
+
+ test_vectors = {f"{batch_idx}": {**to_cpu(to_torch_tensor(input_batch)),
+ **to_cpu(to_torch_tensor(layer_output_batch_dict))}}
+
+ assert os.path.exists(self.output_dir), "output_dir for test vectors doesn't exist"
+
+ for key, value in test_vectors.items():
+ filename = os.path.join(self.output_dir, self.file_prefix + f"_{batch_idx}.pkl")
+ with open(filename, 'wb') as file:
+ pickle.dump({key: value}, file)
+
+ return model_outputs
+
+def generate_test_vectors(sim, forward_pass_manager, data_loader, output_dir, num_batches, test_vector_layers, input_names):
+ vector_output_dir = os.path.join(output_dir, "test_vectors")
+ os.makedirs(vector_output_dir, exist_ok=True)
+
+ def _sanitize_and_update_test_vectors(test_vectors, test_outputs):
+ if "past_key_values" in test_outputs:
+ test_outputs["output_key_values"] = test_outputs.pop("past_key_values")
+
+ test_vectors.update(to_cpu(test_outputs))
+
+ if "lm_logits" in test_vectors:
+ test_vectors["logits"] = test_vectors.pop("lm_logits")
+
+ past_key_values = []
+ for i in range(forward_pass_manager.num_layers):
+ past_key = test_vectors.pop(f"past_key_{i}_in")
+ past_val = test_vectors.pop(f"past_value_{i}_in")
+ past_key_values.append([past_key, past_val])
+ test_vectors["past_key_values"] = past_key_values
+
+ for idx, batch in enumerate(tqdm(data_loader, total=num_batches, desc="Test vector generation")):
+ if idx >= num_batches:
+ break
+ for vector_type in ['fp', 'qt']:
+
+ recorder = LLMLayerOutputUtil(forward_pass_manager.model, dir_path=vector_output_dir,
+ file_prefix=vector_type, regex_patterns=test_vector_layers)
+
+ with quantizers_state(sim, disabled=(vector_type == 'fp')):
+ prepared_inputs, kvcache_info_bundle = forward_pass_manager.prepare_inputs(
+ input_ids=batch['input_ids'][..., :forward_pass_manager.num_tokens], input_names= input_names)
+ outputs = recorder.generate_layer_outputs(prepared_inputs, idx)
+ outputs = forward_pass_manager.prepare_outputs(outputs, prepared_inputs, kvcache_info_bundle)
+
+ filename = os.path.join(vector_output_dir, f"{vector_type}_{idx}.pkl")
+ test_vector_dict = np.load(filename, allow_pickle=True)
+
+ _sanitize_and_update_test_vectors(test_vector_dict[f"{idx}"], outputs)
+ test_vector_dict = to_cpu(to_torch_tensor(test_vector_dict))
+
+ with open(filename, 'wb') as file:
+ pickle.dump(test_vector_dict, file)
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/wikitext_dataloader.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/wikitext_dataloader.py
new file mode 100644
index 000000000..f55470ca5
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/wikitext_dataloader.py
@@ -0,0 +1,200 @@
+# -*- mode: python -*-
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+#
+# This file contains certain notices of software components included with the
+# software that Qualcomm Technologies, Inc. ("QTI") is required to provide you.
+# Except where prohibited by the open source license, the content of this file is
+# provided solely to satisfy QTI's attribution and notice requirement; your use of
+# these software components together with the QTI software ("Software") is subject
+# to the terms of your license from QTI. Compliance with all copyright laws and
+# software license agreements included in the notice section of this file are the
+# responsibility of the user. Except as may be granted by separate express written
+# agreement, this file provides no license to any patents, trademarks, copyrights,
+# or other intellectual property of Qualcomm Incorporated or any of its
+# subsidiaries.
+#
+# Software provided with this notice is NOT A CONTRIBUTION to any open source
+# project. If alternative licensing is available for any of the components with
+# licenses or attributions provided below, a license choice is made for receiving
+# such code by QTI.
+
+# Copyright (c) 2023 Qualcomm Technologies, Inc. All rights reserved.
+
+# Qualcomm is a trademark of Qualcomm Incorporated, registered in the United
+# States and other countries. All Qualcomm Incorporated trademarks are used with
+# permission. Other products and brand names may be trademarks or registered
+# trademarks of their respective owners.
+#
+# =============================================================================
+# @@-COPYRIGHT-START-@@
+#
+# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# 1. Redistributions of source code must retain the above copyright notice,
+# this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright notice,
+# this list of conditions and the following disclaimer in the documentation
+# and/or other materials provided with the distribution.
+#
+# 3. Neither the name of the copyright holder nor the names of its contributors
+# may be used to endorse or promote products derived from this software
+# without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+#
+# SPDX-License-Identifier: BSD-3-Clause
+#
+# @@-COPYRIGHT-END-@@
+# =============================================================================
+""" utility method to evaluate perplexity score on WikiText """
+from itertools import chain
+from torch.utils.data import DataLoader, Dataset
+from datasets import IterableDataset, load_dataset
+from transformers import default_data_collator
+
+
+class CustomDataset(Dataset):
+ """
+ Dataset for GPTQ-preprocessed tokens
+ """
+ def __init__(self, tokens, block_size=2048):
+ self.full_tokens = tokens
+ self.block_size = block_size
+ self._len = len(tokens["input_ids"][0]) // block_size
+
+ def __len__(self):
+ return self._len
+
+ def __getitem__(self, idx):
+ start_idx = idx * self.block_size
+ end_idx = (idx+1) * self.block_size
+
+ input_ids = self.full_tokens["input_ids"][0, start_idx:end_idx]
+ labels = input_ids.clone()
+ attn_mask = self.full_tokens["attention_mask"][0, start_idx:end_idx]
+ output = {"input_ids": input_ids,
+ "attention_mask": attn_mask,
+ "labels": labels}
+ return output
+
+def _get_column_names(dataset):
+ if hasattr(dataset, "column_names"):
+ return dataset.column_names
+ else:
+ return next(iter(dataset.take(1))).keys()
+
+
+def get_column_name(dataset):
+ column_names = _get_column_names(dataset)
+ if "text" in column_names:
+ return "text"
+ else:
+ return column_names[0]
+
+class PreprocessGptqSplit:
+ def __init__(self, tokenizer, block_size, add_special_tokens=True):
+ self._tokenizer = tokenizer
+ self._block_size = block_size
+ self._add_special_tokens = add_special_tokens
+
+ def preprocess(self, dataset):
+ column_name = get_column_name(dataset)
+ tokens = self._tokenizer("\n\n".join(dataset[column_name]), return_tensors="pt", add_special_tokens=self._add_special_tokens)
+ dataset = CustomDataset(tokens, self._block_size)
+ return dataset
+
+class PreprocessSplit:
+ def __init__(self, tokenizer, block_size, column_name="text", add_special_tokens=True):
+ self._tokenizer = tokenizer
+ self._block_size = block_size
+ self._column_name = column_name
+ self._add_special_tokens = add_special_tokens
+
+ def _tokenize_fn(self, examples):
+ return self._tokenizer(examples[self._column_name], return_token_type_ids=False, add_special_tokens=self._add_special_tokens)
+
+ # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
+ def _group_texts(self, examples):
+ # Concatenate all texts.
+ concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
+ # customize this part to your needs.
+ if total_length >= self._block_size:
+ total_length = (total_length // self._block_size) * self._block_size
+ # Split by chunks of max_len.
+ result = {
+ k: [t[i: i + self._block_size] for i in range(0, total_length, self._block_size)]
+ for k, t in concatenated_examples.items()
+ }
+ result["labels"] = result["input_ids"].copy()
+ return result
+
+ def preprocess(self, dataset):
+ map_kwargs = {
+ "num_proc": None,
+ "load_from_cache_file": True,
+ "desc": "Running tokenizer on dataset",
+ }
+
+ tokenized_dataset = dataset.map(
+ self._tokenize_fn,
+ batched=True,
+ remove_columns=dataset.column_names,
+ **(map_kwargs if not isinstance(dataset, IterableDataset) else {}),
+ )
+
+ # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
+ # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
+ # to preprocess.
+ #
+ # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
+ print(f"[Load Dataset]grouped train dataset")
+ map_kwargs["desc"] = f"Grouping texts in chunks of {self._block_size}"
+ dataset = tokenized_dataset.map(
+ self._group_texts,
+ batched=True,
+ **(map_kwargs if not isinstance(dataset, IterableDataset) else {}),
+ )
+
+ return dataset
+
+def get_wiki_dataset(block_size, tokenizer, cache_dir):
+ dataset = {}
+ dataset['train'] = load_dataset(path='wikitext',
+ name='wikitext-2-raw-v1',
+ cache_dir=cache_dir,
+ split='train')
+ dataset['train'] = PreprocessSplit(tokenizer, block_size).preprocess(dataset['train'])
+
+ dataset['test'] = load_dataset(path='wikitext',
+ name='wikitext-2-raw-v1',
+ cache_dir=cache_dir,
+ split='test')
+ dataset['test'] = PreprocessGptqSplit(tokenizer, block_size).preprocess(dataset['test'])
+
+ train_dataloader = DataLoader(dataset['train'], shuffle=False, batch_size=1, collate_fn=default_data_collator)
+ test_dataloader = DataLoader(dataset['test'], shuffle=False, batch_size=1, collate_fn=default_data_collator)
+
+ return train_dataloader, test_dataloader, dataset
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/xlam_dataloader.py b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/xlam_dataloader.py
new file mode 100644
index 000000000..be832e1fc
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/llm_utils/xlam_dataloader.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python3
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+
+import json
+from itertools import chain
+
+from datasets import load_dataset, Dataset
+from transformers import default_data_collator
+from torch.utils.data import DataLoader
+import copy
+
+class xLAMDataset():
+ def __init__(self, tokenizer, block_size, batch_size, return_train_dataset=True, return_test_dataset=False):
+ self._tokenizer = copy.deepcopy(tokenizer)
+ self._block_size = block_size
+ self._batch_size = batch_size
+ self._return_train_dataset = return_train_dataset
+ self._return_test_dataset = return_test_dataset
+
+ def get_xlam_dataloader(self, path):
+ tokenizer = self._tokenizer
+ block_size = self._block_size
+ batch_size = self._batch_size
+
+ dataset = load_dataset(path)
+
+ # Process both train and test data
+ processed_datasets = {}
+ for split, data in dataset.items():
+ split_data = {'input_ids': [], 'attention_mask': []}
+ for row_tmp in data:
+ prompt, answer = self.process_xLAM_processing(row_tmp)
+ if self._return_test_dataset:
+ row_text = prompt
+ if self._return_train_dataset:
+ row_text = prompt + answer + tokenizer.eos_token
+ tokenizer.pad_token = tokenizer.eos_token
+ row1 = tokenizer(row_text, return_tensors="pt", truncation=True, max_length=block_size, padding="max_length")
+ split_data['input_ids'].append(row1['input_ids'][0])
+ split_data['attention_mask'].append(row1['attention_mask'][0])
+
+ processed_datasets[split] = Dataset.from_dict(split_data)
+
+ collate_fn = default_data_collator
+
+ train_dataloader = DataLoader(
+ processed_datasets['train'], shuffle=False,
+ batch_size=batch_size,
+ collate_fn=collate_fn,
+ ) if self._return_train_dataset else None
+
+ test_dataloader = DataLoader(
+ processed_datasets['train'], shuffle=False,
+ batch_size=batch_size,
+ collate_fn=collate_fn,
+ ) if self._return_test_dataset else None
+
+ return train_dataloader, test_dataloader, processed_datasets
+
+ def process_xLAM_processing(self, row):
+ tokenizer = self._tokenizer
+
+ TEST_SYSTEM_PROMPT_FOR_CHAT_MODEL ="You are an expert in composing functions. You are given a question and a set of possible functions.\nBased on the question, you will need to make one or more function/tool calls to achieve the purpose.\nIf none of the function can be used, point it out. If the given question lacks the parameters required by the function, also point it out. You should only return the function call in tools call sections.\nIf you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]\nYou SHOULD NOT include any other text in the response.\nHere is a list of functions in JSON format that you can invoke.\n{functions}"
+
+ SYSTEM_PROMPT_FOR_CHAT_MODEL = """
+ You are an expert in composing functions. You are given a question and a set of possible functions.
+ Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
+ If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
+ also point it out. You should only return the function call in tools call sections.
+ """
+
+ USER_PROMPT_FOR_CHAT_MODEL = """
+ Questions:{user_prompt}\nHere is a list of functions in JSON format that you can invoke:\n{functions}.
+ Should you decide to return the function call(s),Put it in the format of [func1(params_name=params_value, params_name2=params_value2...), func2(params)]\n
+ NO other text MUST be included.
+ """
+ def _format_prompt(prompt, function):
+
+ if self._return_test_dataset:
+ raw_prompt = [{"role": "system", "content": TEST_SYSTEM_PROMPT_FOR_CHAT_MODEL.format(functions=function)},{"role": "user", "content": prompt}]
+ return tokenizer.apply_chat_template(raw_prompt, tokenize=False,add_generation_prompt=True)
+
+ if self._return_train_dataset:
+ conversations = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>{SYSTEM_PROMPT_FOR_CHAT_MODEL}<|eot_id|><|start_header_id|>user<|end_header_id|>{USER_PROMPT_FOR_CHAT_MODEL.format(user_prompt=prompt, functions=str(function))}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
+ return conversations
+
+ def convert_json_to_function_calls(json_str):
+ # Parse the JSON string
+ data = json.loads(json_str)
+
+ # Create the desired output format
+ output_list = []
+ for item in data:
+ func_name = item['name']
+ args = ', '.join([f"{k}={json.dumps(v)}" for k, v in item['arguments'].items()])
+ output_list.append(f"{func_name}({args})")
+
+ # Convert the list to a string
+ output_str = f"[{', '.join(output_list)}]"
+ return output_str
+
+ prompt = _format_prompt(row["query"], row["tools"])
+ answer = convert_json_to_function_calls(row['answers'])
+ return prompt, answer
\ No newline at end of file
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/phi4.py b/microsoft-Phi-4-mini-instruct/QAIRT/phi4.py
new file mode 100644
index 000000000..5ca097f22
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/phi4.py
@@ -0,0 +1,788 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+# # AIMET Quantization workflow for Phi-4-mini-instruct Context Length of 4K
+#
+# This notebook shows a working code example of how to use AIMET to quantize Phi-4-mini-instruct model.
+#
+#
+# ---
+# ### Required packages
+# The notebook assumes AIMET and Phi-4-mini-instruct related packages are already installed.
+
+# In[ ]:
+
+
+try:
+ # Required for proper Python environment configuration of qairt-dev
+ import qairt # noqa: F401 # pylint: disable=unused-import
+except ImportError as exc:
+ raise ImportError(
+ "Failed to import QAIRT SDK - please install olive-ai[qairt] to use QAIRT passes."
+ "If already installed, please run `qairt-vm -i` for help troubleshooting issues."
+ ) from exc
+
+# Guard to prevent child processes from executing the main script
+if __name__ != '__main__':
+ import sys
+ sys.exit(0)
+
+import json
+import argparse
+import sys, os
+
+# Parse command-line arguments for optional config file
+parser = argparse.ArgumentParser(
+ description='Llama 3.1 8B Instruct AdaScale + Quantization Script',
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+)
+parser.add_argument('--config', type=str, default=None,
+ help='Path to JSON configuration file')
+args, unknown = parser.parse_known_args()
+
+# Load JSON config if provided
+json_config = {}
+if args.config:
+ try:
+ with open(args.config, 'r') as f:
+ json_config = json.load(f)
+ print(f"Loaded configuration from: {args.config}")
+ except FileNotFoundError:
+ print(f"Warning: Config file not found: {args.config}")
+ except json.JSONDecodeError as e:
+ print(f"Warning: Invalid JSON in config file: {e}")
+
+def get_config_value(key, default, value_type='str'):
+ """
+ Get configuration value with 3-tier priority:
+ 1. JSON config file
+ 2. Environment variable
+ 3. Default value
+
+ Args:
+ key: Configuration key name
+ default: Default value if not found in config or environment
+ value_type: Type of value ('str', 'int', 'bool', 'none')
+
+ Returns:
+ Configuration value with appropriate type
+ """
+ # Priority 1: Check JSON config
+ if key in json_config:
+ value = json_config[key]
+ if value_type == 'bool':
+ if isinstance(value, bool):
+ return value
+ return str(value).lower() in ('true', '1', 't', 'yes')
+ elif value_type == 'int':
+ return int(value)
+ elif value_type == 'none':
+ return value
+ else: # str
+ return str(value) if value is not None else None
+
+ # Priority 2: Check environment variable
+ env_value = os.getenv(key)
+ if env_value is not None:
+ if value_type == 'bool':
+ return env_value.lower() in ('true', '1', 't')
+ elif value_type == 'int':
+ return int(env_value)
+ elif value_type == 'none':
+ return env_value
+ else: # str
+ return env_value
+
+ # Priority 3: Use default value
+ return default
+
+# ### 1.2 Setting NSP Target
+
+# In[ ]:
+
+
+sys.path.append('../')
+from utilities.nsptargets import NspTargets
+
+# setup Target platform and its generation
+TARGET_PLATFORM = get_config_value("TARGET_PLATFORM", "Windows").capitalize()
+
+# Android GEN4 and GEN5 is supported for this notebook
+PLATFORM_GEN = get_config_value("PLATFORM_GEN", 3, 'int')
+
+# Set up nsp target specification
+nsp_target = eval(f"NspTargets.{TARGET_PLATFORM}.GEN{PLATFORM_GEN}")
+
+# Select quantsim config based on target
+htp_config_file = f'{sys.prefix}/lib/python3.10/site-packages/aimet_common/quantsim_config/htp_quantsim_config_{nsp_target.dsp_arch}.json'
+
+
+
+# ### 2. Instantiate and adapt FP32 model
+#
+# #### 2.1 Instantiate adapted FP32 model definition
+
+# In[ ]:
+
+
+from tqdm import tqdm
+import torch
+
+model_name = get_config_value("MODEL_NAME", "phi4_mini_instruct")
+model_id = get_config_value("MODEL_ID", "microsoft/phi-4-mini-instruct")
+
+cache_dir = get_config_value("CACHE_DIR", './cache_dir')
+output_dir = get_config_value("OUTPUT_DIR", "./output_dir")
+os.makedirs(output_dir, exist_ok=True)
+
+onnx_dir = os.path.join(output_dir, 'base', 'onnx')
+os.makedirs(onnx_dir, exist_ok=True)
+
+#======================Configurable setting by users================================
+from transformers import AutoConfig, AutoTokenizer
+llm_config = AutoConfig.from_pretrained(model_id, cache_dir=cache_dir, trust_remote_code=True)
+# Set context length to be 2048, 4096, or 8192 here, user can change this value to ones' desire (but less than Phi4-mini' trained contex length)
+context_length= 4096
+
+# To help with debugging num_hidden_layers could be set to 2 to quickly verify the pipeline and export a two layer model for verification purposes
+llm_config.num_hidden_layers = 32
+print(f'num_layer: {llm_config.num_hidden_layers}, context_length : {context_length},'
+ f'num_hidden_size :{llm_config.num_attention_heads}, num_kv_heads: {llm_config.num_key_value_heads}')
+
+#======================Fixed setting that should not be changed by users==============
+# Auto-regression length: number of tokens to consume and number of logits to produce.
+# This value should NOT be changed due to downstream consumption requirements
+
+if 8192 == context_length:
+ ARN = 7073
+elif 4096 == context_length:
+ ARN = 2073
+elif 2048 == context_length:
+ ARN = 1073
+else:
+ ARN = 573
+
+setattr(llm_config, 'return_new_key_value_only', True)
+setattr(llm_config, 'transposed_key_cache', True)
+setattr(llm_config, 'use_combined_mask_input', True)
+setattr(llm_config, 'use_position_embedding_input', True)
+setattr(llm_config, '_attn_implementation', 'eager')
+setattr(llm_config, '_attn_implementation_internal', 'eager')
+setattr(llm_config, 'mask_neg', -7100) #-100
+setattr(llm_config, 'partial_rotary_factor', 0.75)
+
+llm_config.save_pretrained(output_dir)
+
+# #### 2.2 Adapt FP32 model definition for inference on HTP
+# The following adaptations are done to replace default attention module with attention definition that compatible with NSP backend
+# - use conv instead of linear for Q,K,V,O projections
+# - bypass attention and causal mask generation and replace with pre-generated 2D-mask input
+# - output only newly created V and transposed K instead of entire augmented KV sequence
+# - input pre-calculated positional embedding instead of position ids, thus bypass the embedding generation in the model
+
+# In[ ]:
+
+
+from transformers.models.phi3 import modeling_phi3
+# from aimet_torch.pro.utils.profiler import event_marker
+from genai_lib.common.debug.profiler import event_marker
+with event_marker('FP model'):
+ model = modeling_phi3.Phi3ForCausalLM.from_pretrained(model_id,cache_dir=cache_dir, config=llm_config)
+
+ os.environ['TOKENIZERS_PARALLELISM'] = '0'
+ tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir,) #, use_fast=True, trust_remote_code=True
+ ## Adjust the tokenizer to limit to context_length
+ tokenizer.model_max_length = context_length
+ tokenizer.save_pretrained(output_dir)
+
+# In[ ]:
+
+
+from transformers.models.phi3 import modeling_phi3
+from transformers import cache_utils, modeling_attn_mask_utils
+# from aimet_torch.pro.utils.profiler import event_marker
+from llm_utils.qcphi4_adaptation import (
+ QcPhi4Attention,
+ bypass_update_causal_mask,
+ bypass_Phi4RotaryEmbedding,
+ Phi4MLP_prepare_conv,
+ Phi4MLP_forward_conv,
+ Phi4ForCausalLM_prepare_conv,
+ Phi4ForCausalLM_forward,
+ DynamicCache_update,
+ DynamicCache_get_seq_length,
+ update_attr
+)
+
+with event_marker("FP model adaptation configuration"):
+ for layer in model.model.layers:
+ layer.self_attn.__class__ = QcPhi4Attention
+
+ # Bypass attention_mask preparation
+ assert update_attr(modeling_phi3.Phi3Model, '_update_causal_mask', bypass_update_causal_mask) or \
+ update_attr(modeling_phi3.Phi3Model, '_prepare_decoder_attention_mask', bypass_update_causal_mask), \
+ f"neither _prepare_decoder_attention_mask(..) nor _update_causal_mask(..) found, Unknown Phi3Model definition {modeling_phi3.Phi3Model}"
+
+ # Bypass rotary_emb module
+ assert update_attr(modeling_phi3.Phi3RotaryEmbedding, 'forward', bypass_Phi4RotaryEmbedding), \
+ f"Unknown RotaryEmbedding definition: {modeling_phi3.Phi3RotaryEmbedding}"
+
+ # Adaptation to use Conv instead of linear
+ setattr(modeling_phi3.Phi3MLP, 'prepare_conv', Phi4MLP_prepare_conv)
+ setattr(modeling_phi3.Phi3MLP, 'forward_conv', Phi4MLP_forward_conv)
+ setattr(modeling_phi3.Phi3ForCausalLM, 'prepare_conv', Phi4ForCausalLM_prepare_conv)
+ update_attr(modeling_phi3.Phi3ForCausalLM, 'forward', Phi4ForCausalLM_forward)
+
+ # Adapting KV$ management
+ assert update_attr(cache_utils.DynamicCache, 'update', DynamicCache_update), f"Unknown DynamicCache definition: {cache_utils.DynamicCache}"
+ assert update_attr(cache_utils.DynamicCache, 'get_seq_length', DynamicCache_get_seq_length), f"Unknown DynamicCache definition: {cache_utils.DynamicCache}"
+
+
+# ### 3. Complete the last step(s) of model adaptation
+#
+# The following model adaptation are enabled for inference:
+# - apply linear to conv in attention, MLP and lmhead and arrange linear weights properly for conv
+
+# In[ ]:
+
+
+with event_marker('FP model adaptation for NSP backend completion'):
+ for name, module in model.named_modules():
+ if hasattr(module, "prepare_conv"):
+ module.prepare_conv()
+
+
+# ### 4. Model Evaluation
+
+# In[ ]:
+
+
+from torch.nn import CrossEntropyLoss
+from llm_utils.forward_pass_wrapper import slice_inputs_and_run_successive_kvcache_inference
+
+def ppl_eval(data_loader, forward_pass_manager, num_batches=0):
+ if num_batches == 0:
+ num_batches = len(data_loader)
+ loss = 0
+
+ if llm_config.num_hidden_layers < 10:
+ num_batches = 1
+
+ for batch_id, batch in enumerate(tqdm(data_loader, total=num_batches, desc="Evaluating")):
+ if batch_id >= num_batches:
+ break
+ outputs = slice_inputs_and_run_successive_kvcache_inference(forward_pass_manager, input_ids=batch['input_ids'])
+ lm_logits = outputs["lm_logits"].cpu()
+
+ # we can either pass input_ids or input_embeds in our fpm, hence with input_embeds we pass the labels.
+ if 'input_ids' not in batch:
+ batch['input_ids'] = batch['labels']
+
+ lm_logits = lm_logits.reshape(batch['input_ids'].shape[0], -1, lm_logits.shape[-1])
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = batch['input_ids'][..., 1:].contiguous().to(shift_logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss += loss_fct(
+ shift_logits.view(-1, shift_logits.size(-1)),
+ shift_labels.view(-1),
+ )
+
+ loss = loss / num_batches
+ ppl = loss.exp()
+
+ return ppl
+
+
+# #### 4.1 FP32 PPL Eval
+
+# In[ ]:
+
+
+from llm_utils.forward_pass_wrapper import LLMForwardPassManager
+orig_fpm = LLMForwardPassManager(cfg=llm_config,
+ model=model,
+ tokenizer=tokenizer,
+ separate_tuple_input_output=False,
+ num_tokens=ARN)
+
+from llm_utils.wikitext_dataloader import get_wiki_dataset
+train_dataloader, test_dataloader, _ = get_wiki_dataset(context_length, tokenizer, cache_dir)
+
+with event_marker("FP eval"):
+ with torch.no_grad():
+ with orig_fpm.place_on_device("cuda"):
+ orig_ppl = ppl_eval(test_dataloader, orig_fpm)
+
+print(f"ppl score of original fp model: {orig_ppl}")
+
+
+# ### 5. Model Sample Input
+
+# In[ ]:
+
+
+from llm_utils.forward_pass_wrapper import get_position_embeddings_from_position_ids, prepare_combined_attention_mask, get_padded_kv_values, flatten_tensors
+
+def get_dummy_data(config, tokenizer, device, separate_tuple_input_output, num_tokens=None, dtype=torch.float32):
+
+ num_layers = config.num_hidden_layers
+ hidden_size = config.hidden_size
+ num_attention_heads = config.num_attention_heads
+ num_kv_heads = config.num_key_value_heads
+ partial_rotary_factor = config.partial_rotary_factor
+
+ max_tokens = tokenizer.model_max_length
+ attention_mask = torch.ones((1, max_tokens), dtype=torch.long, device=device)
+
+ position_ids = torch.cumsum(attention_mask, dim=1) - 1
+ position_ids = position_ids.clip(0, max_tokens - 1)
+ position_ids = position_ids[..., :num_tokens]
+ position_ids = position_ids.to(device=device)
+ if config.use_combined_mask_input:
+ past_kv_length = max_tokens - num_tokens
+ attention_mask = prepare_combined_attention_mask(attention_mask, input_shape=(1, num_tokens),
+ past_key_values_length=past_kv_length, device=device,
+ mask_neg=llm_config.mask_neg, dtype=dtype)
+
+ if config.use_position_embedding_input:
+ position_ids = get_position_embeddings_from_position_ids(position_ids,
+ head_dim=hidden_size/num_attention_heads,
+ max_length=max_tokens,
+ partial_rotary_factor=partial_rotary_factor,
+ device=device, dtype=dtype,
+ config=config)
+
+ inputs = {
+ 'attention_mask': attention_mask,
+ 'position_ids': position_ids,
+ 'input_ids': torch.randint(0, len(tokenizer), (1, num_tokens), device=device)
+ }
+
+ inputs['past_key_values'] = get_padded_kv_values(past_size=max_tokens - num_tokens,
+ num_layers=num_layers,
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_kv_heads=num_kv_heads,
+ transposed_key_cache=config.transposed_key_cache,
+ device=device,
+ dtype=dtype)
+
+ if separate_tuple_input_output:
+ flattened_kvcache = tuple(flatten_tensors(inputs['past_key_values']))
+ if isinstance(inputs['position_ids'], tuple):
+ inputs = inputs['input_ids'], inputs['attention_mask'], inputs['position_ids'][0], inputs['position_ids'][1]
+ else:
+ inputs = inputs['input_ids'], inputs['attention_mask'], inputs['position_ids']
+ inputs = inputs + flattened_kvcache
+
+ return inputs
+
+
+# ### 6. Prepare model using AIMET model preparer pro
+#
+# #### 6.1 KVCache MHA model preparation
+
+# In[ ]:
+
+
+import time
+from aimet_torch import onnx_utils
+
+from genai_lib.llm.model_preparation_utils import llm_build_preparer_converter_args
+
+from qti.aisw.preparer_api import model_preparer
+from qti.aisw.emitter.utils.torch_utils import load_torch_model_using_safetensors
+
+# Setting this flag to False means that the prepared model will be flattened
+onnx_utils.EXPORT_TO_ONNX_DIRECT = True
+
+def _get_past_key_values_names(sfx, n_layers):
+ all_kvs = []
+ for i in range(n_layers):
+ all_kvs.append(f'past_key_{i}_{sfx}')
+ all_kvs.append(f'past_value_{i}_{sfx}')
+ return all_kvs
+
+dummy_input = get_dummy_data(llm_config,
+ tokenizer, 'cpu', separate_tuple_input_output=False, num_tokens=ARN, dtype=model.dtype)
+input_names = ['input_ids', 'attention_mask']
+input_names += ['position_ids_cos', 'position_ids_sin'] if llm_config.use_position_embedding_input else ['position_ids']
+input_names += _get_past_key_values_names('in', llm_config.num_hidden_layers)
+output_names = ['logits'] + _get_past_key_values_names('out', llm_config.num_hidden_layers)
+
+# Build converter args
+converter_args = llm_build_preparer_converter_args(llm_config.num_hidden_layers, input_names, use_qairt_mpp=True) # Build converter args
+
+prepare_path = os.path.join(output_dir, 'prepare')
+os.makedirs(prepare_path, exist_ok=True)
+prepare_filename = f'{model_name}_kvcache_{llm_config.num_hidden_layers}_layer'
+
+skip_prepare = False
+if skip_prepare:
+ with event_marker(f"KVCache load pre-prepared {prepare_filename}", flush_ram=True):
+ prepared_model_path = os.path.join(prepare_path, f'{prepare_filename}.py')
+ if not os.path.exists(prepared_model_path):
+ raise ValueError(f"prepared artifacts not found in {prepare_path}")
+ else:
+ print(f'WARNING: preparation skipped for model={prepare_filename}, prepared at {time.ctime(os.path.getmtime(prepared_model_path))}')
+ prepared_model = load_torch_model_using_safetensors(path=prepare_path, filename=prepare_filename, model_name=prepare_filename)
+else:
+ with event_marker("KVCache prepare model", flush_ram=True):
+ model.num_logits_to_return = ARN # configuring the model for KVCache mode
+ prepared_model = model_preparer.prepare_model(model,
+ dummy_input,
+ model_name = prepare_filename,
+ filename = prepare_filename,
+ path = prepare_path,
+ input_names = input_names,
+ output_names = output_names,
+ onnx_export_args = {"opset_version":17},
+ converter_args = converter_args,
+ keep_original_model_structure = False, # Flatten the model to enable weight-sharing by setting
+ skipped_optimizers = ['eliminate_common_subexpression',
+ 'eliminate_nop_with_unit',
+ 'eliminate_duplicate_initializer'
+ ],
+ return_prepare_model = False
+ )
+ prepared_model = load_torch_model_using_safetensors(path=prepare_path, filename=prepare_filename, model_name=prepare_filename)
+
+
+# del model # original model no longer needed
+
+
+# #### 6.2 Model prepare verification
+#
+# Verify if prepared KV cache model generates the same PPL as FP model.
+
+# In[ ]:
+
+
+# prepared_model = load_torch_model_using_safetensors(path=prepare_path, filename=prepare_filename, model_name=prepare_filename)
+
+
+# In[ ]:
+
+
+# Calculate ppl score for prepared fp model
+fp_prepared_fpm = LLMForwardPassManager(cfg=llm_config,
+ model=prepared_model,
+ tokenizer=tokenizer,
+ separate_tuple_input_output=True,
+ num_tokens=ARN)
+
+with event_marker("KVcache prepared FP eval", flush_ram=True):
+ with torch.no_grad():
+ with fp_prepared_fpm.place_on_device("cuda"):
+ prepared_kvcache_ppl = ppl_eval(test_dataloader, fp_prepared_fpm)
+
+# This should be very close (<1e-4 delta) to original model's perplexity
+# If the perplexity score goes further up, it indicates the AIMET/QNN pair is producing a faulty prepared model
+print(f"ppl score of KVCACHE prepared fp model: {prepared_kvcache_ppl}\n"
+ f"orig ppl - prepared ppl = {orig_ppl - prepared_kvcache_ppl}")
+
+
+# ### 7. Quantization
+#
+# The _Quantization_ step is the primary focus of this notebook, this section could be modified to execute various quantization experiments.
+
+# #### 7.1 Create quantsim configured for QNN HTP target
+#
+# The following member function allows creation of a shallow copied model. This shallow copied model is a separate model object from the original, but contains shared weights, biases, and parameters. As a result, the shallow copied model has very little memory overhead, which is useful for PTQ techniques like sequential MSE that expect separate FP and QuantSim models.
+
+# In[ ]:
+
+
+# Helper function that creates a shallow copy of the provided model
+# Creates a new model object, but all the underlying parameters are shared
+import copy
+from copy import deepcopy
+import functools
+
+def copy_model_with_shared_weights(source_model):
+ target_model = deepcopy(source_model)
+ for name, source_parameter in source_model.named_parameters():
+ pre, _, post = name.rpartition('.')
+ pre_obj = functools.reduce(getattr, [target_model] + pre.split('.')) if pre else target_model
+ setattr(pre_obj, post, source_parameter)
+ return target_model
+
+
+# In[ ]:
+
+
+from aimet_common.defs import QuantScheme
+from aimet_torch.v2.quantsim import QuantizationSimModel
+
+sim_fpm = LLMForwardPassManager(cfg=llm_config,
+ model=copy_model_with_shared_weights(prepared_model), # to avoid creating the sim in_place on the original model
+ tokenizer=tokenizer,
+ separate_tuple_input_output=True,
+ num_tokens=ARN)
+
+dummy_input = get_dummy_data(llm_config,
+ tokenizer, 'cuda', separate_tuple_input_output=True,
+ num_tokens=ARN, dtype=sim_fpm.dtype)
+
+with event_marker("create KVCache Quantsim"):
+ with sim_fpm.place_on_device("cuda"):
+ quantsim = QuantizationSimModel(model=sim_fpm.model,
+ quant_scheme=QuantScheme.post_training_tf,
+ dummy_input=dummy_input,
+ default_output_bw=16,
+ default_param_bw=4,
+ in_place=True,
+ config_file=htp_config_file)
+
+
+# #### 7.2 Setting 16bit x 8bit matmuls
+#
+# To keep key and value tensors as 8 bits, reducing data I/O costs associated with KV-cache orchestration.
+
+# In[ ]:
+
+
+from aimet_torch.v2.experimental.quantsim_utils import set_matmul_second_input_producer_to_8bit_symmetric
+
+set_matmul_second_input_producer_to_8bit_symmetric(quantsim)
+
+
+# #### 7.3 Concat encoding unification
+#
+# Configuring concat ops to have shared encoding on input and output activations.
+
+# In[ ]:
+
+
+from aimet_torch.v2.experimental import propagate_output_encodings
+import aimet_torch.elementwise_ops as aimet_ops
+
+propagate_output_encodings(quantsim, aimet_ops.Concat)
+
+
+# #### 7.4 Manual Mixed Precision
+#
+# Applying mixed precision configuration to ops
+
+# In[ ]:
+
+
+from llm_utils.mixed_precision_overrides import ManualQuantsimMixedPrecisionConfig
+
+config_file = "./config/mixed_precision_config/exceptions.json"
+
+quantsim_adjuster = ManualQuantsimMixedPrecisionConfig(mixed_precision_config_file=config_file)
+quantsim_adjuster.apply_exceptions(quantsim)
+
+
+# In[ ]:
+
+
+from aimet_torch.v2.nn.modules.custom import QuantizedRmsNorm
+from aimet_torch.v2.quantization.affine import QuantizeDequantize
+
+# Make RMSNorm encodings per-tensor (they default to per-channel)
+for name, qmodule in quantsim.named_qmodules():
+ if isinstance(qmodule, QuantizedRmsNorm):
+ qmodule.param_quantizers['weight'] = QuantizeDequantize(shape=(), bitwidth=16, symmetric=False).to(qmodule.weight.device)
+
+
+# #### 7.5 Optimize parameter encodings
+#
+# Apply either SeqMSE or LPBQ for optimized parameter quantization encodings.
+
+# In[ ]:
+
+
+quant_type = 'lpbq' # Quantization type: lpbq | seqmse
+
+if quant_type == 'lpbq':
+ from aimet_torch.v2.nn.true_quant import QuantizedConv2d
+ from aimet_torch.v2.quantsim.config_utils import set_grouped_blockwise_quantization_for_weights
+ import aimet_common.quantsim as qs
+
+ qs.encoding_version = '1.0.0'
+
+ arg = lambda module: isinstance(module, QuantizedConv2d) and module.param_quantizers['weight'].bitwidth == 4
+ BLOCK_QUANT_SIZE = 64
+ BITWIDTH = 4
+ DECOMPRESSED_BITWIDTH = 8
+ print(arg)
+ set_grouped_blockwise_quantization_for_weights(sim = quantsim,
+ arg = arg,
+ bitwidth = BITWIDTH,
+ symmetric = True,
+ decompressed_bw = DECOMPRESSED_BITWIDTH,
+ block_size = BLOCK_QUANT_SIZE,
+ block_grouping = -1)
+else: # seqmse
+ from aimet_torch.v2.seq_mse import apply_seq_mse
+ from aimet_torch.seq_mse import SeqMseParams
+ # from aimet_torch.utils import load_pytorch_model
+ import aimet_common.quantsim as qs
+
+ qs.encoding_version = '0.6.1'
+ def _forward_fn(model, inputs):
+ if model == fp_prepared_fpm.model:
+ fpm = fp_prepared_fpm
+ else:
+ fpm = sim_fpm
+
+ # slice inputs so that we only end up doing inference using first n tokens
+ input_length = inputs["input_ids"].shape[1]
+ prepared_inputs, _ = fpm.prepare_inputs(input_ids=inputs["input_ids"][:, :min(input_length, fpm.num_tokens), ...])
+ prepared_inputs = {name: t.to(torch.half) if t.is_floating_point() else t for name, t in prepared_inputs.items()}
+ fpm.model(**prepared_inputs)
+
+ params = SeqMseParams(num_batches=20,
+ inp_symmetry="symqt",
+ num_candidates=60,
+ loss_fn="mse",
+ forward_fn=_forward_fn)
+
+ with event_marker("SeqMSE"):
+ with fp_prepared_fpm.place_on_device("cuda"), sim_fpm.place_on_device("cuda"):
+ apply_seq_mse(fp_prepared_fpm.model, quantsim, train_dataloader, params)
+
+ del fp_prepared_fpm
+ del prepared_model
+
+
+# In[ ]:
+
+
+from aimet_torch.v2.seq_mse import apply_seq_mse
+from aimet_torch.seq_mse import SeqMseParams
+# from aimet_torch.utils import load_pytorch_model
+
+def _forward_fn(model, inputs):
+ if model == fp_prepared_fpm.model:
+ fpm = fp_prepared_fpm
+ else:
+ fpm = sim_fpm
+
+ # slice inputs so that we only end up doing inference using first n tokens
+ input_length = inputs["input_ids"].shape[1]
+ prepared_inputs, _ = fpm.prepare_inputs(input_ids=inputs["input_ids"][:, :min(input_length, fpm.num_tokens), ...])
+ prepared_inputs = {name: t.to(torch.half) if t.is_floating_point() else t for name, t in prepared_inputs.items()}
+ fpm.model(**prepared_inputs)
+
+params = SeqMseParams(num_batches=20,
+ inp_symmetry="symqt",
+ num_candidates=20,
+ loss_fn="mse",
+ forward_fn=_forward_fn)
+
+with event_marker("SeqMSE"):
+ with fp_prepared_fpm.place_on_device("cuda"), sim_fpm.place_on_device("cuda"):
+ apply_seq_mse(fp_prepared_fpm.model, quantsim, train_dataloader, params)
+
+del fp_prepared_fpm
+del prepared_model
+
+
+# #### 7.6 Calibration
+#
+# Compute activation encodings using AIMET
+
+# In[ ]:
+
+
+def _forward_fn(model, kwargs):
+ data_loader = kwargs['data_loader']
+ fpm = kwargs['fpm']
+ max_iterations = kwargs['num_batches']
+ for batch_id, batch in enumerate(tqdm(data_loader, total=max_iterations)):
+ if batch_id < max_iterations:
+ slice_inputs_and_run_successive_kvcache_inference(fpm, input_ids=batch['input_ids'])
+ else:
+ break
+kwargs = {
+ 'data_loader': train_dataloader,
+ 'fpm': sim_fpm,
+ 'num_batches': 100
+}
+
+with event_marker("compute encoding", flush_ram=True):
+ with sim_fpm.place_on_device("cuda"):
+ quantsim.compute_encodings(_forward_fn, kwargs)
+
+
+# #### 7.7 Eval KV Cache sim model
+
+# In[ ]:
+
+
+with event_marker("KV cache sim eval", flush_ram=True):
+ with torch.no_grad():
+ with sim_fpm.place_on_device("cuda"):
+ sim_ppl = ppl_eval(test_dataloader, sim_fpm)
+
+print(f"ppl score of KVCACHE sim fp model: {sim_ppl}\n"
+ f"orig ppl - kvcache sim ppl = {orig_ppl - sim_ppl}")
+
+
+# ### 8. Export
+#
+# The pipeline call below would export onnx model, encoding and test vector for KVCache models.
+#
+# #### 8.1 Export KVCache Model
+
+# In[ ]:
+
+
+from aimet_torch.utils import change_tensor_device_placement
+from aimet_torch.onnx_utils import OnnxExportApiArgs
+
+onnx_api_args = OnnxExportApiArgs(input_names=input_names,output_names=output_names)
+sample_inputs = change_tensor_device_placement(dummy_input, torch.device('cpu'))
+with event_marker("KVCache export", flush_ram=True):
+ quantsim.export(onnx_dir, model_name, sample_inputs, onnx_export_args=onnx_api_args)
+
+# Export chat template
+if getattr(tokenizer, "chat_template", None):
+ with open(os.path.join(output_dir, "chat_template.jinja"), "w", encoding="utf-8") as f:
+ f.write(tokenizer.chat_template)
+else:
+ print("No chat_template found on tokenizer; nothing to export.")
+
+# Export generation config
+model.generation_config.save_pretrained(output_dir)
+
+# #### 8.2 Generating test vectors for QNN SDK
+
+# In[ ]:
+
+
+from llm_utils.test_vectors import generate_test_vectors
+
+test_vector_layers = [
+ "rms_norm_\\d+",
+ "lm_head_conv_Conv$"
+]
+
+with event_marker("generate test vector"):
+ with sim_fpm.place_on_device("cuda"):
+ generate_test_vectors(quantsim, sim_fpm, train_dataloader, output_dir, num_batches=1, test_vector_layers=test_vector_layers, input_names=input_names)
+
+
+# ### Summary
+
+# In[ ]:
+
+
+# from aimet_torch.pro.utils.profiler import EventProfiler
+from genai_lib.common.debug.profiler import EventProfiler
+EventProfiler().report()
+EventProfiler().json_dump(os.path.join(output_dir, 'profiling_stats.json'))
+
+import json
+with open(f'{output_dir}/ppl.json', 'wt') as f:
+ json.dump({
+ "original": float(orig_ppl),
+ "prepared_kvcache": float(prepared_kvcache_ppl),
+ "QuantSim": float(sim_ppl),
+ }, f, indent=2)
+
+
+# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries.
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/quantizer_utils/quantizer_manipulation.py b/microsoft-Phi-4-mini-instruct/QAIRT/quantizer_utils/quantizer_manipulation.py
new file mode 100644
index 000000000..ef443a353
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/quantizer_utils/quantizer_manipulation.py
@@ -0,0 +1,130 @@
+# =============================================================================
+#
+# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
+# All rights reserved.
+# Confidential and Proprietary - Qualcomm Technologies, Inc.
+#
+# =============================================================================
+"""This file contains utilities for manipulating quantsim quantizers"""
+
+from typing import Sequence, Set, Optional, Union
+
+import torch
+from aimet_torch.v2.quantsim import QuantizationSimModel
+from aimet_torch.v2.nn import QuantizedConv2d, QuantizedLinear
+from aimet_torch.v2.nn.modules.custom import QuantizedMultiply, QuantizedAdd
+
+
+def freeze_quantizers(
+ sim: QuantizationSimModel,
+ target_types: Union[Sequence[str], str] = ('input', 'output', 'param'),
+ return_act_encodings: bool = False,
+ return_param_encodings: bool = False,
+ target_layers: Optional[Union[Set[str], Sequence[str]]] = None,
+):
+ """
+ Freezes quantizer encodings and optionally return the frozen encodings
+ :param sim: The sim object of which the quantizers are set to be non-overwritable
+ :param target_types: The type of quantizers to freeze. Can be any of 'input', 'output', 'param' or a sequence of
+ any combinations of these options
+ :param return_act_encodings: Whether to return the activation encodings
+ :param return_param_encodings: Whether to return the parameter encodings
+ :param target_layers [Optional]: The layers to freeze. If not specified, default to freeze all layers
+ :return: Optionally return the frozen encodings
+ """
+ if isinstance(target_types, str):
+ target_types = [target_types]
+
+ if target_layers is not None and not isinstance(target_layers, set):
+ target_layers = set(target_layers)
+
+ for layer_name, qmodule in sim.named_qmodules():
+ # Skip layers if not found in frozen_layers
+ if target_layers is not None and layer_name not in target_layers:
+ continue
+
+ for target_type in target_types:
+ if target_type.lower() == 'input':
+ for input_quantizer in qmodule.input_quantizers:
+ if input_quantizer is not None:
+ input_quantizer.allow_overwrite(False)
+
+ if target_type.lower() == 'output':
+ for output_quantizer in qmodule.output_quantizers:
+ if output_quantizer is not None:
+ output_quantizer.allow_overwrite(False)
+
+ if target_type.lower() == 'param':
+ for param_quantizer in qmodule.param_quantizers.values():
+ if param_quantizer is not None:
+ param_quantizer.allow_overwrite(False)
+
+ if return_act_encodings or return_param_encodings:
+ act_encodings, param_encodings = sim.get_activation_param_encodings()
+
+ return_encodings = {}
+
+ if return_act_encodings:
+ return_encodings['activation_encodings'] = act_encodings
+
+ if return_param_encodings:
+ return_encodings['param_encodings'] = param_encodings
+
+ return return_encodings
+
+def set_lora_quantizer(
+ sim: QuantizationSimModel,
+ target_modules: Union[Set[str], Sequence[str]],
+ param_bitwidth: int = 8,
+ scaling_min: float = 0.0,
+ scaling_max: float = 1.0,
+):
+ """
+ Set the quantizer range for LoRA scalings and bitwidth of LoRA parameters
+ :param sim: The sim object of which the LoRA quantizers are set to the provided range
+ :param target_modules: The LoRA target_modules. We use target_modules to do substring
+ match for the sim.model's layer names in order to find the correspoding lora layers
+ :param param_bitwidth: The bitwidth of the LoRA param quantizers
+ :param scaling_min: The min value to set for LoRA scalings quantizers
+ :param scaling_max: The max value to set for LoRA scalings quantizers
+ :return: None. Quantizers are set in-place
+ """
+ for layer_name, module in sim.named_qmodules():
+ if (isinstance(module, (QuantizedConv2d, QuantizedLinear, QuantizedMultiply)) and
+ any(map(lambda target_module: target_module in layer_name, target_modules))):
+ if 'lora' in layer_name and isinstance(module, (QuantizedConv2d, QuantizedLinear)):
+ # Set bitwidth for LoRA A/B (aka up/down) modules
+ module.param_quantizers['weight'].bitwidth = param_bitwidth
+ if isinstance(module, QuantizedMultiply):
+ # Set range of the LoRA multiplication layer
+ module.input_quantizers[1].set_range(
+ torch.as_tensor(scaling_min),
+ torch.as_tensor(scaling_max)
+ )
+ # Freeze Quantizer
+ module.input_quantizers[1].allow_overwrite(False)
+
+def propagate_lora_add(sim: QuantizationSimModel, target_modules: Union[Set[str], Sequence[str]]):
+ """
+ Propagate LoRA add's output quantizers from its base layer via unifying their output_quantizers
+ :param sim: The sim object of which the LoRA add quantizers are to be unified
+ :param target_modules: The LoRA target_modules. We use target_modules to do substring matching.
+ The substrings in target_modules must uniquely identify each layer s.t. we can find the unique
+ match of the layer in sim
+ :return: None. Propagation happens in-place
+ """
+ base_output_quantizers = {}
+ lora_add_modules = {}
+
+ for layer_name, module in sim.named_qmodules():
+ if isinstance(module, (QuantizedConv2d, QuantizedLinear, QuantizedAdd)):
+ for target_module in target_modules:
+ if target_module in layer_name:
+ if 'base_layer' in layer_name and isinstance(module, (QuantizedConv2d, QuantizedLinear)):
+ base_output_quantizers[target_module] = module.output_quantizers
+ if isinstance(module, QuantizedAdd):
+ lora_add_modules[target_module] = module
+ break
+
+ for target_module, base_output_quantizer in base_output_quantizers.items():
+ lora_add_modules[target_module].output_quantizers = base_output_quantizer
diff --git a/microsoft-Phi-4-mini-instruct/QAIRT/requirements.txt b/microsoft-Phi-4-mini-instruct/QAIRT/requirements.txt
new file mode 100644
index 000000000..2b22f6611
--- /dev/null
+++ b/microsoft-Phi-4-mini-instruct/QAIRT/requirements.txt
@@ -0,0 +1,315 @@
+absl-py==2.4.0
+accelerate==1.13.0
+aenum==3.1.15
+aimet-torch==2.14.0
+aiohappyeyeballs==2.6.1
+aiohttp==3.13.3
+aiosignal==1.4.0
+alembic==1.18.4
+annotated-doc==0.0.4
+annotated-types==0.7.0
+ansi2html==1.9.2
+anyio==4.13.0
+anykeystore==0.2
+apex==0.9.10.dev0
+argon2-cffi==25.1.0
+argon2-cffi-bindings==25.1.0
+arrow==1.4.0
+asttokens==3.0.1
+async-lru==2.3.0
+async-timeout==5.0.1
+attrs==25.4.0
+autograd==1.8.0
+babel==2.18.0
+bcrypt==5.0.0
+beautifulsoup4==4.14.3
+bidict==0.23.1
+bleach==6.3.0
+bokeh==3.2.2
+certifi==2026.1.4
+cffi==2.0.0
+chardet==5.2.0
+charset-normalizer==3.4.4
+clarabel==0.11.1
+click==8.3.3
+cma==2.7.0
+colorcet==3.1.0
+coloredlogs==15.0.1
+colorlog==6.10.1
+comm==0.2.3
+concurrent-log-handler==0.9.29
+contourpy==1.3.2
+cryptacular==1.6.2
+cryptography==45.0.7
+cuda-bindings==12.9.4
+cuda-pathfinder==1.3.3
+cuda-toolkit==13.0.2
+cvxpy==1.6.0
+cycler==0.12.1
+dash==2.12.1
+dash-core-components==2.0.0
+dash-html-components==2.0.0
+dash-table==5.0.0
+dataclasses==0.6
+datasets==4.0.0
+debugpy==1.8.20
+decorator==5.2.1
+defusedxml==0.7.1
+Deprecated==1.3.1
+dill==0.3.8
+exceptiongroup==1.3.1
+executing==2.2.1
+fastapi==0.135.1
+fastjsonschema==2.21.2
+filelock==3.20.3
+Flask==2.2.5
+flatbuffers==25.12.19
+fonttools==4.61.1
+fqdn==1.5.1
+frozenlist==1.8.0
+fsspec==2025.3.0
+gguf==0.9.1
+greenlet==3.5.0
+grpcio==1.80.0
+grpcio-tools==1.80.0
+h11==0.16.0
+h5py==3.16.0
+hf-xet==1.5.0
+holoviews==1.18.3
+httpcore==1.0.9
+httpx==0.28.1
+huggingface_hub==0.36.2
+humanfriendly==10.0
+hupper==1.12.1
+hvplot==0.9.2
+idna==3.11
+importlib_metadata==8.7.1
+iniconfig==2.3.0
+invoke==1.7.3
+ipykernel==7.1.0
+ipython==8.38.0
+islpy==2025.2.5
+isoduration==20.11.0
+itsdangerous==2.2.0
+jedi==0.19.2
+Jinja2==3.1.0
+joblib==1.5.3
+json5==0.14.0
+jsonpointer==3.1.1
+jsonschema==4.26.0
+jsonschema-specifications==2025.9.1
+jupyter-events==0.12.1
+jupyter-lsp==2.3.1
+jupyter_client==8.8.0
+jupyter_core==5.9.1
+jupyter_server==2.18.2
+jupyter_server_terminals==0.5.4
+jupyterlab==4.5.7
+jupyterlab_pygments==0.3.0
+jupyterlab_server==2.28.0
+kiwisolver==1.4.9
+lark==1.3.1
+lightning-utilities==0.15.3
+linkify-it-py==2.0.3
+lxml==5.2.1
+Mako==1.3.12
+Markdown==3.10.1
+markdown-it-py==4.0.0
+MarkupSafe==3.0.3
+matplotlib==3.10.8
+matplotlib-inline==0.2.1
+mdit-py-plugins==0.5.0
+mdurl==0.1.2
+mistune==3.2.1
+ml_dtypes==0.5.4
+mock==3.0.5
+mpmath==1.3.0
+multidict==6.7.1
+multiprocess==0.70.16
+narwhals==2.15.0
+nbclient==0.10.4
+nbconvert==7.17.1
+nbformat==5.10.4
+nest-asyncio==1.6.0
+networkx==3.4.2
+notebook==7.5.6
+notebook_shim==0.2.4
+numpy==1.24.4
+nvidia-cublas==13.1.0.3
+nvidia-cublas-cu11==11.10.3.66
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti==13.0.85
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc==13.0.88
+nvidia-cuda-nvrtc-cu11==11.7.99
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime==13.0.96
+nvidia-cuda-runtime-cu11==11.7.99
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu11==8.5.0.96
+nvidia-cudnn-cu12==8.9.2.26
+nvidia-cudnn-cu13==9.19.0.56
+nvidia-cufft==12.0.0.61
+nvidia-cufft-cu12==11.0.2.54
+nvidia-cufile==1.15.1.6
+nvidia-cufile-cu12==1.13.1.3
+nvidia-curand==10.4.0.35
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver==12.0.4.66
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse==12.6.3.3
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-cusparselt-cu12==0.7.1
+nvidia-cusparselt-cu13==0.8.0
+nvidia-nccl-cu12==2.18.1
+nvidia-nccl-cu13==2.28.9
+nvidia-nvjitlink==13.0.88
+nvidia-nvjitlink-cu12==12.8.93
+nvidia-nvshmem-cu12==3.4.5
+nvidia-nvshmem-cu13==3.4.5
+nvidia-nvtx==13.0.85
+nvidia-nvtx-cu12==12.1.105
+oauthlib==3.3.1
+onnx==1.17.0
+onnx-ir==0.1.15
+onnx_graphsurgeon==0.5.8
+onnxruntime==1.23.2
+onnxruntime-genai==0.8.2
+onnxruntime_extensions==0.11.0
+onnxscript==0.6.0
+onnxsim==0.4.36
+opencv-python==4.8.1.78
+opentelemetry-api==1.41.1
+opentelemetry-sdk==1.41.1
+opentelemetry-semantic-conventions==0.62b1
+optuna==4.8.0
+osqp==1.1.0
+overrides==7.7.0
+packaging==26.0
+pandas==1.5.3
+pandocfilters==1.5.1
+panel==1.3.8
+param==2.3.1
+paramiko==3.5.1
+parso==0.8.5
+PasteDeploy==3.1.0
+pathlib2==2.3.6
+pbkdf2==1.3
+peft==0.19.1
+pexpect==4.9.0
+Pillow==8.4.0
+plaster==1.1.2
+plaster-pastedeploy==1.0.1
+platformdirs==4.5.1
+plotly==5.20.0
+pluggy==1.6.0
+portalocker==3.2.0
+prometheus_client==0.25.0
+prompt_toolkit==3.0.52
+propcache==0.4.1
+protobuf==3.20.2
+psutil==6.1.1
+ptflops==0.7.5
+ptyprocess==0.7.0
+pure_eval==0.2.3
+pyaml==26.2.1
+pyarrow==23.0.0
+pybind11==3.0.1
+pycparser==3.0
+pydantic==2.8.2
+pydantic_core==2.20.1
+pyDOE2==1.3.0
+Pygments==2.19.2
+pymoo==0.4.1
+PyNaCl==1.6.2
+pyparsing==3.3.2
+pyramid==2.1
+pyramid-mailer==0.15.1
+pytest==8.1.1
+python-dateutil==2.9.0.post0
+python-engineio==4.13.1
+python-json-logger==4.1.0
+python-socketio==5.16.1
+python3-openid==3.2.0
+pytz==2025.2
+pyviz_comms==3.0.6
+PyYAML==6.0.3
+pyzmq==26.4.0
+questionary==2.1.1
+referencing==0.37.0
+regex==2026.1.15
+repoze-sendmail==4.5
+requests==2.32.5
+requests-oauthlib==2.0.0
+retrying==1.4.2
+rfc3339-validator==0.1.4
+rfc3986-validator==0.1.1
+rfc3987-syntax==1.1.0
+rich==14.3.2
+rpds-py==0.30.0
+safetensors==0.5.3
+scikit-learn==1.7.2
+scikit-optimize==0.9.0
+scipy==1.8.1
+scs==3.2.11
+Send2Trash==2.1.0
+sentencepiece==0.2.0
+shellingham==1.5.4
+simple-websocket==1.1.0
+six==1.17.0
+soupsieve==2.8.3
+SQLAlchemy==2.0.49
+stack-data==0.6.3
+starlette==1.0.0
+sympy==1.14.0
+tabulate==0.9.0
+tenacity==9.1.4
+tensorboard==2.20.0
+tensorboard-data-server==0.7.2
+terminado==0.18.1
+threadpoolctl==3.6.0
+timm==0.4.12
+tinycss2==1.4.0
+tokenizers==0.20.3
+tomli==2.4.1
+torch==2.1.2 --index-url https://download.pytorch.org/whl/cu121
+torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121
+torchmetrics==1.9.0
+torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121
+tornado==6.5.4
+tqdm==4.67.2
+traitlets==5.14.3
+transaction==5.1
+transformers==4.46.0
+translationstring==1.4
+triton==2.1.0
+typer==0.25.1
+types-paramiko==3.5.0.20250516
+typing-inspection==0.4.2
+typing_extensions==4.15.0
+tzdata==2025.3
+uc-micro-py==1.0.3
+uri-template==1.3.0
+urllib3==2.6.3
+uvicorn==0.46.0
+velruse==1.1.1
+venusian==3.1.1
+wcwidth==0.5.3
+webcolors==25.10.0
+webencodings==0.5.1
+WebOb==1.8.9
+websocket-client==1.9.0
+Werkzeug==2.2.3
+wrapt==2.1.2
+wsproto==1.3.2
+WTForms==3.2.2
+wtforms-recaptcha==0.3.2
+XlsxWriter==1.2.2
+xxhash==3.6.0
+xyzservices==2025.11.0
+yarl==1.22.0
+zipp==3.23.1
+zope.deprecation==6.0
+zope.interface==8.4
+zope.sqlalchemy==4.1
+zstandard==0.25.0