Source code for easydel.infra.elarge_model.builders

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Builder functions for creating models and inference engines from ELM configurations.

This module provides high-level functions to build EasyDeL models and eSurge inference
engines from ELM configuration dictionaries.
"""

from __future__ import annotations

import pathlib
import typing as tp
from collections.abc import Mapping
from typing import Any

if tp.TYPE_CHECKING:
    from easydel.data.core.protocols import ShardedDataSource

from easydel.inference.esurge.esurge_engine import DEFAULT_DETOKENIZER_MAX_STATES
from easydel.infra.base_module import EasyDeLBaseModule
from easydel.infra.factory import TaskType
from easydel.layers.quantization.quantizers import EasyDeLQuantizationConfig
from easydel.modules.auto import (
    AutoEasyDeLModel,
    AutoEasyDeLModelForCausalLM,
    AutoEasyDeLModelForDiffusionLM,
    AutoEasyDeLModelForImageTextToText,
    AutoEasyDeLModelForSeq2SeqLM,
    AutoEasyDeLModelForSpeechSeq2Seq,
    AutoEasyDeLModelForZeroShotImageClassification,
)

from .normalizer import materialize_base_config, normalize, resolve_task
from .types import ELMConfig
from .utils import coerce_dtype, coerce_precision


[docs]def to_from_pretrained_kwargs(cfg_like: ELMConfig | Mapping[str, Any]) -> dict[str, Any]: """Convert ELM configuration to kwargs for model.from_pretrained() calls. Extracts and transforms configuration values from various sections into the format expected by EasyDeL's from_pretrained methods. Args: cfg_like: ELM configuration dictionary or mapping Returns: Dictionary of keyword arguments for from_pretrained() methods Example: >>> cfg = { ... "model": {"name_or_path": "meta-llama/Llama-2-7b"}, ... "loader": {"dtype": "bf16"}, ... "sharding": {"axis_dims": (1, 1, 1, -1, 1)} ... } >>> kwargs = to_from_pretrained_kwargs(cfg) >>> model = AutoEasyDeLModelForCausalLM.from_pretrained(**kwargs) """ cfg = normalize(cfg_like) model = cfg["model"] loader = cfg.get("loader", {}) sharding = cfg.get("sharding", {}) platform = cfg.get("platform", {}) quant = cfg.get("quantization", {}) config_kwargs = materialize_base_config(cfg, prefer="base") config_kwargs.pop("partition_axis", None) config_kwargs.pop("backend", None) config_kwargs.pop("platform", None) quant_model = quant.get("model") if quant_model is not None: quant_model = EasyDeLQuantizationConfig(**quant_model) return dict( pretrained_model_name_or_path=model["name_or_path"], device=loader.get("device"), dtype=coerce_dtype(loader.get("dtype")), param_dtype=coerce_dtype(loader.get("param_dtype")), precision=coerce_precision(loader.get("precision")), sharding_axis_dims=tuple(sharding.get("axis_dims", (1, 1, 1, -1, 1))), sharding_dcn_axis_dims=tuple(sharding["dcn_axis_dims"]) if sharding.get("dcn_axis_dims") else None, sharding_axis_names=tuple(sharding.get("axis_names", ("dp", "fsdp", "ep", "tp", "sp"))), partition_axis=sharding.get("partition_axis"), shard_fns=sharding.get("shard_fns"), backend=platform.get("backend"), platform=platform.get("platform"), config_kwargs=config_kwargs, auto_shard_model=bool(sharding.get("auto_shard_model", True)), partition_rules=sharding.get("partition_rules"), quantization_config=quant_model, quantize_tensors=bool(quant.get("quantize_tensors", False)), verbose=bool(loader.get("verbose", True)), from_torch=loader.get("from_torch"), **(model.get("extra_kwargs") or {}), )
[docs]def build_model(cfg_like: ELMConfig | Mapping[str, Any]) -> EasyDeLBaseModule: """Build an EasyDeL model from ELM configuration. Automatically selects the appropriate model class based on the task type specified in the configuration. Args: cfg_like: ELM configuration dictionary or mapping Returns: EasyDeLBaseModule: The loaded model instance Example: >>> cfg = { ... "model": {"name_or_path": "meta-llama/Llama-2-7b", "task": "causal_lm"}, ... "loader": {"dtype": "bf16"} ... } >>> model = build_model(cfg) >>> """ kw = to_from_pretrained_kwargs(cfg_like) task = resolve_task(normalize(cfg_like)) if task == TaskType.CAUSAL_LM: return AutoEasyDeLModelForCausalLM.from_pretrained(**kw) if task == TaskType.SEQUENCE_TO_SEQUENCE: return AutoEasyDeLModelForSeq2SeqLM.from_pretrained(**kw) if task == TaskType.SPEECH_SEQUENCE_TO_SEQUENCE: return AutoEasyDeLModelForSpeechSeq2Seq.from_pretrained(**kw) if task == TaskType.IMAGE_TEXT_TO_TEXT: return AutoEasyDeLModelForImageTextToText.from_pretrained(**kw) if task == TaskType.ZERO_SHOT_IMAGE_CLASSIFICATION: return AutoEasyDeLModelForZeroShotImageClassification.from_pretrained(**kw) if task == TaskType.DIFFUSION_LM: return AutoEasyDeLModelForDiffusionLM.from_pretrained(**kw) return AutoEasyDeLModel.from_pretrained(**kw)
[docs]def to_esurge_kwargs(cfg_like: ELMConfig | Mapping[str, Any]) -> dict[str, Any]: """Convert ELM configuration to kwargs for eSurge initialization. Extracts eSurge-specific configuration values and infers defaults from base configuration when needed. Args: cfg_like: ELM configuration dictionary or mapping Returns: Dictionary of keyword arguments for eSurge initialization Example: >>> cfg = { ... "model": {"name_or_path": "meta-llama/Llama-2-7b"}, ... "esurge": {"max_model_len": 4096, "max_num_seqs": 32} ... } >>> kwargs = to_esurge_kwargs(cfg) >>> kwargs["max_model_len"] 4096 """ cfg = normalize(cfg_like) es = cfg.get("esurge", {}) base_vals = dict(cfg.get("base_config", {}).get("values", {}) or {}) max_model_len = ( es.get("max_model_len") or base_vals.get("mask_max_position_embeddings") or base_vals.get("freq_max_position_embeddings") or 8192 ) min_input_pad_val = es.get("min_input_pad") max_num_seqs_val = es.get("max_num_seqs") page_size_val = es.get("page_size") hbm_utilization_val = es.get("hbm_utilization") use_aot_forward_val = es.get("use_aot_forward") enable_prefix_caching_val = es.get("enable_prefix_caching") auto_shard_model_val = es.get("auto_shard_model") compile_runner_val = es.get("compile_runner") overlap_execution_val = es.get("overlap_execution") sampler_metrics_val = es.get("sampler_metrics") auto_truncate_prompt_val = es.get("auto_truncate_prompt") auto_cap_new_tokens_val = es.get("auto_cap_new_tokens") strict_context_val = es.get("strict_context") prefer_preserve_prompt_val = es.get("prefer_preserve_prompt") decode_truncated_prompt_val = es.get("decode_truncated_prompt") destroy_pages_on_pause_val = es.get("destroy_pages_on_pause") silent_mode_val = es.get("silent_mode") sharding_axis_dims_val = es.get("sharding_axis_dims", (1, 1, 1, -1, 1)) sharding_axis_dims = tuple(sharding_axis_dims_val) if sharding_axis_dims_val is not None else None max_num_batched_tokens = es.get("max_num_batched_tokens") if max_num_batched_tokens is not None: max_num_batched_tokens = int(max_num_batched_tokens) reserve_tokens = es.get("reserve_tokens") if reserve_tokens is not None: reserve_tokens = int(reserve_tokens) detokenizer_max_states = es.get("detokenizer_max_states", DEFAULT_DETOKENIZER_MAX_STATES) if detokenizer_max_states is not None: detokenizer_max_states = int(detokenizer_max_states) extra_eos_token_ids = es.get("extra_eos_token_ids") if extra_eos_token_ids is not None: extra_eos_token_ids = list(extra_eos_token_ids) runner_verbose = bool(es.get("runner_verbose", es.get("verbose", False))) truncate_mode = es.get("truncate_mode", "left") return dict( max_model_len=int(max_model_len), min_input_pad=int(min_input_pad_val) if min_input_pad_val is not None else 16, max_num_seqs=int(max_num_seqs_val) if max_num_seqs_val is not None else 256, max_num_batched_tokens=max_num_batched_tokens, hbm_utilization=float(hbm_utilization_val) if hbm_utilization_val is not None else 0.85, page_size=int(page_size_val) if page_size_val is not None else 128, use_aot_forward=True if use_aot_forward_val is None else bool(use_aot_forward_val), enable_prefix_caching=True if enable_prefix_caching_val is None else bool(enable_prefix_caching_val), auto_shard_model=True if auto_shard_model_val is None else bool(auto_shard_model_val), sharding_axis_dims=sharding_axis_dims, compile_runner=True if compile_runner_val is None else bool(compile_runner_val), runner_verbose=runner_verbose, overlap_execution=False if overlap_execution_val is None else bool(overlap_execution_val), sampler_metrics=False if sampler_metrics_val is None else bool(sampler_metrics_val), esurge_name=es.get("esurge_name"), reserve_tokens=reserve_tokens, auto_truncate_prompt=True if auto_truncate_prompt_val is None else bool(auto_truncate_prompt_val), auto_cap_new_tokens=True if auto_cap_new_tokens_val is None else bool(auto_cap_new_tokens_val), strict_context=False if strict_context_val is None else bool(strict_context_val), truncate_mode=truncate_mode, prefer_preserve_prompt=True if prefer_preserve_prompt_val is None else bool(prefer_preserve_prompt_val), decode_truncated_prompt=True if decode_truncated_prompt_val is None else bool(decode_truncated_prompt_val), destroy_pages_on_pause=True if destroy_pages_on_pause_val is None else bool(destroy_pages_on_pause_val), detokenizer_max_states=detokenizer_max_states, tokenizer_endpoint=es.get("tokenizer_endpoint"), detokenizer_endpoint=es.get("detokenizer_endpoint"), sampling_params_callback=es.get("sampling_params_callback"), extra_eos_token_ids=extra_eos_token_ids, silent_mode=False if silent_mode_val is None else bool(silent_mode_val), )
[docs]def build_esurge(cfg_like: ELMConfig | Mapping[str, Any], model: EasyDeLBaseModule | None = None): """Build an eSurge inference engine from ELM configuration. Creates an eSurge instance with the model, tokenizer, and inference configuration specified in the ELM config. Args: cfg_like: ELM configuration dictionary or mapping Returns: eSurge: Configured eSurge inference engine Raises: NotImplementedError: If the task type is not supported by eSurge Example: >>> cfg = { ... "model": {"name_or_path": "meta-llama/Llama-2-7b"}, ... "esurge": {"max_model_len": 4096, "max_num_seqs": 32} ... } >>> engine = build_esurge(cfg) >>> """ from transformers import AutoTokenizer from easydel.inference import eSurge cfg = normalize(cfg_like) task = resolve_task(cfg) if task not in [TaskType.CAUSAL_LM, TaskType.IMAGE_TEXT_TO_TEXT, getattr(TaskType, "VISION_LM", None)]: raise NotImplementedError(f"eSurge supports [CAUSAL_LM, IMAGE_TEXT_TO_TEXT, VISION_LM]; got {task}") tok_path = cfg["model"].get("tokenizer", cfg["model"]["name_or_path"]) if model is None: model = build_model(cfg) return eSurge( model=model, tokenizer=AutoTokenizer.from_pretrained(tok_path), **to_esurge_kwargs(cfg), )
[docs]def to_data_mixture_kwargs(cfg_like: ELMConfig | Mapping[str, Any]) -> dict[str, Any]: """Convert ELM configuration to kwargs for DatasetMixture creation. Transforms the mixture configuration section into the format expected by the DatasetMixture and DataManager classes. Supports all modern features including token packing and block-deterministic mixing. Args: cfg_like: ELM configuration dictionary or mapping Returns: Dictionary of keyword arguments for DatasetMixture initialization Example: >>> cfg = { ... "mixture": { ... "informs": [ ... {"type": "json", "data_files": "train.json", "content_field": "text"} ... ], ... "batch_size": 32, ... "block_mixture": True, ... "pack_tokens": True, ... "pack_seq_length": 2048 ... } ... } >>> kwargs = to_data_mixture_kwargs(cfg) >>> mixture = DatasetMixture(**kwargs) """ from easydel.data import TextDatasetInform, VisualDatasetInform cfg = normalize(cfg_like) mixture_cfg = cfg.get("mixture", {}) if not mixture_cfg: return {} informs = [] for inform_cfg in mixture_cfg.get("informs", []): if "pixel_field" in inform_cfg: inform = VisualDatasetInform( type=inform_cfg.get("type"), data_files=inform_cfg["data_files"], dataset_split_name=inform_cfg.get("dataset_split_name", None), split=inform_cfg.get("split", "train"), pixel_field=inform_cfg.get("pixel_field", "images"), content_field=inform_cfg.get("content_field"), image_size=tuple(inform_cfg["image_size"]) if inform_cfg.get("image_size") else None, num_rows=inform_cfg.get("num_rows"), format_callback=inform_cfg.get("format_callback"), format_fields=inform_cfg.get("format_fields"), ) else: inform = TextDatasetInform( type=inform_cfg.get("type"), data_files=inform_cfg["data_files"], dataset_split_name=inform_cfg.get("dataset_split_name", None), split=inform_cfg.get("split", "train"), content_field=inform_cfg.get("content_field", "content"), additional_fields=inform_cfg.get("additional_fields"), num_rows=inform_cfg.get("num_rows"), format_callback=inform_cfg.get("format_callback"), format_fields=inform_cfg.get("format_fields"), ) informs.append(inform) kwargs = dict( informs=informs, cache_dir=mixture_cfg.get("cache_dir", f"{pathlib.Path.home()}/.cache/easydel"), streaming=mixture_cfg.get("streaming", True), text_target_field=mixture_cfg.get("text_target_field", "text"), image_target_field=mixture_cfg.get("image_target_field", "image"), batch_size=mixture_cfg.get("batch_size", 1), shuffle_buffer_size=mixture_cfg.get("shuffle_buffer_size"), seed=mixture_cfg.get("seed", 42), ) if "pack_tokens" in mixture_cfg: kwargs["pack_tokens"] = mixture_cfg["pack_tokens"] if "tokens_field_name" in mixture_cfg: kwargs["tokens_field_name"] = mixture_cfg["tokens_field_name"] if "pack_seq_length" in mixture_cfg: kwargs["pack_seq_length"] = mixture_cfg["pack_seq_length"] if "pack_eos_token_id" in mixture_cfg: kwargs["pack_eos_token_id"] = mixture_cfg["pack_eos_token_id"] if "pack_shuffle" in mixture_cfg: kwargs["pack_shuffle"] = mixture_cfg["pack_shuffle"] if "pack_shuffle_buffer_factor" in mixture_cfg: kwargs["pack_shuffle_buffer_factor"] = mixture_cfg["pack_shuffle_buffer_factor"] if "dask_storage_options" in mixture_cfg: kwargs["dask_storage_options"] = mixture_cfg["dask_storage_options"] if "pack_on_the_fly" in mixture_cfg: kwargs["pack_on_the_fly"] = mixture_cfg["pack_on_the_fly"] if "tokenize_callback" in mixture_cfg: kwargs["tokenize_callback"] = mixture_cfg["tokenize_callback"] if "block_mixture" in mixture_cfg: kwargs["block_mixture"] = mixture_cfg["block_mixture"] if "mixture_block_size" in mixture_cfg: kwargs["mixture_block_size"] = mixture_cfg["mixture_block_size"] if "stop_strategy" in mixture_cfg: kwargs["stop_strategy"] = mixture_cfg["stop_strategy"] if "mixture_weights" in mixture_cfg: kwargs["mixture_weights"] = mixture_cfg["mixture_weights"] return kwargs
[docs]def build_dataset(cfg_like: ELMConfig | Mapping[str, Any]): """Build a dataset from ELM configuration with data mixture. Creates a unified dataset from the mixture configuration using the new DatasetMixture.build() method. Supports all modern features including token packing, block-deterministic mixing, and streaming. Args: cfg_like: ELM configuration dictionary or mapping Returns: Dataset or IterableDataset: The loaded and processed dataset Example: >>> cfg = { ... "mixture": { ... "informs": [ ... {"type": "json", "data_files": "data.json", "content_field": "text"} ... ], ... "block_mixture": True, ... "pack_tokens": True, ... "pack_seq_length": 2048 ... } ... } >>> dataset = build_dataset(cfg) """ from easydel.data import DatasetMixture cfg = normalize(cfg_like) mixture_cfg = cfg.get("mixture", {}) if not mixture_cfg or not mixture_cfg.get("informs"): return None mixture_kwargs = to_data_mixture_kwargs(cfg) mixture = DatasetMixture(**mixture_kwargs) return mixture.build()
[docs]def tokenize_dataset( dataset, tokenizer, text_field: str = "text", output_field: str = "tokens", max_length: int = 2048, truncation: bool = True, padding: bool | str = False, add_special_tokens: bool = True, return_attention_mask: bool = True, num_proc: int | None = None, batched: bool = True, batch_size: int = 1000, remove_columns: list[str] | None = None, keep_in_memory: bool = False, ): """Tokenize a dataset using the provided tokenizer. Args: dataset: HuggingFace Dataset or IterableDataset to tokenize tokenizer: HuggingFace tokenizer instance text_field: Field name containing text to tokenize (default: "text") output_field: Field name for tokenized output (default: "tokens") max_length: Maximum sequence length (default: 2048) truncation: Whether to truncate sequences (default: True) padding: Padding strategy (default: False) add_special_tokens: Add special tokens like BOS/EOS (default: True) return_attention_mask: Return attention masks (default: True) num_proc: Number of processes for parallel tokenization (default: None) batched: Process examples in batches (default: True) batch_size: Batch size for batched processing (default: 1000) remove_columns: Columns to remove after tokenization (default: None) keep_in_memory: Keep processed dataset in memory (default: False) Returns: Tokenized dataset with token IDs in the output_field Example: >>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b") >>> tokenized = tokenize_dataset(dataset, tokenizer, text_field="content") """ from easydel.data.utils import is_streaming def tokenize_fn(examples): # Handle both batched and single examples texts = examples[text_field] if isinstance(texts, str): texts = [texts] outputs = tokenizer( texts, max_length=max_length, truncation=truncation, padding=padding, add_special_tokens=add_special_tokens, return_attention_mask=return_attention_mask, ) result = {output_field: outputs["input_ids"]} if return_attention_mask: result["attention_mask"] = outputs["attention_mask"] return result # Determine columns to remove if remove_columns is None: if hasattr(dataset, "column_names"): remove_columns = dataset.column_names else: remove_columns = [] # Handle streaming vs non-streaming datasets if is_streaming(dataset): return dataset.map( tokenize_fn, batched=batched, remove_columns=remove_columns, ) else: return dataset.map( tokenize_fn, batched=batched, batch_size=batch_size, num_proc=num_proc, remove_columns=remove_columns, keep_in_memory=keep_in_memory, )
[docs]def save_dataset( dataset, output_path: str, format: str = "parquet", # noqa:A002 num_shards: int | None = None, compression: str | None = "snappy", max_shard_size: str | int = "500MB", overwrite: bool = False, push_to_hub: bool = False, hub_repo_id: str | None = None, hub_private: bool = False, hub_token: str | None = None, ): """Save a dataset to disk or HuggingFace Hub. Args: dataset: HuggingFace Dataset to save output_path: Path to save the dataset format: Output format - "parquet", "arrow", "json", "jsonl" (default: "parquet") num_shards: Number of shards (default: None, auto-detect) compression: Compression algorithm (default: "snappy") max_shard_size: Maximum shard size (default: "500MB") overwrite: Whether to overwrite existing files (default: False) push_to_hub: Push to HuggingFace Hub (default: False) hub_repo_id: Hub repository ID (required if push_to_hub=True) hub_private: Make Hub repo private (default: False) hub_token: HuggingFace token (default: None) Returns: Path to saved dataset or Hub URL if pushed Example: >>> save_dataset(tokenized_dataset, "output/tokenized", format="parquet") >>> # Or push to hub >>> save_dataset(tokenized_dataset, "output/tokenized", ... push_to_hub=True, hub_repo_id="username/my-dataset") """ import os from easydel.data.utils import is_streaming # Check if output exists if os.path.exists(output_path) and not overwrite: raise FileExistsError(f"Output path '{output_path}' already exists. Set overwrite=True to replace.") # Handle streaming datasets - materialize first if is_streaming(dataset): from datasets import Dataset # Convert iterable dataset to regular dataset dataset = Dataset.from_generator(lambda: (ex for ex in dataset)) # Create output directory os.makedirs(output_path, exist_ok=True) # Save based on format if format == "parquet": dataset.to_parquet( os.path.join(output_path, "data.parquet"), compression=compression, ) elif format == "arrow": dataset.save_to_disk(output_path, num_shards=num_shards, max_shard_size=max_shard_size) elif format in ("json", "jsonl"): dataset.to_json( os.path.join(output_path, "data.jsonl"), lines=(format == "jsonl"), ) else: raise ValueError(f"Unsupported format: {format}. Use 'parquet', 'arrow', 'json', or 'jsonl'.") # Push to Hub if requested if push_to_hub: if not hub_repo_id: raise ValueError("hub_repo_id is required when push_to_hub=True") dataset.push_to_hub( repo_id=hub_repo_id, private=hub_private, token=hub_token, ) return f"https://huggingface.co/datasets/{hub_repo_id}" return output_path
[docs]def build_tokenized_dataset(cfg_like: ELMConfig | Mapping[str, Any], save: bool = True): """Build, tokenize, and optionally save a dataset from ELM configuration. This is the main entry point for the tokenization pipeline. It: 1. Loads the dataset from the mixture configuration 2. Tokenizes using the specified tokenizer 3. Optionally saves to disk or HuggingFace Hub Args: cfg_like: ELM configuration dictionary or mapping save: Whether to save the tokenized dataset (default: True) Returns: Tuple of (tokenized_dataset, save_path) if save=True, else tokenized_dataset Example: >>> cfg = { ... "model": {"name_or_path": "meta-llama/Llama-2-7b"}, ... "mixture": { ... "informs": [ ... {"type": "json", "data_files": "data.json", "content_field": "text"} ... ], ... "streaming": False, # Must be False for saving ... "tokenization": { ... "max_length": 2048, ... "text_field": "text", ... "output_field": "tokens", ... "num_proc": 4 ... }, ... "save": { ... "output_path": "tokenized_data", ... "format": "parquet" ... } ... } ... } >>> dataset, path = build_tokenized_dataset(cfg) """ from transformers import AutoTokenizer cfg = normalize(cfg_like) mixture_cfg = cfg.get("mixture", {}) if not mixture_cfg or not mixture_cfg.get("informs"): raise ValueError("mixture.informs is required for tokenization") # Get tokenization config tok_cfg = mixture_cfg.get("tokenization", {}) save_cfg = mixture_cfg.get("save", {}) # Determine tokenizer path tokenizer_path = tok_cfg.get("tokenizer") if tokenizer_path is None: model_cfg = cfg.get("model", {}) tokenizer_path = model_cfg.get("tokenizer", model_cfg.get("name_or_path")) if not tokenizer_path: raise ValueError("Tokenizer not specified. Set mixture.tokenization.tokenizer or model.name_or_path") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) # Build dataset (ensure non-streaming for save) if save and mixture_cfg.get("streaming", True): # Override streaming for save operation cfg_copy = dict(cfg) cfg_copy["mixture"] = dict(mixture_cfg) cfg_copy["mixture"]["streaming"] = False dataset = build_dataset(cfg_copy) else: dataset = build_dataset(cfg) if dataset is None: raise ValueError("Failed to build dataset from configuration") # Tokenize tokenized = tokenize_dataset( dataset, tokenizer, text_field=tok_cfg.get("text_field", mixture_cfg.get("text_target_field", "text")), output_field=tok_cfg.get("output_field", "tokens"), max_length=tok_cfg.get("max_length", 2048), truncation=tok_cfg.get("truncation", True), padding=tok_cfg.get("padding", False), add_special_tokens=tok_cfg.get("add_special_tokens", True), return_attention_mask=tok_cfg.get("return_attention_mask", True), num_proc=tok_cfg.get("num_proc"), batched=tok_cfg.get("batched", True), batch_size=tok_cfg.get("batch_size", 1000), remove_columns=tok_cfg.get("remove_columns"), keep_in_memory=tok_cfg.get("keep_in_memory", False), ) # Save if requested if save and save_cfg: output_path = save_cfg.get("output_path") if not output_path: raise ValueError("mixture.save.output_path is required when save=True") saved_path = save_dataset( tokenized, output_path=output_path, format=save_cfg.get("format", "parquet"), num_shards=save_cfg.get("num_shards"), compression=save_cfg.get("compression", "snappy"), max_shard_size=save_cfg.get("max_shard_size", "500MB"), overwrite=save_cfg.get("overwrite", False), push_to_hub=save_cfg.get("push_to_hub", False), hub_repo_id=save_cfg.get("hub_repo_id"), hub_private=save_cfg.get("hub_private", False), hub_token=save_cfg.get("hub_token"), ) return tokenized, saved_path return tokenized
def _extract_dataset_name(inform_cfg: dict, fallback_index: int = 0) -> str: """Extract a meaningful dataset name from inform configuration. Uses the following priority: 1. Explicit 'name' field if provided 2. HuggingFace repo name from 'data_files' (e.g., "LDJnr/Puffin" -> "Puffin") 3. GCS/S3/cloud path - extract bucket or meaningful path segment 4. File/directory name from 'data_files' path 5. Fallback to "dataset_{index}" Args: inform_cfg: Dataset inform configuration dictionary. fallback_index: Index to use in fallback name. Returns: Extracted dataset name string. Examples: - "LDJnr/Puffin" -> "Puffin" - "gs://my-bucket/datasets/alpaca/*.parquet" -> "alpaca" - "s3://bucket/data/train.json" -> "train" - "/local/path/to/data.parquet" -> "data" - "hf://datasets/tatsu-lab/alpaca" -> "alpaca" """ import os import re # Check for explicit name if inform_cfg.get("name"): return inform_cfg["name"] data_files = inform_cfg.get("data_files", "") if not isinstance(data_files, str): # Handle list of files - use first file if isinstance(data_files, list) and data_files: data_files = data_files[0] else: return f"dataset_{fallback_index}" # Strip whitespace data_files = data_files.strip() # Handle cloud storage paths (gs://, s3://, az://, hf://) cloud_match = re.match(r"^(gs|s3|az|gcs|https?|hf)://(.+)$", data_files, re.IGNORECASE) if cloud_match: path_part = cloud_match.group(2) # Remove bucket name for gs/s3/az, keep meaningful path # e.g., "my-bucket/datasets/alpaca/*.parquet" -> extract "alpaca" # e.g., "datasets/tatsu-lab/alpaca" -> extract "alpaca" parts = [p for p in path_part.split("/") if p and not p.startswith("*")] # Walk backwards to find a meaningful name (skip bucket, globs, extensions) for part in reversed(parts): # Skip if it's just a file extension pattern if part.startswith("*."): continue # Clean glob patterns and extensions clean = part.rstrip("*").rstrip("/") name, _ = os.path.splitext(clean) # Skip if empty or looks like a bucket name (too generic) if name and name not in ("data", "train", "test", "val", "dataset", "datasets"): return name if name: # Use it even if generic, but keep looking candidate = name # Use the last candidate found if "candidate" in dir() and candidate: return candidate # Fallback to last non-empty part if parts: name, _ = os.path.splitext(parts[-1].rstrip("*")) if name: return name if "/" in data_files and not data_files.startswith(("/", ".", "~")): parts = data_files.split("/") if len(parts) == 2 and not any( p.endswith( ( ".json", ".parquet", ".arrow", ".csv", ".txt", ".jsonl", ) ) for p in parts ): return parts[-1] clean_path = data_files.rstrip("*").rstrip("/") base_name = os.path.basename(clean_path) if base_name: # Remove extension if present name, _ = os.path.splitext(base_name) if name: return name # Fallback return f"dataset_{fallback_index}" def _create_source_from_inform( inform_cfg: dict, mixture_cfg: dict, ) -> "ShardedDataSource": """Create a ShardedDataSource from an inform configuration. Maps the inform config to the appropriate ShardedDataSource type based on the dataset type (JSON, Parquet, Arrow, CSV, HF, etc.). Args: inform_cfg: Dataset inform configuration dictionary. mixture_cfg: Parent mixture configuration dictionary. Returns: ShardedDataSource: The created data source. Raises: ValueError: If the dataset type is not supported. """ from easydel.data.sources import ( ArrowShardedSource, CsvShardedSource, HuggingFaceShardedSource, JsonShardedSource, ParquetShardedSource, TextShardedSource, expand_data_files, ) data_files = inform_cfg.get("data_files") source_type = inform_cfg.get("type", "").lower() split = inform_cfg.get("split", "train") dataset_split_name = inform_cfg.get("dataset_split_name") # Check if it's a HuggingFace dataset if source_type in ("huggingface", "hf"): return HuggingFaceShardedSource( dataset_name=data_files, split=split, subset=dataset_split_name, ) # Expand files try: files = expand_data_files(data_files) except FileNotFoundError: # Might be a HuggingFace dataset if type wasn't specified if isinstance(data_files, str) and not source_type: return HuggingFaceShardedSource( dataset_name=data_files, split=split, subset=dataset_split_name, ) raise if not files: raise ValueError(f"No files found for pattern: {data_files}") # Infer type from first file if not specified if not source_type: first_file = files[0] if first_file.endswith((".json", ".jsonl", ".json.gz", ".jsonl.gz")): source_type = "json" elif first_file.endswith(".parquet"): source_type = "parquet" elif first_file.endswith(".arrow"): source_type = "arrow" elif first_file.endswith((".csv", ".tsv")): source_type = "csv" elif first_file.endswith(".txt"): source_type = "txt" # Create appropriate source if source_type in ("json", "jsonl"): return JsonShardedSource(files) elif source_type == "parquet": return ParquetShardedSource(files) elif source_type == "arrow": return ArrowShardedSource(files) elif source_type in ("csv", "tsv"): return CsvShardedSource(files) elif source_type == "txt": return TextShardedSource(files) else: raise ValueError(f"Unsupported dataset type: {source_type}")
[docs]def build_sharded_source(cfg_like: ELMConfig | Mapping[str, Any]) -> "ShardedDataSource | None": """Build a ShardedDataSource from ELM configuration. Uses the new ShardedDataSource architecture for efficient streaming and lazy transforms. Supports mixing, packing, and field transforms. This function creates a unified ShardedDataSource from the mixture configuration, optionally applying: - Field renaming via transforms - Dataset mixing via MixedShardedSource - Sequence packing via PackedShardedSource Args: cfg_like: ELM configuration dictionary or mapping containing a 'mixture' section with dataset configurations. Returns: ShardedDataSource if mixture is configured, None otherwise. Example: >>> cfg = { ... "mixture": { ... "informs": [ ... {"type": "json", "data_files": "data.json", "content_field": "text"} ... ], ... "use_sharded_source": True, ... "pack_tokens": True, ... "pack_seq_length": 2048 ... } ... } >>> source = build_sharded_source(cfg) >>> for batch in source.open_shard(source.shard_names[0]): ... process(batch) """ from easydel.data.transforms import ( MapTransform, MixedShardedSource, RenameFields, ) from easydel.data.transforms.pack import PackedShardedSource cfg = normalize(cfg_like) mixture_cfg = cfg.get("mixture", {}) if not mixture_cfg or not mixture_cfg.get("informs"): return None # Build ShardedDataSource for each inform sources: dict[str, "ShardedDataSource"] = {} content_target = mixture_cfg.get("text_target_field", "text") for i, inform_cfg in enumerate(mixture_cfg.get("informs", [])): # Extract meaningful name from config name = _extract_dataset_name(inform_cfg, fallback_index=i) source = _create_source_from_inform(inform_cfg, mixture_cfg) # Apply format_callback if specified (custom transformation function) format_callback = inform_cfg.get("format_callback") if format_callback is not None: source = source.transform(MapTransform(format_callback)) # Apply field renaming if format_fields is specified if inform_cfg.get("format_fields"): source = source.transform(RenameFields(inform_cfg["format_fields"])) # Rename content_field to target field content_field = inform_cfg.get("content_field", "content") if content_field != content_target: source = source.transform(RenameFields({content_field: content_target})) sources[name] = source # Mix if multiple sources if len(sources) > 1: weights = mixture_cfg.get("mixture_weights") # Convert list weights to dict if needed if isinstance(weights, list) and len(weights) == len(sources): weights = {name: w for name, w in zip(sources.keys(), weights, strict=False)} source = MixedShardedSource( sources=sources, weights=weights, block_size=mixture_cfg.get("mixture_block_size", 2048), seed=mixture_cfg.get("seed", 42), stop_strategy=mixture_cfg.get("stop_strategy", "restart"), ) else: source = next(iter(sources.values())) # Apply packing if enabled if mixture_cfg.get("pack_tokens"): source = PackedShardedSource( source=source, seq_length=mixture_cfg.get("pack_seq_length", 2048), eos_token_id=mixture_cfg.get("pack_eos_token_id", 0), pad_token_id=mixture_cfg.get("pack_eos_token_id", 0), strategy="greedy", input_field=mixture_cfg.get("tokens_field_name", "input_ids"), shuffle=mixture_cfg.get("pack_shuffle", True), shuffle_buffer_factor=mixture_cfg.get("pack_shuffle_buffer_factor", 16), seed=mixture_cfg.get("seed", 42), ) return source