Source code for easydel.utils.lazy_import

# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import os
import typing as tp
from itertools import chain
from types import ModuleType

BACKENDS_T = frozenset[str]
IMPORT_STRUCTURE_T = dict[BACKENDS_T, dict[str, set[str]]]


[docs]class LazyModule(ModuleType): # copy pasted from huggingface lazy module def __init__( self, name: str, module_file: str, import_structure: IMPORT_STRUCTURE_T, module_spec: importlib.machinery.ModuleSpec | None = None, extra_objects: dict[str, object] | None = None, ): super().__init__(name) self._object_missing_backend = {} if any(isinstance(key, frozenset) for key in import_structure.keys()): self._modules = set() self._class_to_module = {} self.__all__ = [] _import_structure = {} for _backends, module in import_structure.items(): self._modules = self._modules.union(set(module.keys())) for key, values in module.items(): for value in values: self._class_to_module[value] = key _import_structure.setdefault(key, []).extend(values) # Needed for autocompletion in an IDE self.__all__.extend(list(module.keys()) + list(chain(*module.values()))) self.__file__ = module_file self.__spec__ = module_spec self.__path__ = [os.path.dirname(module_file)] self._objects = {} if extra_objects is None else extra_objects self._name = name self._import_structure = _import_structure # This can be removed once every exportable object has a `export()` export. else: self._modules = set(import_structure.keys()) self._class_to_module = {} for key, values in import_structure.items(): for value in values: self._class_to_module[value] = key # Needed for autocompletion in an IDE self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) self.__file__ = module_file self.__spec__ = module_spec self.__path__ = [os.path.dirname(module_file)] self._objects = {} if extra_objects is None else extra_objects self._name = name self._import_structure = import_structure def __dir__(self): result = super().__dir__() for attr in self.__all__: if attr not in result: result.append(attr) return result def __getattr__(self, name: str) -> tp.Any: if name in self._objects: return self._objects[name] if name in self._object_missing_backend.keys(): missing_backends = self._object_missing_backend[name] class Placeholder(metaclass=DummyObject): _backends = missing_backends Placeholder.__name__ = name Placeholder.__module__ = self.__spec__ value = Placeholder elif name in self._class_to_module.keys(): module = self._get_module(self._class_to_module[name]) value = getattr(module, name) elif name in self._modules: value = self._get_module(name) else: raise AttributeError(f"module {self.__name__} has no attribute {name}") setattr(self, name, value) return value def _get_module(self, module_name: str): try: return importlib.import_module("." + module_name, self.__name__) except Exception as e: raise RuntimeError( f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" f" traceback):\n{e}" ) from e def __reduce__(self): return (self.__class__, (self._name, self.__file__, self._import_structure))
[docs]class DummyObject(type): """ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by `requires_backend` each time a user tries to access any method of that class. """ def __getattribute__(cls, key): if key.startswith("_") and key != "_from_config": return super().__getattribute__(key)
[docs]def is_package_available(package_name: str) -> bool: """ Checks if a package is available in the current Python environment. Args: package_name: The name of the package to check (e.g., "numpy"). Returns: True if the package is available, False otherwise. """ return importlib.util.find_spec(package_name.replace("-", "_")) is not None