easydel.layers.caching.paged_attention.__init__

Contents

easydel.layers.caching.paged_attention.__init__#

class easydel.layers.caching.paged_attention.__init__.ActiveSequenceBatch(token_ids: ~typing.Union[~jax.Array, ~typing.List[int]], positions: ~jax.Array, page_table: ~jax.Array, sampling_params: ~easydel.layers.caching.paged_attention.types.SamplingParams, available_slots: ~typing.Optional[~_queue.SimpleQueue], active_slot_requests_map: ~typing.Optional[~typing.Dict[int, ~easydel.layers.caching.paged_attention.types.GenerationStepTask]], context_lock: ~typing.Optional[~_thread.allocate_lock] = <unlocked _thread.lock object>, page_update_slots: ~typing.Optional[~jax.Array] = None, page_update_page_idxs: ~typing.Optional[~jax.Array] = None, page_update_mapped_idxs: ~typing.Optional[~jax.Array] = None)[source]#

Bases: object

Manages the batch state for sequences in the decoding phase.

This class holds the dynamic state for all sequences currently undergoing token generation (decoding) within the paged attention batch. It includes the input token IDs for the next step, current positions, page table mappings, sampling parameters, and structures for managing available batch slots.

token_ids#

JAX array (during model execution) or list (during host-side updates) holding the input token ID for the next decode step for each active slot. Shape: (batch_size,).

Type

jax.Array | tp.List[int]

positions#

JAX array holding the current sequence position (length) for each active slot. Inactive slots might have a placeholder value (e.g., -1 or 1e6). Shape: (batch_size,).

Type

jax.Array

page_table#

JAX array mapping logical page indices to physical HBM page indices for each sequence slot. Shape: (batch_size, num_pages_per_sequence).

Type

jax.Array

sampling_params#

Sampling parameters for all sequences in the batch. (Inference Argument)

Type

SamplingParams

available_slots#

A queue managing indices of free batch slots. Used by the host-side scheduler. (Scheduler Argument)

Type

tp.Optional[queue.SimpleQueue]

active_slot_requests_map#

A dictionary mapping active slot indices to their corresponding GenerationStepTask. Used by the host-side scheduler. (Scheduler Argument)

Type

tp.Optional[tp.Dict[int, GenerationStepTask]]

context_lock#

A lock for thread-safe updates to scheduler-related attributes (available_slots, active_slot_requests_map). (Scheduler Argument)

Type

tp.Optional[threading.Lock]

page_update_slots#

Array storing slot indices for pending page table updates. (Internal State)

Type

tp.Optional[jax.Array]

page_update_page_idxs#

Array storing logical page indices for pending page table updates. (Internal State)

Type

tp.Optional[jax.Array]

page_update_mapped_idxs#

Array storing physical mapped indices for pending page table updates. (Internal State)

Type

tp.Optional[jax.Array]

active_slot_requests_map: Optional[Dict[int, GenerationStepTask]]#
apply_assignment(assignment: list[easydel.layers.caching.paged_attention.types.SlotPageAssignment])[source]#

Applies page table assignments to the internal update arrays (host-side).

This method populates the page_update_* arrays based on a list of SlotPageAssignment objects, preparing them for later use in JAX computations (insert_decode_state).

Parameters

assignment (list[SlotPageAssignment]) – A list of page assignments to apply.

available_slots: Optional[SimpleQueue]#
context_lock: Optional[allocate_lock] = <unlocked _thread.lock object>#
copy_decode(decode: ActiveSequenceBatch)[source]#

Copies essential decode state from another ActiveSequenceBatch (host-side).

Updates the current object’s token_ids, positions, page_table, and sampling_params based on the source decode object. Assumes host-side operation.

Parameters

decode (ActiveSequenceBatch) – The source batch state to copy from.

classmethod create(metadata: PagedAttentionCacheMetaData, mesh: Mesh)[source]#

Creates and initializes an ActiveSequenceBatch for JAX execution.

This factory method sets up the ActiveSequenceBatch with JAX arrays (token_ids, positions, page_table, sampling_params) appropriately sharded across the provided mesh. It also initializes the host-side scheduler components (available_slots, active_slot_requests_map).

Parameters
  • metadata (PagedAttentionCacheMetaData) – Paged attention cache configuration.

  • mesh (common_types.Mesh) – The JAX device mesh for array distribution.

Returns

An initialized batch state ready for JAX operations.

Return type

ActiveSequenceBatch

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_empty()[source]#
classmethod init_numpy(metadata: PagedAttentionCacheMetaData) ActiveSequenceBatch[source]#

Initializes ActiveSequenceBatch with NumPy arrays for host-side use.

Creates the necessary arrays (positions, page_table, page_update_*) using NumPy, suitable for manipulation outside JAX computations. token_ids is initialized as an empty list.

