diff --git a/src/peakrdl_regblock/__peakrdl__.py b/src/peakrdl_regblock/__peakrdl__.py index ad69375..eaf6543 100644 --- a/src/peakrdl_regblock/__peakrdl__.py +++ b/src/peakrdl_regblock/__peakrdl__.py @@ -21,7 +21,7 @@ CPUIF_DICT = { } # Load any user-plugins -for ep in entry_points.get_entry_points("peakrdl_regblock.cpuif"): # type: ignore +for ep, dist in entry_points.get_entry_points("peakrdl_regblock.cpuif"): name = ep.name cpuif = ep.load() if name in CPUIF_DICT: diff --git a/src/peakrdl_regblock/entry_points.py b/src/peakrdl_regblock/entry_points.py index 48cdb14..dd8e2d7 100644 --- a/src/peakrdl_regblock/entry_points.py +++ b/src/peakrdl_regblock/entry_points.py @@ -1,19 +1,50 @@ -# type: ignore import sys +from typing import List, Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from importlib.metadata import EntryPoint, Distribution if sys.version_info >= (3,10,0): from importlib import metadata - def get_entry_points(group_name): - return metadata.entry_points().select(group=group_name) + def _get_entry_points(group_name: str) -> List[Tuple['EntryPoint', 'Distribution']]: + eps = [] + for ep in metadata.entry_points().select(group=group_name): + eps.append((ep, ep.dist)) + return eps + + def _get_name_from_dist(dist: 'Distribution') -> str: + return dist.name elif sys.version_info >= (3,8,0): from importlib import metadata - def get_entry_points(group_name): - return metadata.entry_points().get(group_name, tuple()) + def _get_entry_points(group_name: str) -> List[Tuple['EntryPoint', 'Distribution']]: + eps = [] + for dist in metadata.distributions(): + for ep in dist.entry_points: + if ep.group == group_name: + eps.append((ep, dist)) + return eps + + def _get_name_from_dist(dist: 'Distribution') -> str: + return dist.metadata["Name"] else: - import pkg_resources - def get_entry_points(group_name): - return pkg_resources.iter_entry_points(group_name) + import pkg_resources # type: ignore + + def _get_entry_points(group_name: str) -> List[Tuple['EntryPoint', 'Distribution']]: + eps = [] + for ep in pkg_resources.iter_entry_points(group_name): + eps.append((ep, ep.dist)) + return eps + + def _get_name_from_dist(dist: 'Distribution') -> str: + return dist.project_name # type: ignore + + +def get_entry_points(group_name: str) -> List[Tuple['EntryPoint', 'Distribution']]: + return _get_entry_points(group_name) + +def get_name_from_dist(dist: 'Distribution') -> str: + return _get_name_from_dist(dist)