Source code for easydel.scripts.install_on_hosts

"""
Install EasyDel and dependencies on Ray TPU pods.

This script automates the installation of EasyDel and its dependencies across all hosts in a TPU pod.
It uses Ray for distributed execution and TPUExecutor for TPU-specific orchestration.

Usage:
    python -m easydel.scripts.install_on_hosts \
        --tpu-type <TPU_TYPE> \
        --source <pypi|github>

Options:
    --tpu-type     TPU pod slice type (e.g. v4-16, v3-8)
    --source       Installation source: 'pypi' for PyPI package or 'github' for latest from GitHub

Example:
    python -m easydel.scripts.install_on_hosts --tpu-type v4-16 --source github
"""

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python

import argparse
import sys

import ray
from eformer.executor.ray import TpuAcceleratorConfig, execute


@ray.remote
def install_easydel_on_pods_pypi():
    """Installs EasyDel[tf] from PyPI and other dependencies on Ray nodes."""
    import os  # Import within the function

    node_id = ray.get_runtime_context().get_node_id()
    print(f"Node {node_id}: Installing EasyDel from PyPI...")
    os.system("pip install --upgrade pip -q")
    os.system("pip install easydel[tf] -qU")
    os.system("pip install jax[tpu] -qU")
    os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -qU")
    print(f"Node {node_id}: Installation from PyPI complete.")
    return True


@ray.remote
def install_easydel_on_pods_github():
    """Installs EasyDel[tf] from GitHub head and other dependencies on Ray nodes."""
    import os

    node_id = ray.get_runtime_context().get_node_id()
    print(f"Node {node_id}: Installing EasyDel from GitHub head...")
    os.system("pip install --upgrade pip -q")
    os.system("pip uninstall easydel -y -q")
    os.system("pip install 'easydel[tf] @ git+https://github.com/erfanzar/easydel.git' -qU")
    os.system("pip install jax[tpu] -qU")
    os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -qU")
    print(f"Node {node_id}: Installation from GitHub head complete.")
    return True


[docs]def main(): parser = argparse.ArgumentParser( description="Install EasyDel and dependencies on Ray TPU pods. Requires Ray and eformer.escale.tpexec.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--source", choices=["pypi", "github"], default="pypi", help="Choose the source for EasyDel installation: PyPI package or GitHub head.", ) parser.add_argument("--tpu-type", type=str, default="v4-16", help="The type of TPU pod slice to use.") args = parser.parse_args() if args.source == "github": install_func = install_easydel_on_pods_github print("Selected installation source: GitHub head ") else: install_func = install_easydel_on_pods_pypi print("Selected installation source: PyPI") tpu_type = args.tpu_type print(f"Selected TPU type: {tpu_type}") print(f"Determined number of hosts for {tpu_type}") config = TpuAcceleratorConfig(tpu_version=tpu_type) try: print("Initializing Ray...") ray.init("auto") print(f"Ray initialized successfully. Cluster resources: {ray.cluster_resources()}") results = ray.get(execute(config)(install_func)()) print("\nExecution command sent. Waiting for remote tasks (TPUExecutor might block or manage this)...") if results: print(f"Received results from execution (structure depends on TPUExecutor): {results}") print("\nInstallation process initiated (or completed, depending on executor) on pods.") except Exception as e: print(f"\nAn error occurred during Ray initialization or execution: {e}") import traceback traceback.print_exc() if ray.is_initialized(): print("Attempting to shutdown Ray...") ray.shutdown() print("Ray shutdown.") sys.exit(1)
if __name__ == "__main__": main()