|  | 
| 1 |  | -from typing import TYPE_CHECKING, Dict, Optional, Union | 
|  | 1 | +from typing import TYPE_CHECKING, Dict, List, Optional, Union | 
| 2 | 2 | 
 | 
| 3 | 3 | import aiohttp | 
| 4 | 4 | from aiohttp.client_exceptions import ClientResponseError | 
| 5 | 5 | from aleph_message.models import ItemHash | 
|  | 6 | +from pydantic import BaseModel | 
| 6 | 7 | 
 | 
| 7 | 8 | from aleph.sdk.conf import settings | 
| 8 | 9 | from aleph.sdk.exceptions import MethodNotAvailableOnCRN, VmNotFoundOnHost | 
|  | 
| 13 | 14 |     from aleph.sdk.client.http import AlephHttpClient | 
| 14 | 15 | 
 | 
| 15 | 16 | 
 | 
|  | 17 | +class GPU(BaseModel): | 
|  | 18 | +    vendor: str | 
|  | 19 | +    model: str | 
|  | 20 | +    device_name: str | 
|  | 21 | +    device_class: str | 
|  | 22 | +    pci_host: str | 
|  | 23 | +    compatible: bool | 
|  | 24 | + | 
|  | 25 | + | 
|  | 26 | +class NetworkGPUS(BaseModel): | 
|  | 27 | +    total_gpu_count: int | 
|  | 28 | +    available_gpu_count: int | 
|  | 29 | +    available_gpu_list: dict[str, List[GPU]]  # str = node_url | 
|  | 30 | +    used_gpu_list: dict[str, List[GPU]]  # str = node_url | 
|  | 31 | + | 
|  | 32 | + | 
| 16 | 33 | class Crn: | 
| 17 | 34 |     """ | 
| 18 | 35 |     This services allow interact with CRNS API | 
| @@ -136,3 +153,48 @@ async def update_instance_config(self, crn_address: str, item_hash: ItemHash): | 
| 136 | 153 |             async with session.post(full_url) as resp: | 
| 137 | 154 |                 resp.raise_for_status() | 
| 138 | 155 |                 return await resp.json() | 
|  | 156 | + | 
|  | 157 | +    # Gpu Functions Helper | 
|  | 158 | +    async def fetch_gpu_on_network( | 
|  | 159 | +        self, | 
|  | 160 | +        crn_list: Optional[List[dict]] = None, | 
|  | 161 | +    ) -> NetworkGPUS: | 
|  | 162 | +        if not crn_list: | 
|  | 163 | +            crn_list = (await self._client.crn.get_crns_list()).get("crns", []) | 
|  | 164 | + | 
|  | 165 | +        gpu_count: int = 0 | 
|  | 166 | +        available_gpu_count: int = 0 | 
|  | 167 | + | 
|  | 168 | +        compatible_gpu: Dict[str, List[GPU]] = {} | 
|  | 169 | +        available_compatible_gpu: Dict[str, List[GPU]] = {} | 
|  | 170 | + | 
|  | 171 | +        # Ensure crn_list is a list before iterating | 
|  | 172 | +        if not isinstance(crn_list, list): | 
|  | 173 | +            crn_list = [] | 
|  | 174 | + | 
|  | 175 | +        for crn_ in crn_list: | 
|  | 176 | +            if not crn_.get("gpu_support", False): | 
|  | 177 | +                continue | 
|  | 178 | + | 
|  | 179 | +            # Only process CRNs with GPU support | 
|  | 180 | +            crn_address = crn_["address"] | 
|  | 181 | + | 
|  | 182 | +            # Extracts used GPU | 
|  | 183 | +            for gpu in crn_.get("compatible_gpus", []): | 
|  | 184 | +                compatible_gpu[crn_address] = [] | 
|  | 185 | +                compatible_gpu[crn_address].append(GPU.model_validate(gpu)) | 
|  | 186 | +                gpu_count += 1 | 
|  | 187 | + | 
|  | 188 | +            # Extracts available GPU | 
|  | 189 | +            for gpu in crn_.get("compatible_available_gpus", []): | 
|  | 190 | +                available_compatible_gpu[crn_address] = [] | 
|  | 191 | +                available_compatible_gpu[crn_address].append(GPU.model_validate(gpu)) | 
|  | 192 | +                gpu_count += 1 | 
|  | 193 | +                available_gpu_count += 1 | 
|  | 194 | + | 
|  | 195 | +        return NetworkGPUS( | 
|  | 196 | +            total_gpu_count=gpu_count, | 
|  | 197 | +            available_gpu_count=available_gpu_count, | 
|  | 198 | +            used_gpu_list=compatible_gpu, | 
|  | 199 | +            available_gpu_list=available_compatible_gpu, | 
|  | 200 | +        ) | 
0 commit comments