Parameters

metadata (PagedAttentionCacheMetaData) – Configuration metadata.

Returns

An initialized batch state object with NumPy arrays.

Return type

ActiveSequenceBatch

insert_decode_state(insert_slots: Array, update: ActiveSequenceBatch)[source]#

Updates the decode state in specified slots from another batch state.

This is typically used during JAX computation to incorporate updates from newly scheduled decode tasks or page table modifications.

Parameters
  • insert_slots (jax.Array) – An array of slot indices to update.

  • update (ActiveSequenceBatch) – The batch containing the new state to insert.

insert_from_task(slot: int, task: GenerationStepTask)[source]#

Inserts state from a GenerationStepTask into a specific slot (host-side).

Updates the host-side representations (token_ids list, positions array, page_table array, sampling_params) for the given slot based on the provided task.

Parameters
  • slot (int) – The batch slot index to insert the task state into.

  • task (GenerationStepTask) – The task containing the state to insert.

property is_active#

Checks if the batch state is active (i.e., has associated token IDs).

Returns

True if token_ids is a non-empty array, False otherwise.

Return type

bool

pad_tokens(pad_length: int)[source]#

Pads the host-side token list with placeholder scalars.

Ensures the token_ids list (when used host-side) reaches the required pad_length by appending placeholder scalar JAX arrays.

Parameters

pad_length (int) – The target length for the token_ids list.

page_table: Array#
page_update_mapped_idxs: Optional[Array] = None#
page_update_page_idxs: Optional[Array] = None#
page_update_slots: Optional[Array] = None#
positions: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sampling_params: SamplingParams#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

token_ids: Union[Array, List[int]]#
class easydel.layers.caching.paged_attention.__init__.AllocatedPrefillPages(page_indices: list[int])[source]#

Bases: object

Holds the indices of HBM pages allocated during a prefill step.

This simple structure is used to communicate which physical memory pages have been assigned to a sequence during its prefill processing.

page_indices#

A list containing the indices of the HBM (High Bandwidth Memory) pages allocated for a specific prefill chunk.

Type

list[int]

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

page_indices: list[int]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.__init__.GenerationStepTask(id: str, slot: int, position: int, page_indices: list[int], prefill_token_id: Array, sampling_params: SamplingParams)[source]#

Bases: object

Represents a sequence actively undergoing token generation (decoding).

This class holds the necessary information for a single sequence that is currently in the decoding phase within the paged attention batch.

id#

The unique identifier tracing back to the original request.

Type

str

slot#

The assigned batch slot index for this sequence.

Type

int

position#

The current sequence length (position) for the next token.

Type

int

page_indices#

The list of physical HBM page indices allocated to this sequence’s KV cache.

Type

list[int]

prefill_token_id#

The token ID generated in the previous step (either prefill or the last decode step), which serves as input for the current decode step.

Type

jax.Array

sampling_params#

The sampling parameters associated with this specific generation task.

Type

SamplingParams

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

id: str#
page_indices: list[int]#
position: int#
prefill_token_id: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sampling_params: SamplingParams#
slot: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.__init__.HBMPageManager(metadata: PagedAttentionCacheMetaData)[source]#

Bases: object

Manages the allocation and deallocation of physical HBM pages for the KV cache. It keeps track of available pages.

_metadata#

Configuration for the paged cache.

Type

PagedAttentionCacheMetaData

_current_page_index#

Index representing the initial dummy page.

Type

int

_available_hbm_pages#

Queue of free HBM page indices.

Type

queue.SimpleQueue

alloc_hbm_pages(n: int) list[int][source]#

Allocates a specific number of HBM pages.

Parameters

n (int) – Number of pages to allocate.

Returns

Allocated HBM page indices (empty if insufficient pages).

Return type

list[int]

alloc_prefill_hbm_pages(prompt_len) list[int][source]#

Allocates the required number of HBM pages for a prompt prefill based on its length.

Parameters

prompt_len (int) – The length of the prompt (or chunk).

Returns

List of allocated HBM page indices (empty if insufficient pages).

Return type

list[int]

property current_page_index#

Returns the dummy page index (usually 0).

free_hbm_pages(pages: list[int])[source]#

Returns a list of HBM pages back to the available pool.

Parameters

pages (list[int]) – HBM page indices to free (ignores dummy page).

property metadata: PagedAttentionCacheMetaData#

Returns the cache metadata.

property page_size#

Number of per-token KV cache items per page.

class easydel.layers.caching.paged_attention.__init__.InferenceScheduler(manager: HBMPageManager)[source]#

Bases: object

Schedules incoming prefill and decode requests based on HBM page and slot availability. It coordinates with HBMPageManager and ActiveSequenceBatch.

prefill_queue#

Incoming prefill requests queue.

