|
9 | 9 | from dataclasses import dataclass, field, replace
|
10 | 10 | from pathlib import Path
|
11 | 11 | from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
|
12 |
| - Final, List, Literal, Mapping, Optional, Set, Tuple, Type, |
13 |
| - Union) |
| 12 | + Final, List, Literal, Mapping, Optional, Protocol, Set, |
| 13 | + Tuple, Type, Union) |
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | from pydantic import BaseModel, Field, PrivateAttr
|
|
75 | 75 | PretrainedConfig]]
|
76 | 76 |
|
77 | 77 |
|
| 78 | +class SupportsHash(Protocol): |
| 79 | + |
| 80 | + def compute_hash(self) -> str: |
| 81 | + ... |
| 82 | + |
| 83 | + |
78 | 84 | class ModelConfig:
|
79 | 85 | """Configuration for the model.
|
80 | 86 |
|
@@ -2969,6 +2975,10 @@ class VllmConfig:
|
2969 | 2975 | init=True) # type: ignore
|
2970 | 2976 | kv_transfer_config: KVTransferConfig = field(default=None,
|
2971 | 2977 | init=True) # type: ignore
|
| 2978 | + # some opaque config, only used to provide additional information |
| 2979 | + # for the hash computation, mainly used for testing and debugging. |
| 2980 | + additional_config: SupportsHash = field(default=None, |
| 2981 | + init=True) # type: ignore |
2972 | 2982 | instance_id: str = ""
|
2973 | 2983 |
|
2974 | 2984 | def compute_hash(self) -> str:
|
@@ -3000,33 +3010,62 @@ def compute_hash(self) -> str:
|
3000 | 3010 | vllm_factors.append(__version__)
|
3001 | 3011 | if self.model_config:
|
3002 | 3012 | vllm_factors.append(self.model_config.compute_hash())
|
| 3013 | + else: |
| 3014 | + vllm_factors.append("None") |
3003 | 3015 | if self.cache_config:
|
3004 | 3016 | vllm_factors.append(self.cache_config.compute_hash())
|
| 3017 | + else: |
| 3018 | + vllm_factors.append("None") |
3005 | 3019 | if self.parallel_config:
|
3006 | 3020 | vllm_factors.append(self.parallel_config.compute_hash())
|
| 3021 | + else: |
| 3022 | + vllm_factors.append("None") |
3007 | 3023 | if self.scheduler_config:
|
3008 | 3024 | vllm_factors.append(self.scheduler_config.compute_hash())
|
| 3025 | + else: |
| 3026 | + vllm_factors.append("None") |
3009 | 3027 | if self.device_config:
|
3010 | 3028 | vllm_factors.append(self.device_config.compute_hash())
|
| 3029 | + else: |
| 3030 | + vllm_factors.append("None") |
3011 | 3031 | if self.load_config:
|
3012 | 3032 | vllm_factors.append(self.load_config.compute_hash())
|
| 3033 | + else: |
| 3034 | + vllm_factors.append("None") |
3013 | 3035 | if self.lora_config:
|
3014 | 3036 | vllm_factors.append(self.lora_config.compute_hash())
|
| 3037 | + else: |
| 3038 | + vllm_factors.append("None") |
3015 | 3039 | if self.speculative_config:
|
3016 | 3040 | vllm_factors.append(self.speculative_config.compute_hash())
|
| 3041 | + else: |
| 3042 | + vllm_factors.append("None") |
3017 | 3043 | if self.decoding_config:
|
3018 | 3044 | vllm_factors.append(self.decoding_config.compute_hash())
|
| 3045 | + else: |
| 3046 | + vllm_factors.append("None") |
3019 | 3047 | if self.observability_config:
|
3020 | 3048 | vllm_factors.append(self.observability_config.compute_hash())
|
| 3049 | + else: |
| 3050 | + vllm_factors.append("None") |
3021 | 3051 | if self.prompt_adapter_config:
|
3022 | 3052 | vllm_factors.append(self.prompt_adapter_config.compute_hash())
|
| 3053 | + else: |
| 3054 | + vllm_factors.append("None") |
3023 | 3055 | if self.quant_config:
|
3024 | 3056 | pass # should be captured by model_config.quantization
|
3025 | 3057 | if self.compilation_config:
|
3026 | 3058 | vllm_factors.append(self.compilation_config.compute_hash())
|
| 3059 | + else: |
| 3060 | + vllm_factors.append("None") |
3027 | 3061 | if self.kv_transfer_config:
|
3028 | 3062 | vllm_factors.append(self.kv_transfer_config.compute_hash())
|
3029 |
| - |
| 3063 | + else: |
| 3064 | + vllm_factors.append("None") |
| 3065 | + if self.additional_config: |
| 3066 | + vllm_factors.append(self.additional_config.compute_hash()) |
| 3067 | + else: |
| 3068 | + vllm_factors.append("None") |
3030 | 3069 | factors.append(vllm_factors)
|
3031 | 3070 |
|
3032 | 3071 | hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10]
|
|
0 commit comments