# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from eformer.loggings import get_logger
from easydel.infra.utils import ProcessingClassType
from ..esurge import eSurge
logger = get_logger("eSurgeLMEvalAdapter")
try:
from lm_eval.api.model import LM # type:ignore
except Exception as e:
LM = object
logger.warning(
f"consider installing lm_eval if you want to use `eSurgeLMEvalAdapter` (err : {e}).",
stacklevel=1,
)
[docs]class eSurgeLMEvalAdapter(LM):
"""Adapter for EasyDeL models to be compatible with lm-evaluation-harness.
This class inherits from lm_eval.api.model.LM to ensure compatibility with the harness,
allowing EasyDeL models to be evaluated using the lm-evaluation-harness framework.
It wraps an `eSurge` instance for efficient inference with advanced features like
smart bytecode decoding and context management.
"""
def __init__(
self,
surge: eSurge,
processor: ProcessingClassType,
max_length: int = 8192,
max_new_tokens: int = 2048,
top_p: float = 0.95,
temperature: float = 0.0,
batch_size: int | None = None,
):
"""Initializes the eSurgeLMEvalAdapter.
Args:
surge: An instance of `eSurge` for model inference.
processor: The tokenizer/processor associated with the model.
max_length: The maximum context length for the model. Defaults to 8192.
max_new_tokens: Maximum number of tokens to generate. Defaults to 2048.
top_p: Top-p sampling parameter. Defaults to 0.95.
temperature: Sampling temperature. Defaults to 0.0 (greedy).
batch_size: Optional batch size override. If None, uses surge's max_num_seqs.
"""
super().__init__()
self.max_length = max_length
self.tokenizer = processor
self.tokenizer.padding_side = "left"
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
self.temperature = temperature
self.max_new_tokens = max_new_tokens
self.top_p = top_p
self.surge = surge
self._batch_size = batch_size or surge.max_num_seqs
self.model = None
self.setup_complete = False
self._setup()
def _setup(self):
"""Set up the eSurge engine.
Ensures the eSurge scheduler is running; it should auto-init at construction,
but we guard against any cases where it is not.
"""
if self.setup_complete:
return
# eSurge automatically starts scheduler in __init__
# Just verify it's running
if not self.surge._scheduler_running:
self.surge.initiate()
self.setup_complete = True
[docs] def stop(self):
"""Stop the eSurge engine.
Terminates the underlying `eSurge` scheduler thread.
"""
if self.surge:
self.surge.terminate()
def _generate(
self,
prompts: list[str],
max_tokens: int | None = None,
temperature: float | None = None,
top_p: float | None = None,
stop_sequences: list[list[str]] | None = None,
) -> list[str]:
"""Generate responses for a list of prompts.
Args:
prompts: List of prompts to generate responses for.
max_tokens: Maximum number of tokens to generate per prompt.
temperature: Sampling temperature.
top_p: Top-p sampling parameter.
stop_sequences: List of lists of stop sequences, one list per prompt.
Generation stops if any sequence in the corresponding list is encountered.
Returns:
List of generated responses.
"""
if not self.setup_complete:
self._setup()
import easydel as ed
if max_tokens is None:
max_tokens = self.max_gen_toks
if top_p is None:
top_p = self.top_p
if temperature is None:
temperature = self.temperature
# Create sampling params for each prompt
sampling_params_list = []
for i, _ in enumerate(prompts):
current_stop_seq = None
if stop_sequences and i < len(stop_sequences):
current_stop_seq = stop_sequences[i]
sampling_params_list.append(
ed.SamplingParams(
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stop=current_stop_seq if current_stop_seq else [],
n=1,
)
)
# Use eSurge's generate method
results = self.surge.generate(
prompts=prompts,
sampling_params=sampling_params_list[0]
if len(set(str(sp) for sp in sampling_params_list)) == 1
else sampling_params_list,
use_tqdm=True,
)
generated_texts = []
for i, result in enumerate(results):
text = result.get_text()
# Apply stop sequences if present
if stop_sequences and i < len(stop_sequences):
for stop in stop_sequences[i]:
if stop in text:
text = text[: text.find(stop)]
generated_texts.append(text)
assert len(generated_texts) == len(prompts), (
f"Mismatch between prompts sent ({len(prompts)}) and results received "
f"({len(generated_texts)}) from surge.generate!."
)
return generated_texts
def _extract_choice_from_generation(self, generation: str) -> str:
"""Extract a multiple-choice answer (A, B, C, D) from generated text.
Args:
generation: The generated text string.
Returns:
The extracted choice (e.g., "A", "B", "C", "D") or an empty string if no clear choice is found.
"""
import re
patterns = [
r"^([A-Da-d])[^A-Za-z0-9]",
r"^([A-Da-d])$",
r"[Aa]nswer[^A-Za-z0-9]*([A-Da-d])",
r"[Tt]he answer is[^A-Za-z0-9]*([A-Da-d])",
r"[Oo]ption[^A-Za-z0-9]*([A-Da-d])",
r"[Cc]hoice[^A-Za-z0-9]*([A-Da-d])",
]
for pattern in patterns:
match = re.search(pattern, generation.strip())
if match:
return match.group(1).upper()
first_char = generation.strip()[0:1].upper()
if first_char in "ABCD":
return first_char
return ""
[docs] def generate_until(self, instances):
"""
Generate text until a specified set of stop sequences is reached for each instance.
This method is part of the lm-evaluation-harness LM interface.
Args:
instances: List of Instance objects from lm-evaluation-harness.
Each instance is expected to contain the prompt as the first argument
and an optional dictionary as the second argument with a 'until' key
containing a list of stop sequences.
Returns:
List of generated strings, one for each instance.
"""
requests = []
for instance in instances:
prompt = instance.arguments[0]
if len(instance.arguments) > 1 and isinstance(instance.arguments[1], dict):
config = instance.arguments[1]
stop_sequences = config.get("until", [])
else:
stop_sequences = []
requests.append((prompt, stop_sequences))
generations = self._generate(
[req[0] for req in requests],
max_tokens=self.max_gen_toks,
stop_sequences=[req[1] for req in requests],
)
return generations
@property
def eot_token_id(self):
"""Get the end-of-text token ID."""
return self.tokenizer.eos_token_id
@property
def max_length(self):
"""Get the maximum context length."""
return self._max_length
@max_length.setter
def max_length(self, value):
"""Set the maximum context length."""
self._max_length = value
@property
def max_gen_toks(self):
"""Get the maximum number of tokens to generate."""
return self.max_new_tokens
@property
def batch_size(self):
"""Get the batch size."""
return self._batch_size
@property
def device(self):
"""Get the device (CPU/GPU)."""
return "cpu"
@property
def tokenizer_name(self):
"""Get the tokenizer name for chat template support.
Returns the name or path of the tokenizer/model being used.
This is required by lm_eval for proper chat template handling.
"""
# Try to get the tokenizer name from various possible attributes
if hasattr(self.tokenizer, "name_or_path") and self.tokenizer.name_or_path:
return self.tokenizer.name_or_path
elif hasattr(self.tokenizer, "tokenizer_name") and self.tokenizer.tokenizer_name:
return self.tokenizer.tokenizer_name
elif hasattr(self.tokenizer, "__class__"):
# Return the class name as a fallback
return self.tokenizer.__class__.__name__
else:
return ""
[docs] def apply_chat_template(self, messages, add_generation_prompt: bool):
"""Apply chat template to messages.
This method is required by lm_eval for chat-based evaluations.
Args:
messages: List of message dictionaries with 'role' and 'content' keys
Returns:
String with the formatted chat template applied
"""
return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=add_generation_prompt)
[docs] def tok_encode(self, string: str):
"""Encode a string into token IDs.
Args:
string: The input string.
Returns:
A list of token IDs.
"""
return self.tokenizer.encode(string)
[docs] def tok_decode(self, tokens):
"""Decode token IDs into a string.
Args:
tokens: A list or tensor of token IDs.
Returns:
The decoded string.
"""
return self.tokenizer.decode(tokens)
def _model_call(self, inps):
"""
This method is not directly used by eSurgeLMEvalAdapter but is required by the LM interface.
In our case, loglikelihood and greedy_until handle the model calls directly
by interacting with the `eSurge` instance.
Raises:
NotImplementedError: This method is not implemented as it's not used.
"""
raise NotImplementedError("eSurgeLMEvalAdapter doesn't use _model_call directly")
def _model_generate(self, context, max_length, eos_token_id):
"""Generate text from context.
This method is not directly used by eSurgeLMEvalAdapter but is required by the LM interface.
Generation is handled by the `_generate` and `generate_until` methods.
Args:
context: The input context.
max_length: The maximum length of the generated sequence.
eos_token_id: The end-of-sequence token ID.
Raises:
NotImplementedError: This method is not implemented as it's not used.
"""
raise NotImplementedError("eSurgeLMEvalAdapter doesn't use _model_generate directly")
[docs] def loglikelihood(self, instances):
"""
Compute log-likelihood of completions given contexts.
This method is part of the lm-evaluation-harness LM interface.
It currently provides a placeholder implementation, especially for non-multiple-choice tasks.
Args:
instances: List of Instance objects from lm-evaluation-harness.
For multiple-choice tasks, instances are expected to have context and continuation.
Returns:
List of (log_likelihood, is_greedy) tuples.
For multiple-choice tasks, log-likelihood is high if the extracted choice matches the continuation,
low otherwise. For other tasks, a placeholder value is returned.
"""
requests = []
for instance in instances:
if len(instance.arguments) >= 2:
context = instance.arguments[0]
continuation = instance.arguments[1]
requests.append((context, continuation))
else:
print(f"Warning: Invalid instance format: {instance}")
requests.append(("", ""))
contexts = [req[0] for req in requests]
continuations = [req[1] for req in requests]
is_mc_task = False
if contexts and len(contexts) > 0:
mc_pattern = r"[A-D]\.\s"
import re
if any(re.search(mc_pattern, ctx) for ctx in contexts):
is_mc_task = True
results = []
if is_mc_task:
choices = "ABCD"
max_tokens = 5
generations = self._generate(contexts, max_tokens=max_tokens)
for _i, (_, continuation, generation) in enumerate(zip(contexts, continuations, generations, strict=False)):
predicted_choice = self._extract_choice_from_generation(generation)
expected_choice = continuation.strip().upper()
if expected_choice and expected_choice[0] in choices:
expected_choice = expected_choice[0]
log_likelihood = 0.0 if predicted_choice == expected_choice else -100.0
is_greedy = predicted_choice == expected_choice
results.append((log_likelihood, is_greedy))
else:
for _ in zip(contexts, continuations, strict=False):
results.append((-1.0, True))
return results
[docs] def loglikelihood_rolling(self, instances):
"""
Calculate log-likelihood of token sequences in a rolling fashion.
This method is part of the lm-evaluation-harness LM interface.
It currently provides a placeholder implementation as actual rolling log-likelihood
calculation might not be directly supported by the current `eSurge` setup.
Args:
instances: List of Instance objects from lm-evaluation-harness.
Instances are expected to contain the token sequence as the first argument.
Returns:
List of lists of (loglikelihood, is_greedy) pairs, one inner list per instance.
Each inner list contains pairs for each token in the sequence (except the first).
Currently returns placeholder values.
"""
token_lists = []
for instance in instances:
if len(instance.arguments) >= 1 and isinstance(instance.arguments[0], list):
tokens = instance.arguments[0]
token_lists.append(tokens)
else:
print(f"Warning: Invalid instance format for rolling loglikelihood: {instance}")
token_lists.append([])
results = []
for tokens in token_lists:
token_results = []
for _i in range(1, len(tokens)):
log_likelihood = -2.0
is_greedy = True
token_results.append((log_likelihood, is_greedy))
results.append(token_results)
return results
[docs] def greedy_until(self, requests):
"""
Generate completions for prompts until a stop sequence is reached using greedy decoding.
This method is part of the lm-evaluation-harness LM interface.
It currently raises NotImplementedError as its functionality is covered by `generate_until`.
Args:
requests: List of (context, stopping_sequences) tuples.
Returns:
List of generated completions.
Raises:
NotImplementedError: This method is not implemented as `generate_until` provides similar functionality.
"""
raise NotImplementedError("eSurgeLMEvalAdapter doesn't use greedy_until directly, use generate_until")