Type

queue.Queue[InitialSequenceRequest]

decodes_queue#

Prefill-to-decode transition queue.

Type

queue.Queue[GenerationStepTask]

manager#

Manager for HBM page allocation.

Type

HBMPageManager

batch_size#

Max concurrent decode requests.

Type

int

max_seq_len#

Max sequence length supported.

Type

int

create_plan(active_prefill: easydel.layers.caching.paged_attention.types.InitialSequenceRequest | None, decodes_state: ActiveSequenceBatch) easydel.layers.caching.paged_attention.types.NextIterationPlan | None[source]#

Determines the workload for the next model iteration.

This function decides: 1. If a new prefill request can be started or an existing one continued. 2. If HBM pages need to be allocated for the prefill request. 3. Which decode requests from the decodes_queue can be added to the active batch. 4. If HBM pages need to be allocated for ongoing decode requests.

It updates the host-side state (like decodes_state.active_slot_requests_map) but prepares updates for the device-side state (JAX arrays) within the returned NextIterationPlan.

Parameters
Returns

An object containing the scheduling decisions and necessary updates

for the next iteration, or None if no work can be scheduled (e.g., all queues empty).

Return type

NextIterationPlan | None

Raises

NotImplementedError – If page allocation fails and eviction is required (currently not supported).

enqueue_decodes_request(request: GenerationStepTask)[source]#

Adds a completed prefill request to the decode queue.

enqueue_prefill_request(request: InitialSequenceRequest)[source]#

Adds a prefill request to the prefill queue.

class easydel.layers.caching.paged_attention.__init__.InitialSequenceRequest(token_ids: Array, positions: Array, page_indices: Array, sampling_params: SamplingParams, id: Optional[Union[str, int]], chunk_idx: Optional[int], chunk_size: Optional[int], prompt_token_ids: Optional[list[int]], length: Optional[Array] = None)[source]#

Bases: object

Represents a request for processing a new sequence during the prefill phase.

This class encapsulates the information needed to process a new input sequence (prompt) in the paged attention mechanism. It includes the token IDs, positions, allocated page indices, and associated metadata.

token_ids#

JAX array of token IDs for the prefill sequence, potentially padded. (Runtime Argument)

Type

jax.Array

positions#

JAX array of position IDs corresponding to token_ids. (Runtime Argument)

Type

jax.Array

page_indices#

JAX array holding the indices of HBM pages allocated for this request’s KV cache. (Runtime Argument)

Type

jax.Array

sampling_params#

Sampling parameters for this specific request. (Runtime Argument)

Type

SamplingParams

id#

A unique identifier for the request. (Scheduler Argument)

Type

tp.Optional[str | int]

chunk_idx#

The index of the current chunk being processed if the prompt is chunked. (Scheduler Argument)

Type

tp.Optional[int]

chunk_size#

The size of the token chunk being processed in this prefill step. (Scheduler Argument)

Type

tp.Optional[int]

prompt_token_ids#

The original list of token IDs for the complete prompt. (Scheduler Argument)

Type

tp.Optional[list[int]]

length#

The actual length of the sequence processed so far (relevant for chunked prefill). Defaults to None. (Scheduler Argument)

Type

tp.Optional[jax.Array]

chunk_idx: Optional[int]#
chunk_size: Optional[int]#
copy_prefill(prefill: InitialSequenceRequest)[source]#

Copies runtime state from another InitialSequenceRequest (prefill source).

This method updates the current request’s runtime attributes (token_ids, positions, page_indices, sampling_params, length) based on the state of a source prefill request, typically used when advancing through chunks of a long prompt.

Parameters

prefill (InitialSequenceRequest) – The source request from which to copy runtime state.

classmethod create(id: str, mesh: Mesh, metadata: PagedAttentionCacheMetaData, chunk_size: int, prompt_token_ids: list[int], max_prefill_length: Optional[int] = None, prefill_lengths: Optional[List[int]] = None, sampling_params: Optional[SamplingParams] = None)[source]#

Creates an InitialSequenceRequest from prompt token IDs.

This factory method takes a list of token IDs and prepares them for the prefill phase, including padding, creating position IDs, initializing page indices, and setting up sampling parameters. Arrays are placed on the specified JAX mesh.

Parameters
  • mesh (common_types.Mesh) – The JAX device mesh for array distribution.

  • metadata (PagedAttentionCacheMetaData) – Paged attention cache configuration.

  • chunk_size (int) – The size for potential chunking during prefill.

  • prompt_token_ids (list[int]) – The input prompt token IDs.

  • prefill_length (tp.Optional[int]) – Target length for padding token IDs. Defaults to metadata.max_sequences.

  • sampling_params (tp.Optional[SamplingParams]) – Custom sampling parameters. If None, default parameters are used.

Returns

