Dataset Mixing#
EasyData provides powerful dataset mixing capabilities with static weights, dynamic scheduling, and block-based deterministic mixing.
Quick Start#
Simple Mixing with HuggingFace Datasets#
from datasets import load_dataset
from easydel.data import block_mixture_interleave
# Load datasets
code_ds = load_dataset("bigcode/starcoderdata", split="train", streaming=True)
text_ds = load_dataset("HuggingFaceFW/fineweb", split="train", streaming=True)
math_ds = load_dataset("hendrycks/competition_math", split="train")
# Mix with weights using dict (recommended - explicit mapping)
mixed = block_mixture_interleave(
datasets={"code": code_ds, "text": text_ds, "math": math_ds},
weights={"code": 0.4, "text": 0.5, "math": 0.1},
block_size=1000,
seed=42,
stop="restart",
)
# Use with trainer
trainer = ed.SFTTrainer(train_dataset=mixed, ...)
block_mixture_interleave#
The simplest way to mix HuggingFace datasets.
from easydel.data import block_mixture_interleave
# Dict format with explicit name-to-dataset mapping
mixed = block_mixture_interleave(
datasets={"code": ds1, "text": ds2, "math": ds3},
weights={"code": 0.5, "text": 0.3, "math": 0.2},
block_size=1000,
seed=42,
stop="restart",
)
# Equal weights
mixed = block_mixture_interleave(
datasets={"code": ds1, "text": ds2},
weights=None, # Equal 50/50
)
Parameters:
Parameter |
Type |
Description |
|---|---|---|
|
dict |
Dict mapping names to datasets: |
|
dict or None |
Dict mapping names to weights (keys must match), None = equal |
|
int |
Number of examples per mixing block |
|
int |
Random seed for shuffling within blocks |
|
str |
|
How it works:
Divides training into blocks of
block_sizeexamplesWithin each block, samples according to weights
Shuffles within block for variety
Uses deterministic RNG per block for reproducibility
MixedShardedSource#
For ShardedDataSource-based mixing with more control.
from easydel.data import MixedShardedSource, HFDatasetShardedSource
# Wrap datasets as ShardedDataSource
source1 = HFDatasetShardedSource(dataset1)
source2 = HFDatasetShardedSource(dataset2)
# Mix with static weights
mixed = MixedShardedSource(
sources={"code": source1, "text": source2},
weights={"code": 0.3, "text": 0.7},
block_size=1000,
seed=42,
stop_strategy="restart",
)
# Iterate
for example in mixed.open_shard(mixed.shard_names[0]):
print(example["__source__"]) # Shows which dataset
Parameters:
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
dict |
Required |
Name to ShardedDataSource mapping |
|
dict |
None |
Static weights (None = equal) |
|
int |
1000 |
Examples per mixing block |
|
int |
None |
Random seed |
|
str |
“restart” |
|
|
WeightScheduler |
None |
Dynamic weight scheduler |
Dynamic Weight Scheduling#
Change dataset weights during training:
from easydel.data import (
MixedShardedSource,
HFDatasetShardedSource,
WeightScheduler,
WeightSchedulePoint,
)
# Define schedule
scheduler = WeightScheduler(
schedule=[
WeightSchedulePoint(step=0, weights={"code": 0.2, "text": 0.8}),
WeightSchedulePoint(step=10000, weights={"code": 0.5, "text": 0.5}),
WeightSchedulePoint(step=50000, weights={"code": 0.8, "text": 0.2}),
],
interpolation="linear", # "step", "linear", or "cosine"
)
# Create sources
code_source = HFDatasetShardedSource(code_dataset)
text_source = HFDatasetShardedSource(text_dataset)
# Mix with scheduler
mixed = MixedShardedSource(
sources={"code": code_source, "text": text_source},
weight_scheduler=scheduler,
block_size=1000,
)
Interpolation Types#
Type |
Description |
|---|---|
|
Jump to new weights at each schedule point |
|
Linearly interpolate between schedule points |
|
Smooth cosine annealing between points |
Visualizing Schedule#
scheduler = WeightScheduler(
schedule=[
WeightSchedulePoint(step=0, weights={"code": 0.2, "text": 0.8}),
WeightSchedulePoint(step=10000, weights={"code": 0.5, "text": 0.5}),
WeightSchedulePoint(step=50000, weights={"code": 0.8, "text": 0.2}),
],
interpolation="linear",
)
# Check weights at any step
print(scheduler.get_weights(0)) # {"code": 0.2, "text": 0.8}
print(scheduler.get_weights(5000)) # {"code": 0.35, "text": 0.65}
print(scheduler.get_weights(10000)) # {"code": 0.5, "text": 0.5}
print(scheduler.get_weights(30000)) # {"code": 0.65, "text": 0.35}
Stop Strategies#
Strategy |
Behavior |
|---|---|
|
Loop exhausted datasets (infinite iteration) |
|
Stop when any dataset is exhausted |
|
Stop when all datasets are exhausted |
# Infinite training - datasets loop
mixed = MixedShardedSource(sources, stop_strategy="restart")
# Stop at smallest dataset
mixed = MixedShardedSource(sources, stop_strategy="first_exhausted")
# Train until all data seen at least once
mixed = MixedShardedSource(sources, stop_strategy="all_exhausted")
Pipeline API Mixing#
Using the Pipeline fluent API:
from easydel.data import (
Pipeline,
PipelineConfig,
DatasetConfig,
MixStageConfig,
)
config = PipelineConfig(
datasets=[
DatasetConfig(
name="code",
data_files="code_data/*.parquet",
tokenizer="meta-llama/Llama-2-7b",
),
DatasetConfig(
name="text",
data_files="text_data/*.parquet",
tokenizer="meta-llama/Llama-2-7b",
),
],
mix=MixStageConfig(
weights={"code": 0.3, "text": 0.7},
block_size=1000,
stop_strategy="restart",
),
)
pipeline = Pipeline.from_config(config)
for batch in pipeline.source().tokenize().mix().load().build():
train_step(batch)
MixStageConfig#
from easydel.data import MixStageConfig, WeightSchedulePoint
config = MixStageConfig(
weights={"code": 0.3, "text": 0.7}, # Static weights
weight_schedule=[ # Or dynamic schedule
WeightSchedulePoint(step=0, weights={"code": 0.2, "text": 0.8}),
WeightSchedulePoint(step=10000, weights={"code": 0.5, "text": 0.5}),
],
weight_schedule_type="linear", # "step", "linear", "cosine"
block_size=1000,
stop_strategy="restart",
seed=42,
)
CompositeShardedSource#
For mixing without weights (simple concatenation):
from easydel.data import CompositeShardedSource, ParquetShardedSource
source1 = ParquetShardedSource("data1/*.parquet")
source2 = ParquetShardedSource("data2/*.parquet")
# Simple concatenation (no mixing)
combined = CompositeShardedSource([source1, source2])
# All shards from both sources
for shard in combined.shard_names:
for example in combined.open_shard(shard):
process(example)
Legacy DatasetMixture API#
For backward compatibility:
from easydel.data import DatasetMixture, TextDatasetInform, build_dataset
mixture = DatasetMixture(
informs=[
TextDatasetInform(
type="parquet",
data_files="data1/*.parquet",
content_field="text",
),
TextDatasetInform(
type="json",
data_files="data2/*.json",
content_field="content",
),
],
mixture_weights={"ds1": 0.7, "ds2": 0.3},
block_mixture=True,
mixture_block_size=2048,
streaming=True,
seed=42,
)
dataset = build_dataset(mixture)
Best Practices#
1. Choose Block Size Carefully#
# Smaller blocks = more interleaving but more overhead
mixed = block_mixture_interleave(..., block_size=100)
# Larger blocks = less overhead but more bursty
mixed = block_mixture_interleave(..., block_size=10000)
# Recommended: 1000-2000 for good balance
mixed = block_mixture_interleave(..., block_size=1000)
2. Use Dynamic Scheduling for Curriculum#
# Start with easier data, gradually increase difficulty
scheduler = WeightScheduler(
schedule=[
WeightSchedulePoint(step=0, weights={"easy": 0.9, "hard": 0.1}),
WeightSchedulePoint(step=50000, weights={"easy": 0.5, "hard": 0.5}),
WeightSchedulePoint(step=100000, weights={"easy": 0.1, "hard": 0.9}),
],
interpolation="cosine", # Smooth transition
)
3. Handle Imbalanced Datasets#
# Small dataset with important data
scheduler = WeightScheduler(
schedule=[
# High weight initially to ensure exposure
WeightSchedulePoint(step=0, weights={"small_important": 0.5, "large": 0.5}),
# Reduce after sufficient exposure
WeightSchedulePoint(step=10000, weights={"small_important": 0.1, "large": 0.9}),
],
)
4. Track Data Source#
mixed = MixedShardedSource(sources, ...)
for example in mixed.open_shard(mixed.shard_names[0]):
source_name = example.get("__source__") # Added automatically
if source_name == "code":
# Apply code-specific processing
pass
Reproducibility#
Block-based mixing is deterministic given the same seed:
# Same seed = same sequence
mixed1 = block_mixture_interleave(datasets, seed=42)
mixed2 = block_mixture_interleave(datasets, seed=42)
# Different seeds = different sequences
mixed3 = block_mixture_interleave(datasets, seed=123)
For distributed training, ensure workers use consistent seeds:
# All workers use same seed for consistent global ordering
mixed = MixedShardedSource(
sources=sources,
seed=42, # Same across all workers
)
Next Steps#
Pre-tokenization - Tokenize before mixing
Pipeline API - Full pipeline with mixing
Streaming - Stream mixed datasets from cloud