Source code for easydel.inference.esurge.outputs

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from collections.abc import Hashable
from dataclasses import dataclass
from typing import NamedTuple, TypeVar

from jax import Array
from jax import numpy as jnp

_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")


[docs]class LogprobsLists(NamedTuple): logprob_token_ids: list[list[int]] logprobs: list[list[float]] sampled_token_ranks: list[int]
[docs] def slice(self, start: int, end: int): return LogprobsLists( self.logprob_token_ids[start:end], self.logprobs[start:end], self.sampled_token_ranks[start:end], )
[docs]class LogprobsTensors(NamedTuple): logprob_token_ids: Array logprobs: Array selected_token_ranks: Array
[docs] def tolists(self): return LogprobsLists( self.logprob_token_ids.tolist(), self.logprobs.tolist(), self.selected_token_ranks.tolist(), )
[docs] @staticmethod def empty(num_positions: int, num_tokens_per_position: int) -> LogprobsTensors: logprob_token_ids = jnp.empty((num_positions, num_tokens_per_position), dtype=jnp.int32) logprobs = jnp.empty_like(logprob_token_ids, dtype=jnp.float32) selected_token_ranks = jnp.empty(num_positions, dtype=jnp.int32) return LogprobsTensors( logprob_token_ids=logprob_token_ids, logprobs=logprobs, selected_token_ranks=selected_token_ranks, )
[docs]@dataclass class ModelRunnerOutput: req_ids: list[str] req_id_to_index: dict[str, int] sampled_token_ids: list[list[int]] spec_token_ids: list[list[int]] | None logprobs: LogprobsLists | None prompt_logprobs_dict: dict[str, LogprobsTensors | None] finished_sending: set[str] | None = None finished_recving: set[str] | None = None num_nans_in_logits: dict[str, int] | None = None token_logprobs: dict[str, float] | None = None
[docs]def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None: """ Helper function to swap values for two keys """ v1 = obj.get(key1) v2 = obj.get(key2) if v1 is not None: obj[key2] = v1 else: obj.pop(key2, None) if v2 is not None: obj[key1] = v2 else: obj.pop(key1, None)