An initialized request object ready for prefill.

Return type

InitialSequenceRequest

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

id: Optional[Union[str, int]]#
classmethod init_empty()[source]#

Creates an empty InitialSequenceRequest placeholder.

Initializes attributes with placeholder scalar JAX arrays suitable for use in contexts where a valid request might not be present (e.g., padding).

Returns

A placeholder request object.

Return type

InitialSequenceRequest

property is_active#

Checks if the request is active (i.e., has associated token IDs).

Returns

True if token_ids is a non-empty array, False otherwise.

Return type

bool

length: Optional[Array] = None#
page_indices: Array#
positions: Array#
prompt_token_ids: Optional[list[int]]#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

sampling_params: SamplingParams#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

token_ids: Array#
class easydel.layers.caching.paged_attention.__init__.ModelIOProcessor[source]#

Bases: object

Processes and transforms model inputs and outputs for paged attention.

This class handles the construction of model input batches from prefill and decode states, and organizes raw model outputs into structured data for further processing.

classmethod build_input(iteration_plan: NextIterationPlan, metadata: PagedAttentionCacheMetaData, decodes_state: ActiveSequenceBatch)[source]#

Orchestrates the preparation of model inputs based on a scheduling decision.

It retrieves the necessary data from the iteration_plan and decodes_state, formats it (converting lists/numpy arrays to JAX arrays as needed), and calls prepare_model_input to perform the JIT-compiled merging and updating. It also updates the decodes_state (device arrays) based on the results from prepare_model_input.

Parameters
Returns

The structured input ready for the model’s forward pass.

Return type

ModelInputBatch

static prepare_model_input(attn_meta: PagedAttentionMetadata, ongoing_prefill: InitialSequenceRequest, chunk_id: Array, chunk_size: Array, ongoing_decodes: ActiveSequenceBatch, ongoing_updates: ActiveSequenceBatch, insert_slots: Array)[source]#

Prepares model input by slicing and concatenating prefill and decode tokens, updating attention metadata.

Parameters
  • attn_meta (PagedAttentionMetadata) – Metadata to update with positions and page tables.

  • ongoing_prefill (InitialSequenceRequest) – Current prefill request state.

  • chunk_id (jax.Array) – Index of the current prefill chunk.

  • chunk_size (jax.Array) – Size of each prefill chunk.

  • ongoing_decodes (ActiveSequenceBatch) – Current decode batch state.

  • ongoing_updates (ActiveSequenceBatch) – Batch updates for ongoing decodes.

  • insert_slots (jax.Array) – Slots indices indicating where to insert new decodes.

Returns

(input_ids, decodes_token_ids, positions, attn_meta) where:

input_ids (jnp.ndarray): Combined token IDs for model input. decodes_token_ids (jnp.ndarray): Token IDs array for decode phase. positions (jnp.ndarray): Position indices matching input_ids. attn_meta (PagedAttentionMetadata): Updated attention metadata.

Return type

tuple

static prepare_model_output(next_token: Array, complete: Array, attn_meta: PagedAttentionMetadata, sampling_params: SamplingParams) ModelOutputBatch[source]#

Processes the raw model output (logits converted to tokens) and completion flags to create a structured ModelOutputBatch object.

It separates the outputs corresponding to the prefill step (if any) and the decode steps (if any) based on the structure defined in attn_meta.

Parameters
  • next_token (jax.Array) – The generated token IDs from the model. Concatenated prefill token (if any) and decode tokens (if any).

  • complete (jax.Array) – Boolean flags indicating sequence completion (e.g., EOS token) for the corresponding generated tokens.

  • attn_meta (PagedAttentionMetadata) – Metadata describing the structure of the input batch (prefill vs. decode parts).

Returns

A structured object containing separated prefill/decode outputs

and next positions.

Return type

ModelOutputBatch

class easydel.layers.caching.paged_attention.__init__.ModelInputBatch(input_ids: Array, positions: Array, attn_meta: PagedAttentionMetadata, sampling_params: SamplingParams)[source]#

Bases: xTree

Consolidated input data for a single model forward pass in paged attention.

This structure gathers all necessary inputs for the model, potentially combining data for both a prefill step and multiple decode steps into a single batch structure suitable for the paged attention kernel.

input_ids#

Combined token IDs for prefill and decode sequences.

Type

jax.Array

positions#

Combined position IDs corresponding to input_ids.

Type

jax.Array

attn_meta#

Metadata required by the paged attention kernel, including sequence lengths, block tables, etc., for both prefill and decode parts of the batch.

Type

PagedAttentionMetadata

sampling_params#

Combined sampling parameters for all sequences included in the batch.

Type

SamplingParams

attn_meta: PagedAttentionMetadata#
input_ids: Array#
positions: Array#
replace(**updates)#

