partition_module package#

class easydel.etils.partition_module.PartitionAxis(batch_axis: Optional[Union[Tuple[str, ...], str]] = ('fsdp', 'dp'), sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', hidden_state_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None)[source]#

Bases: NamedTuple

A NamedTuple representing different axes of partitioning in a model.

Each field represents an axis and its corresponding partitioning strategy. The value of each field can be:

  • None: The axis is not partitioned.

  • str: The name of the single mesh dimension across which the axis is partitioned.

  • Tuple[str, …]: A tuple of mesh dimension names, indicating a sharding strategy where the axis is split across multiple mesh dimensions.

batch_axis#

Partitioning strategy for the batch dimension. Defaults to (“fsdp”, “dp”).

Type

Optional[Union[Tuple[str, …], str]]

sequence_axis#

Partitioning strategy for the sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

query_sequence_axis#

Partitioning strategy for the query sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

head_axis#

Partitioning strategy for the attention head dimension. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

key_sequence_axis#

Partitioning strategy for the key sequence dimension. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

hidden_state_axis#

Partitioning strategy for the hidden state dimension. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

attention_dim_axis#

Partitioning strategy for the attention dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

bias_head_sequence_axis#

Partitioning strategy for the bias head sequence dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

bias_key_sequence_axis#

Partitioning strategy for the bias key sequence dimension. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

generation_query_sequence_axis#

Partitioning strategy for the query sequence dimension during generation. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

generation_head_axis#

Partitioning strategy for the attention head dimension during generation. Defaults to “tp”.

Type

Optional[Union[Tuple[str, …], str]]

generation_key_sequence_axis#

Partitioning strategy for the key sequence dimension during generation. Defaults to “sp”.

Type

Optional[Union[Tuple[str, …], str]]

generation_attention_dim_axis#

Partitioning strategy for the attention dimension during generation. Defaults to None.

Type

Optional[Union[Tuple[str, …], str]]

attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 6

batch_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 0

bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 7

bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 8

generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 12

generation_head_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 10

generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 11

generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 9

head_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 3

hidden_state_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 5

key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 4

query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 2

sequence_axis: Optional[Union[Tuple[str, ...], str]]#

Alias for field number 1