easydel.modules.clip.modeling_clip

Contents

easydel.modules.clip.modeling_clip#

class easydel.modules.clip.modeling_clip.CLIPAttention(*args: Any, **kwargs: Any)[source]#

Bases: AttentionModule

CLIP Attention module, supporting both text (causal) and vision (non-causal) attention.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip.CLIPEncoder(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer encoder consisting of CLIPEncoderLayer layers.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

property causal_mask#

Returns the causal mask if the encoder is for text, otherwise None.

Returns

Causal mask.

Return type

Optional[chex.Array]

class easydel.modules.clip.modeling_clip.CLIPEncoderLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

Single CLIP encoder layer, combining self-attention and MLP.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip.CLIPForImageClassification(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

CLIP vision model with an image classification head on top (a linear layer on the pooled final hidden state).

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. This is an encoder-only model for classification.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

get_lm_head()[source]#

Returns the language model head of the module. This model has an image classification head, not a language model head.

class easydel.modules.clip.modeling_clip.CLIPMLP(*args: Any, **kwargs: Any)[source]#

Bases: Module

CLIP MLP (Feed-Forward) layer.

config#

Configuration object.

Type

Union[CLIPTextConfig, CLIPVisionConfig]

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip.CLIPModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Contrastive CLIP model wiring together text and vision towers with projection heads.

compute_loss(*, labels=None, loss_config=None, loss_kwargs=None, **batch) tuple[Any, easydel.infra.modeling_outputs.CLIPOutput][source]#

Computes the loss for the model given a batch of inputs and labels.

This method performs a forward pass using the provided batch arguments, then calculates the loss using the determined loss_function. It handles potential label inference (e.g., using input_ids as labels for Causal LM) and default loss configurations.

Parameters
  • labels (tp.Optional[chex.Array], optional) – The target labels. If None and the task is Causal LM, input_ids from the batch might be used. Defaults to None.

  • loss_config (tp.Optional[LossConfig], optional) – Specific configuration for the loss calculation. If None, defaults might be inferred (e.g., for sequence classification). Defaults to None.

  • loss_kwargs (tp.Optional[tp.Dict], optional) – Additional keyword arguments to pass directly to the loss function. Defaults to None.

  • **batch – Keyword arguments representing the input batch (e.g., input_ids, attention_mask).

Returns

A tuple containing:
  • The model’s output ( Pytree typically including logits, hidden states etc.)

  • A LossMetrics object containing the calculated loss and potentially other metrics.

Return type

tp.Tuple[tp.Any, LossMetrics]

Raises
  • AssertionError – If labels are required for the loss function but are not provided or inferred.

  • AssertionError – If sequence classification loss is used without num_labels in the config.

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. The text model acts as the “decoder” or text processor in this multi-modal setup.

get_embedding()[source]#

Returns the embedding layer of the text model.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition. The vision tower acts as the encoder in this multi-modal setup.

get_image_features(pixel_values: Union[Array, ndarray, bool, number])[source]#
get_lm_head()[source]#

Returns the language model head of the module. This model does not have a traditional language model head, but projection heads.

get_text_features(input_ids: Int[Array, 'batch seq_len'], attention_mask: jaxtyping.Bool[Array, 'batch seq_len'] | None = None, mask_info: ejkernel.types.mask.MaskInfo | None = None, position_ids: jaxtyping.Int[Array, 'batch seq_len'] | None = None)[source]#
class easydel.modules.clip.modeling_clip.CLIPTextEmbeddings(*args: Any, **kwargs: Any)[source]#

Bases: Module

Constructs the text embeddings for CLIP.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip.CLIPTextModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Bare CLIP text model (transformer) outputting raw hidden-states without any specific head on top.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. This is an encoder-only model.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

get_lm_head()[source]#

Returns the language model head of the module. Base Models don’t have a Language Model Head.

class easydel.modules.clip.modeling_clip.CLIPTextModelWithProjection(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

CLIP text model with a projection layer on top.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. This is an encoder-only model.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

get_lm_head()[source]#

Returns the language model head of the module. This model has a projection head, not a language model head.

class easydel.modules.clip.modeling_clip.CLIPTextTransformer(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The transformer encoder for the CLIP text model.

config#

Configuration object.

Type

CLIPTextConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. This is an encoder-only model.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

get_lm_head()[source]#

Returns the language model head of the module. This model has a projection head, not a language model head.

class easydel.modules.clip.modeling_clip.CLIPVisionEmbeddings(*args: Any, **kwargs: Any)[source]#

Bases: Module

Constructs the vision embeddings for CLIP.

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

class easydel.modules.clip.modeling_clip.CLIPVisionModel(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

Bare CLIP vision model (transformer) outputting raw hidden-states without any specific head on top.

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. This is an encoder-only model.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

get_lm_head()[source]#

Returns the language model head of the module. This vision model does not have a language model head.

class easydel.modules.clip.modeling_clip.CLIPVisionTransformer(*args: Any, **kwargs: Any)[source]#

Bases: EasyDeLBaseModule

The transformer encoder for the CLIP vision model.

config#

Configuration object.

Type

CLIPVisionConfig

dtype#

Data type for computation.

Type

jnp.dtype

param_dtype#

Data type for parameters.

Type

jnp.dtype

precision#

JAX precision level.

Type

jax.lax.PrecisionLike

rngs#

Random number generators.

Type

nn.Rngs

get_decoder()[source]#

Returns the decoder part of the model’s graph definition. This is an encoder-only model.

get_embedding()[source]#

Returns the embedding layer of the module.

get_encoder()[source]#

Returns the encoder part of the model’s graph definition.

get_lm_head()[source]#

Returns the language model head of the module. This vision model does not have a language model head.

easydel.modules.clip.modeling_clip.clip_loss(similarity: Array) Array[source]#

Computes the CLIP loss.

Parameters

similarity (jax.Array) – Similarity matrix.

Returns

CLIP loss.

Return type

jax.Array

easydel.modules.clip.modeling_clip.contrastive_loss(logits: Array) Array[source]#

Computes the contrastive loss.

Parameters

logits (jax.Array) – Logits from the model.

Returns

Contrastive loss.

Return type

jax.Array