partition_module package#
- class easydel.etils.partition_module.PartitionAxis(batch_axis: Optional[Union[Tuple[str, ...], str]] = ('fsdp', 'dp'), sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', hidden_state_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]] = None, generation_head_axis: Optional[Union[Tuple[str, ...], str]] = 'tp', generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]] = 'sp', generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]] = None)[source]#
Bases:
NamedTupleA NamedTuple representing different axes of partitioning in a model.
Each field represents an axis and its corresponding partitioning strategy. The value of each field can be:
None: The axis is not partitioned.
str: The name of the single mesh dimension across which the axis is partitioned.
Tuple[str, …]: A tuple of mesh dimension names, indicating a sharding strategy where the axis is split across multiple mesh dimensions.
- batch_axis#
Partitioning strategy for the batch dimension. Defaults to (“fsdp”, “dp”).
- Type
Optional[Union[Tuple[str, …], str]]
- sequence_axis#
Partitioning strategy for the sequence dimension. Defaults to “sp”.
- Type
Optional[Union[Tuple[str, …], str]]
- query_sequence_axis#
Partitioning strategy for the query sequence dimension. Defaults to “sp”.
- Type
Optional[Union[Tuple[str, …], str]]
- head_axis#
Partitioning strategy for the attention head dimension. Defaults to “tp”.
- Type
Optional[Union[Tuple[str, …], str]]
- key_sequence_axis#
Partitioning strategy for the key sequence dimension. Defaults to “sp”.
- Type
Optional[Union[Tuple[str, …], str]]
Partitioning strategy for the hidden state dimension. Defaults to “tp”.
- Type
Optional[Union[Tuple[str, …], str]]
- attention_dim_axis#
Partitioning strategy for the attention dimension. Defaults to None.
- Type
Optional[Union[Tuple[str, …], str]]
- bias_head_sequence_axis#
Partitioning strategy for the bias head sequence dimension. Defaults to None.
- Type
Optional[Union[Tuple[str, …], str]]
- bias_key_sequence_axis#
Partitioning strategy for the bias key sequence dimension. Defaults to None.
- Type
Optional[Union[Tuple[str, …], str]]
- generation_query_sequence_axis#
Partitioning strategy for the query sequence dimension during generation. Defaults to None.
- Type
Optional[Union[Tuple[str, …], str]]
- generation_head_axis#
Partitioning strategy for the attention head dimension during generation. Defaults to “tp”.
- Type
Optional[Union[Tuple[str, …], str]]
- generation_key_sequence_axis#
Partitioning strategy for the key sequence dimension during generation. Defaults to “sp”.
- Type
Optional[Union[Tuple[str, …], str]]
- generation_attention_dim_axis#
Partitioning strategy for the attention dimension during generation. Defaults to None.
- Type
Optional[Union[Tuple[str, …], str]]
- attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 6
- batch_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 0
- bias_head_sequence_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 7
- bias_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 8
- generation_attention_dim_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 12
- generation_head_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 10
- generation_key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 11
- generation_query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 9
- head_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 3
- hidden_state_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 5
- key_sequence_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 4
- query_sequence_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 2
- sequence_axis: Optional[Union[Tuple[str, ...], str]]#
Alias for field number 1