Returns a new instance of the dataclass with specified fields updated.

Parameters

**updates – Keyword arguments where keys are field names and values are the new values for those fields.

Returns

A new instance of the dataclass with the updated fields.

sampling_params: SamplingParams#
class easydel.layers.caching.paged_attention.__init__.ModelOutputBatch(prefill_complete: Array, decodes_completes: Array, prefill_token_id: Array, decodes_token_ids: Array, prefill_next_position: Array, decodes_next_position: Array, next_sampling_params: SamplingParams)[source]#

Bases: object

Output generated by the model after a paged attention forward pass.

Contains the results from the model, separating outputs corresponding to the prefill phase (if one was run) and the decode phase.

prefill_complete#

Scalar boolean JAX array. True if the prefill operation (if run) generated a completion token (e.g., EOS).

Type

jax.Array

decodes_completes#

Boolean JAX array. Indicates for each sequence in the decode part of the batch whether a completion token was generated. Shape: (num_decode_sequences,).

Type

jax.Array

prefill_token_id#

Scalar JAX array. The token ID generated by the prefill step (if run).

Type

jax.Array

decodes_token_ids#

JAX array. The token IDs generated for each sequence in the decode part of the batch. Shape: (num_decode_sequences,).

Type

jax.Array

prefill_next_position#

Scalar JAX array. The next position index for the sequence processed in the prefill step.

Type

jax.Array

decodes_next_position#

JAX array. The next position indices for each sequence in the decode part of the batch. Completed sequences might have a special value (e.g., -1). Shape: (num_decode_sequences,).

Type

jax.Array

next_sampling_params#

Updated sampling parameters after the forward pass (e.g., potentially modified max_tokens).

Type

SamplingParams

decodes_completes: Array#
decodes_next_position: Array#
decodes_token_ids: Array#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_empty()[source]#

Factory method to create an empty ModelOutputBatch instance with placeholder values.

Returns

An instance with default scalar/vector values indicating no output.

Return type

ModelOutputBatch

next_sampling_params: SamplingParams#
prefill_complete: Array#
prefill_next_position: Array#
prefill_token_id: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.__init__.ModelOutputSummary(prefill_request_id: str | None, prefill_token_id: Array, prefill_complete: Array, decodes_active_slots: list[int], decodes_active_request_ids: list[str], decodes_token_ids: Array, decodes_completes: Array)[source]#

Bases: object

Summarizes model output for scheduler and state updates (host-side).

This structure extracts and organizes key information from the ModelOutputBatch (which contains JAX arrays) into a format suitable for the host-side scheduler to process and update its internal state (like ActiveSequenceBatch’s scheduler attributes).

prefill_request_id#

The ID of the prefill request processed in the last step, if any.

Type

str | None

prefill_token_id#

The token ID generated by the prefill step.

Type

jax.Array

prefill_complete#

Boolean flag indicating if the prefill step completed the sequence.

Type

jax.Array

decodes_active_slots#

List of batch slot indices that were active during the decode phase of the last step.

Type

list[int]

decodes_active_request_ids#

List of request IDs corresponding to the decodes_active_slots.

Type

list[str]

decodes_token_ids#

Token IDs generated for the active decode slots.

Type

jax.Array

decodes_completes#

Boolean flags indicating completion for each active decode slot.

Type

jax.Array

decodes_active_request_ids: list[str]#
decodes_active_slots: list[int]#
decodes_completes: Array#
decodes_token_ids: Array#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod from_output(output: ModelOutputBatch) ModelOutputSummary[source]#

Creates a ModelOutputSummary from a ModelOutputBatch.

Initializes the summary, copying relevant fields from the model output. Scheduler-specific fields (prefill_request_id, decodes_active_slots, decodes_active_request_ids) are initialized as empty/None and need to be populated separately by the scheduler based on its knowledge of the batch composition.

Parameters

output (ModelOutputBatch) – The model output batch containing JAX arrays.

Returns

An initialized summary object ready for scheduler processing.

Return type

ModelOutputSummary

prefill_complete: Array#
prefill_request_id: str | None#
prefill_token_id: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.__init__.NextIterationPlan(prefill_request: InitialSequenceRequest, schedule_prefill: bool, schedule_decodes: bool, prefill_pages_update: AllocatedPrefillPages, new_decodes_requests: list[easydel.layers.caching.paged_attention.types.GenerationStepTask], decodes_state_page_updates: list[easydel.layers.caching.paged_attention.types.SlotPageAssignment])[source]#

Bases: object

Encapsulates the scheduling decisions for the next model iteration.

Based on available resources (like HBM pages and batch slots) and pending requests, the scheduler produces this plan, detailing which prefill and decode operations to execute next, along with necessary state updates.

prefill_request#

The prefill request scheduled for the next iteration. Can be an empty/placeholder request if no prefill is scheduled.

