easydel.layers.caching.paged_attention.__init__#
- class easydel.layers.caching.paged_attention.__init__.ActiveSequenceBatch(token_ids: ~typing.Union[~jax.Array, ~typing.List[int]], positions: ~jax.Array, page_table: ~jax.Array, sampling_params: ~easydel.layers.caching.paged_attention.types.SamplingParams, available_slots: ~typing.Optional[~_queue.SimpleQueue], active_slot_requests_map: ~typing.Optional[~typing.Dict[int, ~easydel.layers.caching.paged_attention.types.GenerationStepTask]], context_lock: ~typing.Optional[~_thread.allocate_lock] = <unlocked _thread.lock object>, page_update_slots: ~typing.Optional[~jax.Array] = None, page_update_page_idxs: ~typing.Optional[~jax.Array] = None, page_update_mapped_idxs: ~typing.Optional[~jax.Array] = None)[source]#
Bases:
objectManages the batch state for sequences in the decoding phase.
This class holds the dynamic state for all sequences currently undergoing token generation (decoding) within the paged attention batch. It includes the input token IDs for the next step, current positions, page table mappings, sampling parameters, and structures for managing available batch slots.
- token_ids#
JAX array (during model execution) or list (during host-side updates) holding the input token ID for the next decode step for each active slot. Shape: (batch_size,).
- Type
jax.Array | tp.List[int]
- positions#
JAX array holding the current sequence position (length) for each active slot. Inactive slots might have a placeholder value (e.g., -1 or 1e6). Shape: (batch_size,).
- Type
- page_table#
JAX array mapping logical page indices to physical HBM page indices for each sequence slot. Shape: (batch_size, num_pages_per_sequence).
- Type
- sampling_params#
Sampling parameters for all sequences in the batch. (Inference Argument)
- Type
- available_slots#
A queue managing indices of free batch slots. Used by the host-side scheduler. (Scheduler Argument)
- Type
tp.Optional[queue.SimpleQueue]
- active_slot_requests_map#
A dictionary mapping active slot indices to their corresponding GenerationStepTask. Used by the host-side scheduler. (Scheduler Argument)
- Type
tp.Optional[tp.Dict[int, GenerationStepTask]]
- context_lock#
A lock for thread-safe updates to scheduler-related attributes (available_slots, active_slot_requests_map). (Scheduler Argument)
- Type
tp.Optional[threading.Lock]
- page_update_slots#
Array storing slot indices for pending page table updates. (Internal State)
- Type
tp.Optional[jax.Array]
- page_update_page_idxs#
Array storing logical page indices for pending page table updates. (Internal State)
- Type
tp.Optional[jax.Array]
- page_update_mapped_idxs#
Array storing physical mapped indices for pending page table updates. (Internal State)
- Type
tp.Optional[jax.Array]
- active_slot_requests_map: Optional[Dict[int, GenerationStepTask]]#
- apply_assignment(assignment: list[easydel.layers.caching.paged_attention.types.SlotPageAssignment])[source]#
Applies page table assignments to the internal update arrays (host-side).
This method populates the page_update_* arrays based on a list of SlotPageAssignment objects, preparing them for later use in JAX computations (insert_decode_state).
- Parameters
assignment (list[SlotPageAssignment]) – A list of page assignments to apply.
- available_slots: Optional[SimpleQueue]#
- context_lock: Optional[allocate_lock] = <unlocked _thread.lock object>#
- copy_decode(decode: ActiveSequenceBatch)[source]#
Copies essential decode state from another ActiveSequenceBatch (host-side).
Updates the current object’s token_ids, positions, page_table, and sampling_params based on the source decode object. Assumes host-side operation.
- Parameters
decode (ActiveSequenceBatch) – The source batch state to copy from.
- classmethod create(metadata: PagedAttentionCacheMetaData, mesh: Mesh)[source]#
Creates and initializes an ActiveSequenceBatch for JAX execution.
This factory method sets up the ActiveSequenceBatch with JAX arrays (token_ids, positions, page_table, sampling_params) appropriately sharded across the provided mesh. It also initializes the host-side scheduler components (available_slots, active_slot_requests_map).
- Parameters
metadata (PagedAttentionCacheMetaData) – Paged attention cache configuration.
mesh (common_types.Mesh) – The JAX device mesh for array distribution.
- Returns
An initialized batch state ready for JAX operations.
- Return type
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_numpy(metadata: PagedAttentionCacheMetaData) ActiveSequenceBatch[source]#
Initializes ActiveSequenceBatch with NumPy arrays for host-side use.
Creates the necessary arrays (positions, page_table, page_update_*) using NumPy, suitable for manipulation outside JAX computations. token_ids is initialized as an empty list.
- Parameters
metadata (PagedAttentionCacheMetaData) – Configuration metadata.
- Returns
An initialized batch state object with NumPy arrays.
- Return type
- insert_decode_state(insert_slots: Array, update: ActiveSequenceBatch)[source]#
Updates the decode state in specified slots from another batch state.
This is typically used during JAX computation to incorporate updates from newly scheduled decode tasks or page table modifications.
- Parameters
insert_slots (jax.Array) – An array of slot indices to update.
update (ActiveSequenceBatch) – The batch containing the new state to insert.
- insert_from_task(slot: int, task: GenerationStepTask)[source]#
Inserts state from a GenerationStepTask into a specific slot (host-side).
Updates the host-side representations (token_ids list, positions array, page_table array, sampling_params) for the given slot based on the provided task.
- Parameters
slot (int) – The batch slot index to insert the task state into.
task (GenerationStepTask) – The task containing the state to insert.
- property is_active#
Checks if the batch state is active (i.e., has associated token IDs).
- Returns
True if token_ids is a non-empty array, False otherwise.
- Return type
bool
- pad_tokens(pad_length: int)[source]#
Pads the host-side token list with placeholder scalars.
Ensures the token_ids list (when used host-side) reaches the required pad_length by appending placeholder scalar JAX arrays.
- Parameters
pad_length (int) – The target length for the token_ids list.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sampling_params: SamplingParams#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.AllocatedPrefillPages(page_indices: list[int])[source]#
Bases:
objectHolds the indices of HBM pages allocated during a prefill step.
This simple structure is used to communicate which physical memory pages have been assigned to a sequence during its prefill processing.
- page_indices#
A list containing the indices of the HBM (High Bandwidth Memory) pages allocated for a specific prefill chunk.
- Type
list[int]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- page_indices: list[int]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.GenerationStepTask(id: str, slot: int, position: int, page_indices: list[int], prefill_token_id: Array, sampling_params: SamplingParams)[source]#
Bases:
objectRepresents a sequence actively undergoing token generation (decoding).
This class holds the necessary information for a single sequence that is currently in the decoding phase within the paged attention batch.
- id#
The unique identifier tracing back to the original request.
- Type
str
- slot#
The assigned batch slot index for this sequence.
- Type
int
- position#
The current sequence length (position) for the next token.
- Type
int
- page_indices#
The list of physical HBM page indices allocated to this sequence’s KV cache.
- Type
list[int]
- prefill_token_id#
The token ID generated in the previous step (either prefill or the last decode step), which serves as input for the current decode step.
- Type
- sampling_params#
The sampling parameters associated with this specific generation task.
- Type
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- id: str#
- page_indices: list[int]#
- position: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sampling_params: SamplingParams#
- slot: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.HBMPageManager(metadata: PagedAttentionCacheMetaData)[source]#
Bases:
objectManages the allocation and deallocation of physical HBM pages for the KV cache. It keeps track of available pages.
- _metadata#
Configuration for the paged cache.
- _current_page_index#
Index representing the initial dummy page.
- Type
int
- _available_hbm_pages#
Queue of free HBM page indices.
- Type
queue.SimpleQueue
- alloc_hbm_pages(n: int) list[int][source]#
Allocates a specific number of HBM pages.
- Parameters
n (int) – Number of pages to allocate.
- Returns
Allocated HBM page indices (empty if insufficient pages).
- Return type
list[int]
- alloc_prefill_hbm_pages(prompt_len) list[int][source]#
Allocates the required number of HBM pages for a prompt prefill based on its length.
- Parameters
prompt_len (int) – The length of the prompt (or chunk).
- Returns
List of allocated HBM page indices (empty if insufficient pages).
- Return type
list[int]
- property current_page_index#
Returns the dummy page index (usually 0).
- free_hbm_pages(pages: list[int])[source]#
Returns a list of HBM pages back to the available pool.
- Parameters
pages (list[int]) – HBM page indices to free (ignores dummy page).
- property metadata: PagedAttentionCacheMetaData#
Returns the cache metadata.
- property page_size#
Number of per-token KV cache items per page.
- class easydel.layers.caching.paged_attention.__init__.InferenceScheduler(manager: HBMPageManager)[source]#
Bases:
objectSchedules incoming prefill and decode requests based on HBM page and slot availability. It coordinates with HBMPageManager and ActiveSequenceBatch.
- prefill_queue#
Incoming prefill requests queue.
- Type
queue.Queue[InitialSequenceRequest]
- decodes_queue#
Prefill-to-decode transition queue.
- Type
queue.Queue[GenerationStepTask]
- manager#
Manager for HBM page allocation.
- Type
- batch_size#
Max concurrent decode requests.
- Type
int
- max_seq_len#
Max sequence length supported.
- Type
int
- create_plan(active_prefill: easydel.layers.caching.paged_attention.types.InitialSequenceRequest | None, decodes_state: ActiveSequenceBatch) easydel.layers.caching.paged_attention.types.NextIterationPlan | None[source]#
Determines the workload for the next model iteration.
This function decides: 1. If a new prefill request can be started or an existing one continued. 2. If HBM pages need to be allocated for the prefill request. 3. Which decode requests from the decodes_queue can be added to the active batch. 4. If HBM pages need to be allocated for ongoing decode requests.
It updates the host-side state (like decodes_state.active_slot_requests_map) but prepares updates for the device-side state (JAX arrays) within the returned NextIterationPlan.
- Parameters
active_prefill (InitialSequenceRequest | None) – The prefill request currently being processed, if any.
decodes_state (ActiveSequenceBatch) – The current state of the decode batch.
- Returns
- An object containing the scheduling decisions and necessary updates
for the next iteration, or None if no work can be scheduled (e.g., all queues empty).
- Return type
NextIterationPlan | None
- Raises
NotImplementedError – If page allocation fails and eviction is required (currently not supported).
- enqueue_decodes_request(request: GenerationStepTask)[source]#
Adds a completed prefill request to the decode queue.
- enqueue_prefill_request(request: InitialSequenceRequest)[source]#
Adds a prefill request to the prefill queue.
- class easydel.layers.caching.paged_attention.__init__.InitialSequenceRequest(token_ids: Array, positions: Array, page_indices: Array, sampling_params: SamplingParams, id: Optional[Union[str, int]], chunk_idx: Optional[int], chunk_size: Optional[int], prompt_token_ids: Optional[list[int]], length: Optional[Array] = None)[source]#
Bases:
objectRepresents a request for processing a new sequence during the prefill phase.
This class encapsulates the information needed to process a new input sequence (prompt) in the paged attention mechanism. It includes the token IDs, positions, allocated page indices, and associated metadata.
- token_ids#
JAX array of token IDs for the prefill sequence, potentially padded. (Runtime Argument)
- Type
- page_indices#
JAX array holding the indices of HBM pages allocated for this request’s KV cache. (Runtime Argument)
- Type
- sampling_params#
Sampling parameters for this specific request. (Runtime Argument)
- Type
- id#
A unique identifier for the request. (Scheduler Argument)
- Type
tp.Optional[str | int]
- chunk_idx#
The index of the current chunk being processed if the prompt is chunked. (Scheduler Argument)
- Type
tp.Optional[int]
- chunk_size#
The size of the token chunk being processed in this prefill step. (Scheduler Argument)
- Type
tp.Optional[int]
- prompt_token_ids#
The original list of token IDs for the complete prompt. (Scheduler Argument)
- Type
tp.Optional[list[int]]
- length#
The actual length of the sequence processed so far (relevant for chunked prefill). Defaults to None. (Scheduler Argument)
- Type
tp.Optional[jax.Array]
- chunk_idx: Optional[int]#
- chunk_size: Optional[int]#
- copy_prefill(prefill: InitialSequenceRequest)[source]#
Copies runtime state from another InitialSequenceRequest (prefill source).
This method updates the current request’s runtime attributes (token_ids, positions, page_indices, sampling_params, length) based on the state of a source prefill request, typically used when advancing through chunks of a long prompt.
- Parameters
prefill (InitialSequenceRequest) – The source request from which to copy runtime state.
- classmethod create(id: str, mesh: Mesh, metadata: PagedAttentionCacheMetaData, chunk_size: int, prompt_token_ids: list[int], max_prefill_length: Optional[int] = None, prefill_lengths: Optional[List[int]] = None, sampling_params: Optional[SamplingParams] = None)[source]#
Creates an InitialSequenceRequest from prompt token IDs.
This factory method takes a list of token IDs and prepares them for the prefill phase, including padding, creating position IDs, initializing page indices, and setting up sampling parameters. Arrays are placed on the specified JAX mesh.
- Parameters
mesh (common_types.Mesh) – The JAX device mesh for array distribution.
metadata (PagedAttentionCacheMetaData) – Paged attention cache configuration.
chunk_size (int) – The size for potential chunking during prefill.
prompt_token_ids (list[int]) – The input prompt token IDs.
prefill_length (tp.Optional[int]) – Target length for padding token IDs. Defaults to metadata.max_sequences.
sampling_params (tp.Optional[SamplingParams]) – Custom sampling parameters. If None, default parameters are used.
- Returns
An initialized request object ready for prefill.
- Return type
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- id: Optional[Union[str, int]]#
- classmethod init_empty()[source]#
Creates an empty InitialSequenceRequest placeholder.
Initializes attributes with placeholder scalar JAX arrays suitable for use in contexts where a valid request might not be present (e.g., padding).
- Returns
A placeholder request object.
- Return type
- property is_active#
Checks if the request is active (i.e., has associated token IDs).
- Returns
True if token_ids is a non-empty array, False otherwise.
- Return type
bool
- prompt_token_ids: Optional[list[int]]#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- sampling_params: SamplingParams#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.ModelIOProcessor[source]#
Bases:
objectProcesses and transforms model inputs and outputs for paged attention.
This class handles the construction of model input batches from prefill and decode states, and organizes raw model outputs into structured data for further processing.
- classmethod build_input(iteration_plan: NextIterationPlan, metadata: PagedAttentionCacheMetaData, decodes_state: ActiveSequenceBatch)[source]#
Orchestrates the preparation of model inputs based on a scheduling decision.
It retrieves the necessary data from the iteration_plan and decodes_state, formats it (converting lists/numpy arrays to JAX arrays as needed), and calls prepare_model_input to perform the JIT-compiled merging and updating. It also updates the decodes_state (device arrays) based on the results from prepare_model_input.
- Parameters
cls – The class itself.
iteration_plan (NextIterationPlan) – The output of the scheduler, indicating what to run.
metadata (PagedAttentionCacheMetaData) – Cache configuration.
decodes_state (ActiveSequenceBatch) – The current state of the decode batch.
- Returns
The structured input ready for the model’s forward pass.
- Return type
- static prepare_model_input(attn_meta: PagedAttentionMetadata, ongoing_prefill: InitialSequenceRequest, chunk_id: Array, chunk_size: Array, ongoing_decodes: ActiveSequenceBatch, ongoing_updates: ActiveSequenceBatch, insert_slots: Array)[source]#
Prepares model input by slicing and concatenating prefill and decode tokens, updating attention metadata.
- Parameters
attn_meta (PagedAttentionMetadata) – Metadata to update with positions and page tables.
ongoing_prefill (InitialSequenceRequest) – Current prefill request state.
chunk_id (jax.Array) – Index of the current prefill chunk.
chunk_size (jax.Array) – Size of each prefill chunk.
ongoing_decodes (ActiveSequenceBatch) – Current decode batch state.
ongoing_updates (ActiveSequenceBatch) – Batch updates for ongoing decodes.
insert_slots (jax.Array) – Slots indices indicating where to insert new decodes.
- Returns
- (input_ids, decodes_token_ids, positions, attn_meta) where:
input_ids (jnp.ndarray): Combined token IDs for model input. decodes_token_ids (jnp.ndarray): Token IDs array for decode phase. positions (jnp.ndarray): Position indices matching input_ids. attn_meta (PagedAttentionMetadata): Updated attention metadata.
- Return type
tuple
- static prepare_model_output(next_token: Array, complete: Array, attn_meta: PagedAttentionMetadata, sampling_params: SamplingParams) ModelOutputBatch[source]#
Processes the raw model output (logits converted to tokens) and completion flags to create a structured ModelOutputBatch object.
It separates the outputs corresponding to the prefill step (if any) and the decode steps (if any) based on the structure defined in attn_meta.
- Parameters
next_token (jax.Array) – The generated token IDs from the model. Concatenated prefill token (if any) and decode tokens (if any).
complete (jax.Array) – Boolean flags indicating sequence completion (e.g., EOS token) for the corresponding generated tokens.
attn_meta (PagedAttentionMetadata) – Metadata describing the structure of the input batch (prefill vs. decode parts).
- Returns
- A structured object containing separated prefill/decode outputs
and next positions.
- Return type
- class easydel.layers.caching.paged_attention.__init__.ModelInputBatch(input_ids: Array, positions: Array, attn_meta: PagedAttentionMetadata, sampling_params: SamplingParams)[source]#
Bases:
xTreeConsolidated input data for a single model forward pass in paged attention.
This structure gathers all necessary inputs for the model, potentially combining data for both a prefill step and multiple decode steps into a single batch structure suitable for the paged attention kernel.
- attn_meta#
Metadata required by the paged attention kernel, including sequence lengths, block tables, etc., for both prefill and decode parts of the batch.
- sampling_params#
Combined sampling parameters for all sequences included in the batch.
- Type
- attn_meta: PagedAttentionMetadata#
- replace(**updates)#
Returns a new instance of the dataclass with specified fields updated.
- Parameters
**updates – Keyword arguments where keys are field names and values are the new values for those fields.
- Returns
A new instance of the dataclass with the updated fields.
- sampling_params: SamplingParams#
- class easydel.layers.caching.paged_attention.__init__.ModelOutputBatch(prefill_complete: Array, decodes_completes: Array, prefill_token_id: Array, decodes_token_ids: Array, prefill_next_position: Array, decodes_next_position: Array, next_sampling_params: SamplingParams)[source]#
Bases:
objectOutput generated by the model after a paged attention forward pass.
Contains the results from the model, separating outputs corresponding to the prefill phase (if one was run) and the decode phase.
- prefill_complete#
Scalar boolean JAX array. True if the prefill operation (if run) generated a completion token (e.g., EOS).
- Type
- decodes_completes#
Boolean JAX array. Indicates for each sequence in the decode part of the batch whether a completion token was generated. Shape: (num_decode_sequences,).
- Type
- prefill_token_id#
Scalar JAX array. The token ID generated by the prefill step (if run).
- Type
- decodes_token_ids#
JAX array. The token IDs generated for each sequence in the decode part of the batch. Shape: (num_decode_sequences,).
- Type
- prefill_next_position#
Scalar JAX array. The next position index for the sequence processed in the prefill step.
- Type
- decodes_next_position#
JAX array. The next position indices for each sequence in the decode part of the batch. Completed sequences might have a special value (e.g., -1). Shape: (num_decode_sequences,).
- Type
- next_sampling_params#
Updated sampling parameters after the forward pass (e.g., potentially modified max_tokens).
- Type
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_empty()[source]#
Factory method to create an empty ModelOutputBatch instance with placeholder values.
- Returns
An instance with default scalar/vector values indicating no output.
- Return type
- next_sampling_params: SamplingParams#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.ModelOutputSummary(prefill_request_id: str | None, prefill_token_id: Array, prefill_complete: Array, decodes_active_slots: list[int], decodes_active_request_ids: list[str], decodes_token_ids: Array, decodes_completes: Array)[source]#
Bases:
objectSummarizes model output for scheduler and state updates (host-side).
This structure extracts and organizes key information from the ModelOutputBatch (which contains JAX arrays) into a format suitable for the host-side scheduler to process and update its internal state (like ActiveSequenceBatch’s scheduler attributes).
- prefill_request_id#
The ID of the prefill request processed in the last step, if any.
- Type
str | None
- prefill_complete#
Boolean flag indicating if the prefill step completed the sequence.
- Type
- decodes_active_slots#
List of batch slot indices that were active during the decode phase of the last step.
- Type
list[int]
- decodes_active_request_ids#
List of request IDs corresponding to the decodes_active_slots.
- Type
list[str]
- decodes_active_request_ids: list[str]#
- decodes_active_slots: list[int]#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod from_output(output: ModelOutputBatch) ModelOutputSummary[source]#
Creates a ModelOutputSummary from a ModelOutputBatch.
Initializes the summary, copying relevant fields from the model output. Scheduler-specific fields (prefill_request_id, decodes_active_slots, decodes_active_request_ids) are initialized as empty/None and need to be populated separately by the scheduler based on its knowledge of the batch composition.
- Parameters
output (ModelOutputBatch) – The model output batch containing JAX arrays.
- Returns
An initialized summary object ready for scheduler processing.
- Return type
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.NextIterationPlan(prefill_request: InitialSequenceRequest, schedule_prefill: bool, schedule_decodes: bool, prefill_pages_update: AllocatedPrefillPages, new_decodes_requests: list[easydel.layers.caching.paged_attention.types.GenerationStepTask], decodes_state_page_updates: list[easydel.layers.caching.paged_attention.types.SlotPageAssignment])[source]#
Bases:
objectEncapsulates the scheduling decisions for the next model iteration.
Based on available resources (like HBM pages and batch slots) and pending requests, the scheduler produces this plan, detailing which prefill and decode operations to execute next, along with necessary state updates.
- prefill_request#
The prefill request scheduled for the next iteration. Can be an empty/placeholder request if no prefill is scheduled.
- schedule_prefill#
True if a prefill operation should be executed.
- Type
bool
- schedule_decodes#
True if decode operations should be executed.
- Type
bool
- prefill_pages_update#
Contains the indices of pages newly allocated for the scheduled prefill request. Can be empty if no new pages were needed or no prefill is scheduled.
- new_decodes_requests#
A list of sequences that are newly transitioning into the decode phase in this iteration (e.g., after completing prefill).
- Type
list[GenerationStepTask]
- decodes_state_page_updates#
A list of updates to the page tables of sequences already in the decode phase, typically due to new page allocations for them.
- Type
list[SlotPageAssignment]
- decodes_state_page_updates: list[easydel.layers.caching.paged_attention.types.SlotPageAssignment]#
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- new_decodes_requests: list[easydel.layers.caching.paged_attention.types.GenerationStepTask]#
- prefill_pages_update: AllocatedPrefillPages#
- prefill_request: InitialSequenceRequest#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- schedule_decodes: bool#
- schedule_prefill: bool#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.PagedAttentionCache(views: List[PagedAttentionCacheView])[source]#
Bases:
BaseCacheRepresents the complete Paged Attention KV cache for all layers of a model.
It holds a list of PagedAttentionCacheView objects, one for each layer. It inherits from BaseCache.
- views#
A list containing the cache view for each layer in the model.
- Type
tp.List[PagedAttentionCacheView]
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_cache(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#
Initializes the entire PagedAttentionCache for all layers.
Creates a list of PagedAttentionCacheView instances, one for each layer specified in the metadata, by calling PagedAttentionCacheView.init for each layer.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages.
metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.
partition_manager (es.PartitionManager) – Manages tensor sharding.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.
- Returns
An initialized cache object containing views for all layers.
- Return type
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- views: List[PagedAttentionCacheView]#
- class easydel.layers.caching.paged_attention.__init__.PagedAttentionCacheMetaData(batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float)[source]#
Bases:
BaseCacheMetadataMetadata holding configuration parameters for the Paged Attention KV cache.
This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.
- batch_size#
The maximum number of sequences processed concurrently during decoding.
- Type
int
The total number of transformer layers in the model.
- Type
int
- num_pages_per_layer#
The total number of physical memory pages allocated for the KV cache per layer across all sequences. This is calculated based on available memory and hbm_utilization.
- Type
int
- num_pages_per_sequence#
The maximum number of pages a single sequence can occupy, determined by max_sequences and page_size.
- Type
int
- max_sequences#
The maximum sequence length supported by the cache allocation.
- Type
int
- page_size#
The number of tokens stored per page in the KV cache.
- Type
int
- num_kv_heads#
The number of key/value heads in the attention mechanism.
- Type
int
- kv_head_dim_size#
The dimension size of each key/value head.
- Type
int
- hbm_utilization#
The target fraction of available High Bandwidth Memory (HBM) to be utilized for the KV cache pages. Should be between 0.0 and 1.0.
- Type
float
- batch_size: int#
- classmethod create(mesh: ~jax._src.mesh.Mesh, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#
Factory method to create and initialize a PagedAttentionCacheMetaData instance.
Calculates derived values like num_pages_per_layer and num_pages_per_sequence based on the provided parameters and estimated available memory.
- Parameters
mesh (Mesh) – The JAX device mesh.
batch_size (int) – Maximum concurrent sequences for decode.
num_hidden_layers (int) – Number of transformer layers.
max_sequences (int) – Maximum supported sequence length.
page_size (int) – Number of tokens per cache page.
num_kv_heads (int) – Number of KV heads.
kv_head_dim_size (int) – Dimension of each KV head.
hbm_utilization (float) – Target HBM utilization fraction (0.0 to 1.0).
dtype (jnp.dtype) – Data type used for cache size calculation.
- Returns
An initialized metadata object.
- Return type
- Raises
ValueError – If input parameters are invalid (e.g., non-positive dimensions, invalid utilization factor).
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- hbm_utilization: float#
- kv_head_dim_size: int#
- max_sequences: int#
- num_hidden_layers: int#
- num_kv_heads: int#
- num_pages_per_layer: int#
- num_pages_per_sequence: int#
- page_size: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray])[source]#
Bases:
BaseCacheViewRepresents the view of the Paged Attention KV cache for a single transformer layer.
It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.
- metadata#
The static configuration metadata for the entire paged cache.
- layer_index#
The index of the transformer layer this view corresponds to.
- Type
int
- key_pages#
The tensor holding all key pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.
- Type
tp.Union[cx.Array, ImplicitArray]
- value_pages#
The tensor holding all value pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.
- Type
tp.Union[cx.Array, ImplicitArray]
- concatenate_to_cache(*args, **kwargs)[source]#
Concatenation is not applicable for Paged Attention. Raises NotImplementedError.
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#
Initializes the PagedAttentionCacheView for a specific layer.
Allocates the key_pages and value_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.
- Parameters
mesh (Mesh) – The JAX device mesh.
dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).
metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.
layer_index (int) – The index of the layer this view is for.
partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.
quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.
- Returns
An initialized cache view for the specified layer.
- Return type
- layer_index: int#
- metadata: PagedAttentionCacheMetaData#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- write_decodes_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
Writes the key/value pairs from a decode step into the appropriate cache pages.
Uses the decodes_position and decodes_page_table from the runtime metadata to calculate the exact page index and offset within that page where the new key/value pair for each sequence in the batch should be written. It reshapes the cache pages and input keys/values for efficient scattered updates using .at[…].set(…).
- Parameters
key (cx.Array) – Key tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).
value (cx.Array) – Value tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata) – Runtime metadata containing decodes_position and decodes_page_table.
- Returns
Returns self after updating the pages.
- Return type
- write_prefill_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#
Writes the key/value pairs from a prefill step into the appropriate cache pages.
Uses the prefill_page_table from the runtime metadata to determine which physical pages (key_pages, value_pages) correspond to the logical pages of the prefill sequence. It transposes and reshapes the input key/value tensors and uses jax.lax.dynamic_update_slice_in_dim within a while_loop to update the relevant pages.
- Parameters
key (cx.Array) – Key tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).
value (cx.Array) – Value tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).
metadata (PagedAttentionMetadata) – Runtime metadata containing the prefill_length and prefill_page_table.
- Returns
Returns self after updating the pages.
- Return type
- class easydel.layers.caching.paged_attention.__init__.PagedAttentionMetadata(prefill_length: Array, prefill_position: Array, prefill_page_table: Array, decodes_position: Array, decodes_page_table: Array)[source]#
Bases:
objectRuntime metadata required for performing a Paged Attention computation step.
This object holds the necessary information for a single forward pass of the paged attention mechanism, distinguishing between prefill and decode steps and providing the mappings (page tables) from logical sequence positions to physical cache pages.
- prefill_length#
Scalar JAX array containing the actual length of the prompt being processed in a prefill step. Shape (). Set to 0 if not in prefill.
- Type
- prefill_position#
JAX array of positions for the prefill tokens. Shape (padded_prompt_length,). Empty shape () if not in prefill.
- Type
- prefill_page_table#
JAX array mapping logical page indices of the prefill sequence to physical page indices in the KV cache. Shape (num_pages_for_prefill,). Empty shape () if not in prefill.
- Type
- decodes_position#
JAX array containing the current sequence position (or length - 1) for each sequence in the decode batch. Shape (batch_size,). Empty shape () if not in decode.
- Type
- decodes_page_table#
JAX array mapping logical page indices to physical page indices for each sequence in the decode batch. Shape (batch_size, num_pages_per_sequence). Empty shape () if not in decode.
- Type
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- is_decode_mode() bool[source]#
Creates an initial or placeholder PagedAttentionMetadata object. (Internal helper method).
- Returns
An instance with scalar placeholder values.
- Return type
- is_prefill_mode() bool[source]#
Checks if the current metadata represents a prefill-only step.
- Returns
- True if only prefill information is present (decode arrays have empty shape),
False otherwise.
- Return type
bool
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.SamplingParams(top_p: jax.Array | float = <factory>, max_tokens: jax.Array | int = <factory>, temperature: jax.Array | float = <factory>)[source]#
Bases:
objectConfiguration parameters for controlling text generation sampling.
This class holds parameters that influence the sampling process during text generation, such as top-p (nucleus) sampling, top-k sampling, maximum token generation, and temperature scaling.
- top_p#
The probability threshold for nucleus sampling. Defaults to 1.0 (no nucleus sampling).
- Type
jax.Array | float
- max_tokens#
The maximum number of tokens to generate for a sequence. Defaults to 32.
- Type
jax.Array | int
- temperature#
The temperature for scaling logits before sampling. Defaults to 0.0 (deterministic).
- Type
jax.Array | float
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- classmethod init_empty() SamplingParams[source]#
Creates an empty SamplingParams placeholder with scalar JAX arrays.
- Returns
A placeholder SamplingParams object.
- Return type
- classmethod init_jax(metadata: PagedAttentionCacheMetaData, sharding: NamedSharding) SamplingParams[source]#
Initializes SamplingParams with JAX arrays on the specified device/sharding.
- Parameters
metadata (PagedAttentionCacheMetaData) – Metadata containing batch size.
sharding (jax.sharding.NamedSharding) – The JAX sharding configuration.
- Returns
An initialized SamplingParams object with JAX arrays.
- Return type
- classmethod init_numpy(metadata: PagedAttentionCacheMetaData) SamplingParams[source]#
Initializes SamplingParams with NumPy arrays.
- Parameters
metadata (PagedAttentionCacheMetaData) – Metadata containing batch size.
- Returns
An initialized SamplingParams object with NumPy arrays.
- Return type
- insert_decode_state(insert_slots: Array, update: ActiveSequenceBatch)[source]#
Updates sampling parameters in specified slots from an ActiveSequenceBatch.
- Parameters
insert_slots (jax.Array) – An array of slot indices to update.
update (ActiveSequenceBatch) – The batch containing the new sampling parameters.
- insert_from_task(slot: int, task: GenerationStepTask)[source]#
Inserts sampling parameters from a GenerationStepTask into a specific slot.
- Parameters
slot (int) – The batch slot index to insert the parameters into.
task (GenerationStepTask) – The task containing the sampling parameters to insert.
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.
- class easydel.layers.caching.paged_attention.__init__.SlotPageAssignment(slot: int, page_idx: int, mapped_idx: int)[source]#
Bases:
objectRepresents the assignment of a physical page to a logical page slot.
During decoding, as sequences grow, new physical memory pages might be allocated. This class represents the update instruction to map a specific logical page index within a sequence’s page table (identified by its slot) to a newly allocated physical HBM page index (mapped_idx).
- slot#
The batch slot index of the sequence whose page table is updated.
- Type
int
- page_idx#
The logical page index within the sequence’s page table.
- Type
int
- mapped_idx#
The physical HBM page index to map the logical page to.
- Type
int
- classmethod from_dict(data: Dict[str, Any]) T#
Deserializes a dictionary into a PyTree object.
- classmethod from_json(json_str: str) T#
Deserializes a JSON string into a PyTree object.
- mapped_idx: int#
- page_idx: int#
- replace(**kwargs)#
Creates a new instance with specified fields replaced.
- slot: int#
- to_dict() Dict[str, Any]#
Serializes the PyTree object to a dictionary.
- to_json(**kwargs) str#
Serializes the PyTree object to a JSON string.