# Copyright 2023 The EASYDEL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import typing as tp
from dataclasses import field
from pathlib import Path
from eformer.pytree import auto_pytree
from jinja2 import Environment, FileSystemLoader, select_autoescape
from easydel import __version__
from easydel.utils.helpers import get_logger
logger = get_logger(__name__)
JINJA_TEMPLATE = """
---
tags:
- EasyDeL
- {{model.type}}
- safetensors
- TPU
- GPU
- XLA
- Flax
---
# {{ model.name }}
[](https://github.com/erfanzar/EasyDeL)
[](https://github.com/erfanzar/EasyDeL)
{{ model.description if model.description else "A model implemented using the EasyDeL framework, designed to deliver optimal performance for large-scale natural language processing tasks." }}
## Overview
{{ model.overview if model.overview else "EasyDeL provides an efficient, highly-optimized, and customizable machine learning model compatible with both GPU and TPU environments. Built with JAX, this model supports advanced features such as sharded model parallelism, making it suitable for distributed training and inference and customized kernels." }}
## Features
{% if model.features %}
{% for feature in model.features %}
- {{ feature }}
{% endfor %}
{% else %}
- **Efficient Implementation**: Built with JAX/Flax for high-performance computation.
- **Multi-Device Support**: Optimized to run on TPU, GPU, and CPU environments for sharding model over 2^(1-1000+) of devices.
- **Sharded Model Parallelism**: Supports model parallelism across multiple devices for scalability.
- **Customizable Precision**: Allows specification of floating-point precision for performance optimization.
{% endif %}
## Installation
To install EasyDeL, simply run:
```bash
pip install easydel
```
## Usage
### Loading the Pre-trained Model
To load a pre-trained version of the model with EasyDeL:
```python
from easydel import AutoEasyDeLModelForCausalLM
from jax import numpy as jnp, lax
max_length = None # can be set to use lower memory for caching
# Load model and parameters
model = AutoEasyDeLModelForCausalLM.from_pretrained(
"{{ model.repo_id }}",
config_kwargs=ed.EasyDeLBaseConfigDict(
use_scan_mlp=False,
attn_dtype=jnp.float16,
freq_max_position_embeddings=max_length,
mask_max_position_embeddings=max_length,
attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2
),
dtype=jnp.float16,
param_dtype=jnp.float16,
precision=lax.Precision("fastest"),
auto_shard_model=True,
)
```
## Supported Tasks
{% if model.supported_tasks %}
This model is well-suited for the following tasks:
{% for task in model.supported_tasks %}
- **{{ task }}**
{% endfor %}
{% else %}
[Need more information]
{% endif %}
## Limitations
{% if model.limitations %}
{% for limitation in model.limitations %}
- {{ limitation }}
{% endfor %}
{% else %}
- **Hardware Dependency**: Performance can vary significantly based on the hardware used.
- **JAX/Flax Setup Required**: The environment must support JAX/Flax for optimal use.
- **Experimental Features**: Some features (like custom kernel usage or ed-ops) may require additional configuration and tuning.
{% endif %}
"""
[docs]@auto_pytree
class ModelInfo:
"""Model information container."""
name: str = field(
metadata={"help": "The name of the model."},
)
type: str = field(
metadata={"help": "The type of the model."},
)
repo_id: str = field(
metadata={"help": "The repository ID of the model."},
)
description: tp.Optional[str] = field(
default=None,
metadata={"help": "A description of the model."},
)
model_type: tp.Optional[str] = field(
default=None,
metadata={"help": "The model type."},
)
model_task: tp.Optional[str] = field(
default=None,
metadata={"help": "The model task."},
)
features: tp.Optional[tp.List[str]] = field(
default=None,
metadata={"help": "A list of features of the model."},
)
supported_tasks: tp.Optional[tp.List[str]] = field(
default=None,
metadata={"help": "A list of tasks supported by the model."},
)
limitations: tp.Optional[tp.List[str]] = field(
default=None,
metadata={"help": "A list of limitations of the model."},
)
version: str = field(
default=__version__,
metadata={"help": "The version of the model."},
)
[docs]class ReadmeGenerator:
"""Generate README files for EasyDeL models."""
def __init__(self, template_dir: tp.Optional[str] = None):
"""
Initialize the README generator.
Args:
template_dir: tp.Optional custom template directory path
"""
# Setup Jinja environment
if template_dir and os.path.exists(template_dir):
self.env = Environment(
loader=FileSystemLoader(template_dir),
autoescape=select_autoescape(["html", "xml"]),
)
else:
# Use default template
self.env = Environment(
loader=FileSystemLoader(os.path.dirname(__file__)),
autoescape=select_autoescape(["html", "xml"]),
)
[docs] def generate_readme(
self,
model_info: ModelInfo,
output_path: tp.Optional[str] = None,
) -> str:
"""
Generate README content for a model.
Args:
model_info: Model information
output_path: tp.Optional path to save the README
template_name: Name of the template to use
Returns:
Generated README content
"""
try:
template = self.env.from_string(JINJA_TEMPLATE)
content = template.render(model=model_info)
if output_path:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(content)
logger.info(f"README saved to {output_path}")
return content
except Exception as e:
logger.error(f"Error generating README: {str(e)}")
raise
# Example usage
if __name__ == "__main__":
model_info = ModelInfo(
name="LLaMA-2-7B-EasyDeL",
type="CausalLM",
repo_id="erfanzar/LLaMA-2-7B-EasyDeL",
)
generator = ReadmeGenerator()
readme = generator.generate_readme(model_info, "tmp-files/readme.md")