Type

InitialSequenceRequest

schedule_prefill#

True if a prefill operation should be executed.

Type

bool

schedule_decodes#

True if decode operations should be executed.

Type

bool

prefill_pages_update#

Contains the indices of pages newly allocated for the scheduled prefill request. Can be empty if no new pages were needed or no prefill is scheduled.

Type

AllocatedPrefillPages

new_decodes_requests#

A list of sequences that are newly transitioning into the decode phase in this iteration (e.g., after completing prefill).

Type

list[GenerationStepTask]

decodes_state_page_updates#

A list of updates to the page tables of sequences already in the decode phase, typically due to new page allocations for them.

Type

list[SlotPageAssignment]

decodes_state_page_updates: list[easydel.layers.caching.paged_attention.types.SlotPageAssignment]#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

new_decodes_requests: list[easydel.layers.caching.paged_attention.types.GenerationStepTask]#
prefill_pages_update: AllocatedPrefillPages#
prefill_request: InitialSequenceRequest#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

schedule_decodes: bool#
schedule_prefill: bool#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.__init__.PagedAttentionCache(views: List[PagedAttentionCacheView])[source]#

Bases: BaseCache

Represents the complete Paged Attention KV cache for all layers of a model.

It holds a list of PagedAttentionCacheView objects, one for each layer. It inherits from BaseCache.

views#

A list containing the cache view for each layer in the model.

Type

tp.List[PagedAttentionCacheView]

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_cache(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#

Initializes the entire PagedAttentionCache for all layers.

Creates a list of PagedAttentionCacheView instances, one for each layer specified in the metadata, by calling PagedAttentionCacheView.init for each layer.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages.

  • metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.

  • partition_manager (es.PartitionManager) – Manages tensor sharding.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply.

Returns

An initialized cache object containing views for all layers.

Return type

PagedAttentionCache

init_empty(*args, **kwargs)[source]#

Not typically used for PagedAttentionCache; returns None.

replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

views: List[PagedAttentionCacheView]#
class easydel.layers.caching.paged_attention.__init__.PagedAttentionCacheMetaData(batch_size: int, num_hidden_layers: int, num_pages_per_layer: int, num_pages_per_sequence: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float)[source]#

Bases: BaseCacheMetadata

Metadata holding configuration parameters for the Paged Attention KV cache.

This class stores static configuration details required to initialize and manage a paged KV cache, such as dimensions, page sizes, and resource utilization hints. It inherits from BaseCacheMetadata.

batch_size#

The maximum number of sequences processed concurrently during decoding.

Type

int

num_hidden_layers#

The total number of transformer layers in the model.

Type

int

num_pages_per_layer#

The total number of physical memory pages allocated for the KV cache per layer across all sequences. This is calculated based on available memory and hbm_utilization.

Type

int

num_pages_per_sequence#

The maximum number of pages a single sequence can occupy, determined by max_sequences and page_size.

Type

int

max_sequences#

The maximum sequence length supported by the cache allocation.

Type

int

page_size#

The number of tokens stored per page in the KV cache.

Type

int

num_kv_heads#

The number of key/value heads in the attention mechanism.

Type

int

kv_head_dim_size#

The dimension size of each key/value head.

Type

int

hbm_utilization#

The target fraction of available High Bandwidth Memory (HBM) to be utilized for the KV cache pages. Should be between 0.0 and 1.0.

Type

float

batch_size: int#
classmethod create(mesh: ~jax._src.mesh.Mesh, batch_size: int, num_hidden_layers: int, max_sequences: int, page_size: int, num_kv_heads: int, kv_head_dim_size: int, hbm_utilization: float, dtype: ~numpy.dtype = <class 'jax.numpy.bfloat16'>) PagedAttentionCacheMetaData[source]#

Factory method to create and initialize a PagedAttentionCacheMetaData instance.

Calculates derived values like num_pages_per_layer and num_pages_per_sequence based on the provided parameters and estimated available memory.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • batch_size (int) – Maximum concurrent sequences for decode.

  • num_hidden_layers (int) – Number of transformer layers.

  • max_sequences (int) – Maximum supported sequence length.

  • page_size (int) – Number of tokens per cache page.

  • num_kv_heads (int) – Number of KV heads.

  • kv_head_dim_size (int) – Dimension of each KV head.

  • hbm_utilization (float) – Target HBM utilization fraction (0.0 to 1.0).

  • dtype (jnp.dtype) – Data type used for cache size calculation.

Returns

An initialized metadata object.

Return type

PagedAttentionCacheMetaData

Raises

ValueError – If input parameters are invalid (e.g., non-positive dimensions, invalid utilization factor).

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

hbm_utilization: float#
kv_head_dim_size: int#
max_sequences: int#
num_hidden_layers: int#
num_kv_heads: int#
num_pages_per_layer: int#
num_pages_per_sequence: int#
page_size: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.__init__.PagedAttentionCacheView(metadata: PagedAttentionCacheMetaData, layer_index: int, key_pages: Union[Array, ndarray, bool, number, ImplicitArray], value_pages: Union[Array, ndarray, bool, number, ImplicitArray])[source]#

Bases: BaseCacheView

Represents the view of the Paged Attention KV cache for a single transformer layer.

It holds references to the physical key and value pages allocated for this layer and the associated metadata. It provides methods to write new key/value pairs into the correct pages based on runtime metadata. It inherits from BaseCacheView.

metadata#

The static configuration metadata for the entire paged cache.

Type

PagedAttentionCacheMetaData

layer_index#

The index of the transformer layer this view corresponds to.

Type

int

key_pages#

The tensor holding all key pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.

Type

tp.Union[cx.Array, ImplicitArray]

value_pages#

The tensor holding all value pages for this layer. Shape: (num_kv_heads, num_pages_per_layer, page_size, kv_head_dim_size). Can be a JAX array or an ImplicitArray if quantization is used.

Type

tp.Union[cx.Array, ImplicitArray]

concatenate_to_cache(*args, **kwargs)[source]#

Concatenation is not applicable for Paged Attention. Raises NotImplementedError.

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init(mesh: Mesh, dtype: dtype, metadata: PagedAttentionCacheMetaData, layer_index: int, partition_manager: PartitionManager, quantizer: Optional[object] = None)[source]#

Initializes the PagedAttentionCacheView for a specific layer.

Allocates the key_pages and value_pages tensors with the appropriate shape, dtype, and sharding based on the provided metadata and partition manager. Optionally applies quantization if a quantizer is provided.

Parameters
  • mesh (Mesh) – The JAX device mesh.

  • dtype (jnp.dtype) – The data type for the cache pages (e.g., jnp.bfloat16).

  • metadata (PagedAttentionCacheMetaData) – Static configuration for the cache.

  • layer_index (int) – The index of the layer this view is for.

  • partition_manager (es.PartitionManager) – Manages tensor sharding across the mesh.

  • quantizer (tp.Optional["EasyQuantizer"]) – Optional quantizer to apply to the pages.

Returns

An initialized cache view for the specified layer.

Return type

PagedAttentionCacheView

key_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
layer_index: int#
metadata: PagedAttentionCacheMetaData#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

value_pages: Union[Array, ndarray, bool, number, ImplicitArray]#
write_decodes_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#

Writes the key/value pairs from a decode step into the appropriate cache pages.

Uses the decodes_position and decodes_page_table from the runtime metadata to calculate the exact page index and offset within that page where the new key/value pair for each sequence in the batch should be written. It reshapes the cache pages and input keys/values for efficient scattered updates using .at[…].set(…).

Parameters
  • key (cx.Array) – Key tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).

  • value (cx.Array) – Value tensor for the decode tokens. Shape (batch_size, num_kv_heads, kv_head_dim_size).

  • metadata (PagedAttentionMetadata) – Runtime metadata containing decodes_position and decodes_page_table.

