Source code for easydel.utils.lazy_import
# Copyright 2025 The EasyDeL Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import typing as tp
from itertools import chain
from types import ModuleType
BACKENDS_T = frozenset[str]
IMPORT_STRUCTURE_T = dict[BACKENDS_T, dict[str, set[str]]]
[docs]class LazyModule(ModuleType):
# copy pasted from huggingface lazy module
def __init__(
self,
name: str,
module_file: str,
import_structure: IMPORT_STRUCTURE_T,
module_spec: importlib.machinery.ModuleSpec | None = None,
extra_objects: dict[str, object] | None = None,
):
super().__init__(name)
self._object_missing_backend = {}
if any(isinstance(key, frozenset) for key in import_structure.keys()):
self._modules = set()
self._class_to_module = {}
self.__all__ = []
_import_structure = {}
for _backends, module in import_structure.items():
self._modules = self._modules.union(set(module.keys()))
for key, values in module.items():
for value in values:
self._class_to_module[value] = key
_import_structure.setdefault(key, []).extend(values)
# Needed for autocompletion in an IDE
self.__all__.extend(list(module.keys()) + list(chain(*module.values())))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = _import_structure
# This can be removed once every exportable object has a `export()` export.
else:
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
self.__file__ = module_file
self.__spec__ = module_spec
self.__path__ = [os.path.dirname(module_file)]
self._objects = {} if extra_objects is None else extra_objects
self._name = name
self._import_structure = import_structure
def __dir__(self):
result = super().__dir__()
for attr in self.__all__:
if attr not in result:
result.append(attr)
return result
def __getattr__(self, name: str) -> tp.Any:
if name in self._objects:
return self._objects[name]
if name in self._object_missing_backend.keys():
missing_backends = self._object_missing_backend[name]
class Placeholder(metaclass=DummyObject):
_backends = missing_backends
Placeholder.__name__ = name
Placeholder.__module__ = self.__spec__
value = Placeholder
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
elif name in self._modules:
value = self._get_module(name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str):
try:
return importlib.import_module("." + module_name, self.__name__)
except Exception as e:
raise RuntimeError(
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e
def __reduce__(self):
return (self.__class__, (self._name, self.__file__, self._import_structure))
[docs]class DummyObject(type):
"""
Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
`requires_backend` each time a user tries to access any method of that class.
"""
def __getattribute__(cls, key):
if key.startswith("_") and key != "_from_config":
return super().__getattribute__(key)
[docs]def is_package_available(package_name: str) -> bool:
"""
Checks if a package is available in the current Python environment.
Args:
package_name: The name of the package to check (e.g., "numpy").
Returns:
True if the package is available, False otherwise.
"""
return importlib.util.find_spec(package_name.replace("-", "_")) is not None