easydel.modules.clip.modeling_clip_flax

Contents

easydel.modules.clip.modeling_clip_flax#

class easydel.modules.clip.modeling_clip_flax.CLIPAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

CLIP Attention module, supporting both text (causal) and vision (non-causal) attention.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer encoder consisting of CLIPEncoderLayer layers.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property causal_mask#

Returns the causal mask if the encoder is for text, otherwise None.

Returns

Causal mask.

Return type

Optional[chex.Array]

class easydel.modules.clip.modeling_clip_flax.CLIPEncoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single CLIP encoder layer, combining self-attention and MLP.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPForImageClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

CLIP vision model with an image classification head on top (a linear layer on the pooled final hidden state).

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

CLIP MLP (Feed-Forward) layer.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

compute_loss(*, labels=None, loss_config=None, loss_kwargs=None, **batch) Tuple[Any, CLIPOutput][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

get_image_features(pixel_values: Union[Array, ndarray, bool, number])[source]#
get_text_features(input_ids: Union[Array, ndarray, bool, number], attention_mask: Optional[Union[Array, ndarray, bool, number]] = None, position_ids: Optional[Union[Array, ndarray, bool, number]] = None)[source]#
class easydel.modules.clip.modeling_clip_flax.CLIPTextEmbeddings(*args: Any, **kwargs: Any)[source]#

Bases: Module

Constructs the text embeddings for CLIP.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPTextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Bare CLIP text model (transformer) outputting raw hidden-states without any specific head on top.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPTextModelWithProjection(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

CLIP text model with a projection layer on top.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPTextTransformer(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The transformer encoder for the CLIP text model.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPVisionEmbeddings(*args: Any, **kwargs: Any)[source]#

Bases: Module

Constructs the vision embeddings for CLIP.

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Bare CLIP vision model (transformer) outputting raw hidden-states without any specific head on top.

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip_flax.CLIPVisionTransformer(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The transformer encoder for the CLIP vision model.

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

easydel.modules.clip.modeling_clip_flax.clip_loss(similarity: Array) Array[source]#

Computes the CLIP loss.

Parameters

similarity (jax.Array) – Similarity matrix.

Returns

CLIP loss.

Return type

jax.Array

easydel.modules.clip.modeling_clip_flax.contrastive_loss(logits: Array) Array[source]#

Computes the contrastive loss.

Parameters

logits (jax.Array) – Logits from the model.

Returns

Contrastive loss.

Return type

jax.Array