Returns

Returns self after updating the pages.

Return type

PagedAttentionCacheView

write_prefill_to_cache(key: Union[Array, ndarray, bool, number], value: Union[Array, ndarray, bool, number], metadata: PagedAttentionMetadata)[source]#

Writes the key/value pairs from a prefill step into the appropriate cache pages.

Uses the prefill_page_table from the runtime metadata to determine which physical pages (key_pages, value_pages) correspond to the logical pages of the prefill sequence. It transposes and reshapes the input key/value tensors and uses jax.lax.dynamic_update_slice_in_dim within a while_loop to update the relevant pages.

Parameters
  • key (cx.Array) – Key tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).

  • value (cx.Array) – Value tensor for the prefill sequence. Shape (padded_prefill_len, num_kv_heads, kv_head_dim_size).

  • metadata (PagedAttentionMetadata) – Runtime metadata containing the prefill_length and prefill_page_table.

Returns

Returns self after updating the pages.

Return type

PagedAttentionCacheView

class easydel.layers.caching.paged_attention.__init__.PagedAttentionMetadata(prefill_length: Array, prefill_position: Array, prefill_page_table: Array, decodes_position: Array, decodes_page_table: Array)[source]#

Bases: object

Runtime metadata required for performing a Paged Attention computation step.

This object holds the necessary information for a single forward pass of the paged attention mechanism, distinguishing between prefill and decode steps and providing the mappings (page tables) from logical sequence positions to physical cache pages.

prefill_length#

Scalar JAX array containing the actual length of the prompt being processed in a prefill step. Shape (). Set to 0 if not in prefill.

Type

