Dataset Mixing#

EasyData provides powerful dataset mixing capabilities with static weights, dynamic scheduling, and block-based deterministic mixing.

Quick Start#

Simple Mixing with HuggingFace Datasets#

from datasets import load_dataset
from easydel.data import block_mixture_interleave

# Load datasets
code_ds = load_dataset("bigcode/starcoderdata", split="train", streaming=True)
text_ds = load_dataset("HuggingFaceFW/fineweb", split="train", streaming=True)
math_ds = load_dataset("hendrycks/competition_math", split="train")

# Mix with weights using dict (recommended - explicit mapping)
mixed = block_mixture_interleave(
    datasets={"code": code_ds, "text": text_ds, "math": math_ds},
    weights={"code": 0.4, "text": 0.5, "math": 0.1},
    block_size=1000,
    seed=42,
    stop="restart",
)

# Use with trainer
trainer = ed.SFTTrainer(train_dataset=mixed, ...)

block_mixture_interleave#

The simplest way to mix HuggingFace datasets.

from easydel.data import block_mixture_interleave

# Dict format with explicit name-to-dataset mapping
mixed = block_mixture_interleave(
    datasets={"code": ds1, "text": ds2, "math": ds3},
    weights={"code": 0.5, "text": 0.3, "math": 0.2},
    block_size=1000,
    seed=42,
    stop="restart",
)

# Equal weights
mixed = block_mixture_interleave(
    datasets={"code": ds1, "text": ds2},
    weights=None,  # Equal 50/50
)

Parameters:

Parameter

Type

Description

datasets

dict

Dict mapping names to datasets: {"code": code_ds, "text": text_ds}

weights

dict or None

Dict mapping names to weights (keys must match), None = equal

block_size

int

Number of examples per mixing block

seed

int

Random seed for shuffling within blocks

stop

str

"restart" to loop, "first_exhausted" to stop

How it works:

  1. Divides training into blocks of block_size examples

  2. Within each block, samples according to weights

  3. Shuffles within block for variety

  4. Uses deterministic RNG per block for reproducibility

MixedShardedSource#

For ShardedDataSource-based mixing with more control.

from easydel.data import MixedShardedSource, HFDatasetShardedSource

# Wrap datasets as ShardedDataSource
source1 = HFDatasetShardedSource(dataset1)
source2 = HFDatasetShardedSource(dataset2)

# Mix with static weights
mixed = MixedShardedSource(
    sources={"code": source1, "text": source2},
    weights={"code": 0.3, "text": 0.7},
    block_size=1000,
    seed=42,
    stop_strategy="restart",
)

# Iterate
for example in mixed.open_shard(mixed.shard_names[0]):
    print(example["__source__"])  # Shows which dataset

Parameters:

Parameter

Type

Default

Description

sources

dict

Required

Name to ShardedDataSource mapping

weights

dict

None

Static weights (None = equal)

block_size

int

1000

Examples per mixing block

seed

int

None

Random seed

stop_strategy

str

“restart”

"restart", "first_exhausted", "all_exhausted"

weight_scheduler

WeightScheduler

None

Dynamic weight scheduler

Dynamic Weight Scheduling#

Change dataset weights during training:

from easydel.data import (
    MixedShardedSource,
    HFDatasetShardedSource,
    WeightScheduler,
    WeightSchedulePoint,
)

# Define schedule
scheduler = WeightScheduler(
    schedule=[
        WeightSchedulePoint(step=0, weights={"code": 0.2, "text": 0.8}),
        WeightSchedulePoint(step=10000, weights={"code": 0.5, "text": 0.5}),
        WeightSchedulePoint(step=50000, weights={"code": 0.8, "text": 0.2}),
    ],
    interpolation="linear",  # "step", "linear", or "cosine"
)

# Create sources
code_source = HFDatasetShardedSource(code_dataset)
text_source = HFDatasetShardedSource(text_dataset)

# Mix with scheduler
mixed = MixedShardedSource(
    sources={"code": code_source, "text": text_source},
    weight_scheduler=scheduler,
    block_size=1000,
)

Interpolation Types#

Type

Description

"step"

Jump to new weights at each schedule point

"linear"

Linearly interpolate between schedule points

"cosine"

Smooth cosine annealing between points

Visualizing Schedule#

scheduler = WeightScheduler(
    schedule=[
        WeightSchedulePoint(step=0, weights={"code": 0.2, "text": 0.8}),
        WeightSchedulePoint(step=10000, weights={"code": 0.5, "text": 0.5}),
        WeightSchedulePoint(step=50000, weights={"code": 0.8, "text": 0.2}),
    ],
    interpolation="linear",
)

# Check weights at any step
print(scheduler.get_weights(0))      # {"code": 0.2, "text": 0.8}
print(scheduler.get_weights(5000))   # {"code": 0.35, "text": 0.65}
print(scheduler.get_weights(10000))  # {"code": 0.5, "text": 0.5}
print(scheduler.get_weights(30000))  # {"code": 0.65, "text": 0.35}

Stop Strategies#

Strategy

Behavior

