Source code for easydel.etils.partition_module
# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import NamedTuple, Optional, Tuple, Union
AxisType = Optional[Union[Tuple[str, ...], str]]
[docs]class PartitionAxis(NamedTuple):
"""
A NamedTuple representing different axes of partitioning in a model.
Each field represents an axis and its corresponding partitioning strategy.
The value of each field can be:
* None: The axis is not partitioned.
* str: The name of the single mesh dimension across which the axis is partitioned.
* Tuple[str, ...]: A tuple of mesh dimension names, indicating a sharding strategy
where the axis is split across multiple mesh dimensions.
Attributes:
batch_axis: Partitioning strategy for the batch dimension. Defaults to ("fsdp", "dp").
sequence_axis: Partitioning strategy for the sequence dimension. Defaults to "sp".
query_sequence_axis: Partitioning strategy for the query sequence dimension. Defaults to "sp".
head_axis: Partitioning strategy for the attention head dimension. Defaults to "tp".
key_sequence_axis: Partitioning strategy for the key sequence dimension. Defaults to "sp".
hidden_state_axis: Partitioning strategy for the hidden state dimension. Defaults to "tp".
attention_dim_axis: Partitioning strategy for the attention dimension. Defaults to None.
bias_head_sequence_axis: Partitioning strategy for the bias head sequence dimension. Defaults to None.
bias_key_sequence_axis: Partitioning strategy for the bias key sequence dimension. Defaults to None.
generation_query_sequence_axis: Partitioning strategy for the query sequence dimension during generation.
Defaults to None.
generation_head_axis: Partitioning strategy for the attention head dimension during generation.
Defaults to "tp".
generation_key_sequence_axis: Partitioning strategy for the key sequence dimension during generation.
Defaults to "sp".
generation_attention_dim_axis: Partitioning strategy for the attention dimension during generation.
Defaults to None.
"""
batch_axis: AxisType = ("fsdp", "dp")
sequence_axis: AxisType = "sp"
query_sequence_axis: AxisType = "sp"
head_axis: AxisType = "tp"
key_sequence_axis: AxisType = "sp"
hidden_state_axis: AxisType = "tp"
attention_dim_axis: AxisType = None
bias_head_sequence_axis: AxisType = None
bias_key_sequence_axis: AxisType = None
generation_query_sequence_axis: AxisType = None
generation_head_axis: AxisType = "tp"
generation_key_sequence_axis: AxisType = "sp"
generation_attention_dim_axis: AxisType = None