jax.Array

prefill_position#

JAX array of positions for the prefill tokens. Shape (padded_prompt_length,). Empty shape () if not in prefill.

Type

jax.Array

prefill_page_table#

JAX array mapping logical page indices of the prefill sequence to physical page indices in the KV cache. Shape (num_pages_for_prefill,). Empty shape () if not in prefill.

Type

jax.Array

decodes_position#

JAX array containing the current sequence position (or length - 1) for each sequence in the decode batch. Shape (batch_size,). Empty shape () if not in decode.

Type

jax.Array

decodes_page_table#

JAX array mapping logical page indices to physical page indices for each sequence in the decode batch. Shape (batch_size, num_pages_per_sequence). Empty shape () if not in decode.

Type

jax.Array

decodes_page_table: Array#
decodes_position: Array#
classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_empty()[source]#
is_decode_mode() bool[source]#

Creates an initial or placeholder PagedAttentionMetadata object. (Internal helper method).

Returns

An instance with scalar placeholder values.

Return type

PagedAttentionMetadata

is_prefill_mode() bool[source]#

Checks if the current metadata represents a prefill-only step.

Returns

True if only prefill information is present (decode arrays have empty shape),

False otherwise.

Return type

bool

prefill_length: Array#
prefill_page_table: Array#
prefill_position: Array#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

class easydel.layers.caching.paged_attention.__init__.SamplingParams(top_p: jax.Array | float = <factory>, max_tokens: jax.Array | int = <factory>, temperature: jax.Array | float = <factory>)[source]#

Bases: object

Configuration parameters for controlling text generation sampling.

This class holds parameters that influence the sampling process during text generation, such as top-p (nucleus) sampling, top-k sampling, maximum token generation, and temperature scaling.

top_p#

The probability threshold for nucleus sampling. Defaults to 1.0 (no nucleus sampling).

Type

jax.Array | float

max_tokens#

The maximum number of tokens to generate for a sequence. Defaults to 32.

Type

jax.Array | int

temperature#

The temperature for scaling logits before sampling. Defaults to 0.0 (deterministic).

Type

jax.Array | float

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

classmethod init_empty() SamplingParams[source]#

Creates an empty SamplingParams placeholder with scalar JAX arrays.

Returns

A placeholder SamplingParams object.

Return type

SamplingParams

classmethod init_jax(metadata: PagedAttentionCacheMetaData, sharding: NamedSharding) SamplingParams[source]#

Initializes SamplingParams with JAX arrays on the specified device/sharding.

Parameters
Returns

An initialized SamplingParams object with JAX arrays.

Return type

SamplingParams

classmethod init_numpy(metadata: PagedAttentionCacheMetaData) SamplingParams[source]#

Initializes SamplingParams with NumPy arrays.

Parameters

metadata (PagedAttentionCacheMetaData) – Metadata containing batch size.

Returns

An initialized SamplingParams object with NumPy arrays.

Return type

SamplingParams

insert_decode_state(insert_slots: Array, update: ActiveSequenceBatch)[source]#

Updates sampling parameters in specified slots from an ActiveSequenceBatch.

Parameters
  • insert_slots (jax.Array) – An array of slot indices to update.

  • update (ActiveSequenceBatch) – The batch containing the new sampling parameters.

insert_from_task(slot: int, task: GenerationStepTask)[source]#

Inserts sampling parameters from a GenerationStepTask into a specific slot.

Parameters
  • slot (int) – The batch slot index to insert the parameters into.

  • task (GenerationStepTask) – The task containing the sampling parameters to insert.

max_tokens: jax.Array | int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

temperature: jax.Array | float#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.

top_p: jax.Array | float#
class easydel.layers.caching.paged_attention.__init__.SlotPageAssignment(slot: int, page_idx: int, mapped_idx: int)[source]#

Bases: object

Represents the assignment of a physical page to a logical page slot.

During decoding, as sequences grow, new physical memory pages might be allocated. This class represents the update instruction to map a specific logical page index within a sequence’s page table (identified by its slot) to a newly allocated physical HBM page index (mapped_idx).

slot#

The batch slot index of the sequence whose page table is updated.

Type

int

page_idx#

The logical page index within the sequence’s page table.

Type

int

mapped_idx#

The physical HBM page index to map the logical page to.

Type

int

classmethod from_dict(data: Dict[str, Any]) T#

Deserializes a dictionary into a PyTree object.

classmethod from_json(json_str: str) T#

Deserializes a JSON string into a PyTree object.

mapped_idx: int#
page_idx: int#
replace(**kwargs)#

Creates a new instance with specified fields replaced.

slot: int#
to_dict() Dict[str, Any]#

Serializes the PyTree object to a dictionary.

to_json(**kwargs) str#

Serializes the PyTree object to a JSON string.