"restart"

Loop exhausted datasets (infinite iteration)

"first_exhausted"

Stop when any dataset is exhausted

"all_exhausted"

Stop when all datasets are exhausted

# Infinite training - datasets loop
mixed = MixedShardedSource(sources, stop_strategy="restart")

# Stop at smallest dataset
mixed = MixedShardedSource(sources, stop_strategy="first_exhausted")

# Train until all data seen at least once
mixed = MixedShardedSource(sources, stop_strategy="all_exhausted")

Pipeline API Mixing#

Using the Pipeline fluent API:

from easydel.data import (
    Pipeline,
    PipelineConfig,
    DatasetConfig,
    MixStageConfig,
)

config = PipelineConfig(
    datasets=[
        DatasetConfig(
            name="code",
            data_files="code_data/*.parquet",
            tokenizer="meta-llama/Llama-2-7b",
        ),
        DatasetConfig(
            name="text",
            data_files="text_data/*.parquet",
            tokenizer="meta-llama/Llama-2-7b",
        ),
    ],
    mix=MixStageConfig(
        weights={"code": 0.3, "text": 0.7},
        block_size=1000,
        stop_strategy="restart",
    ),
)

pipeline = Pipeline.from_config(config)
for batch in pipeline.source().tokenize().mix().load().build():
    train_step(batch)

MixStageConfig#

from easydel.data import MixStageConfig, WeightSchedulePoint

config = MixStageConfig(
    weights={"code": 0.3, "text": 0.7},  # Static weights
    weight_schedule=[  # Or dynamic schedule
        WeightSchedulePoint(step=0, weights={"code": 0.2, "text": 0.8}),
        WeightSchedulePoint(step=10000, weights={"code": 0.5, "text": 0.5}),
    ],
    weight_schedule_type="linear",  # "step", "linear", "cosine"
    block_size=1000,
    stop_strategy="restart",
    seed=42,
)

CompositeShardedSource#

For mixing without weights (simple concatenation):

from easydel.data import CompositeShardedSource, ParquetShardedSource

source1 = ParquetShardedSource("data1/*.parquet")
source2 = ParquetShardedSource("data2/*.parquet")

# Simple concatenation (no mixing)
combined = CompositeShardedSource([source1, source2])

# All shards from both sources
for shard in combined.shard_names:
    for example in combined.open_shard(shard):
        process(example)

Legacy DatasetMixture API#

For backward compatibility:

from easydel.data import DatasetMixture, TextDatasetInform, build_dataset

mixture = DatasetMixture(
    informs=[
        TextDatasetInform(
            type="parquet",
            data_files="data1/*.parquet",
            content_field="text",
        ),
        TextDatasetInform(
            type="json",
            data_files="data2/*.json",
            content_field="content",
        ),
    ],
    mixture_weights={"ds1": 0.7, "ds2": 0.3},
    block_mixture=True,
    mixture_block_size=2048,
    streaming=True,
    seed=42,
)

dataset = build_dataset(mixture)

Best Practices#

1. Choose Block Size Carefully#

# Smaller blocks = more interleaving but more overhead
mixed = block_mixture_interleave(..., block_size=100)

# Larger blocks = less overhead but more bursty
mixed = block_mixture_interleave(..., block_size=10000)

# Recommended: 1000-2000 for good balance
mixed = block_mixture_interleave(..., block_size=1000)

2. Use Dynamic Scheduling for Curriculum#

# Start with easier data, gradually increase difficulty
scheduler = WeightScheduler(
    schedule=[
        WeightSchedulePoint(step=0, weights={"easy": 0.9, "hard": 0.1}),
        WeightSchedulePoint(step=50000, weights={"easy": 0.5, "hard": 0.5}),
        WeightSchedulePoint(step=100000, weights={"easy": 0.1, "hard": 0.9}),
    ],
    interpolation="cosine",  # Smooth transition
)

3. Handle Imbalanced Datasets#

# Small dataset with important data
scheduler = WeightScheduler(
    schedule=[
        # High weight initially to ensure exposure
        WeightSchedulePoint(step=0, weights={"small_important": 0.5, "large": 0.5}),
        # Reduce after sufficient exposure
        WeightSchedulePoint(step=10000, weights={"small_important": 0.1, "large": 0.9}),
    ],
)

4. Track Data Source#

mixed = MixedShardedSource(sources, ...)

for example in mixed.open_shard(mixed.shard_names[0]):
    source_name = example.get("__source__")  # Added automatically
    if source_name == "code":
        # Apply code-specific processing
        pass

Reproducibility#

Block-based mixing is deterministic given the same seed:

# Same seed = same sequence
mixed1 = block_mixture_interleave(datasets, seed=42)
mixed2 = block_mixture_interleave(datasets, seed=42)

# Different seeds = different sequences
mixed3 = block_mixture_interleave(datasets, seed=123)

For distributed training, ensure workers use consistent seeds:

# All workers use same seed for consistent global ordering
mixed = MixedShardedSource(
    sources=sources,
    seed=42,  # Same across all workers
)

Next Steps#