Skip to content

Commit 1557b81

Browse files
YixinSong-esyx
authored andcommitted
Add solver (ggml-org#4)
* add solver * update solver --------- Co-authored-by: syx <[email protected]>
1 parent 9adba26 commit 1557b81

File tree

1 file changed

+106
-0
lines changed

1 file changed

+106
-0
lines changed

solver.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
import argparse
4+
from cvxopt.glpk import ilp
5+
import numpy as np
6+
from cvxopt import matrix
7+
import torch
8+
import pickle
9+
10+
# Set up command line arguments
11+
parser = argparse.ArgumentParser(description='Optimize neuron activation based on VRAM capacity and other parameters.')
12+
parser.add_argument('--activation_path', type=str, required=True, help='Path to the directory containing activation data.')
13+
parser.add_argument('--neuron', type=int, default=8192*4, help='Total number of neurons in the network.')
14+
parser.add_argument('--capacity', type=int, default=int(8192*4*32*0.1), help='Total VRAM capacity for the model.')
15+
parser.add_argument('--layer', type=int, default=59, help='Total number of layers in the neural network.')
16+
parser.add_argument('--batch', type=int, default=32, help='Batch size for processing.')
17+
parser.add_argument('--threshold', type=int, default=512, help='Threshold for splitting a layer across multiple GPUs.')
18+
parser.add_argument('--output', type=str, required=True, help='File path for the output pickle file.')
19+
20+
args = parser.parse_args()
21+
22+
# Assigning command line arguments to variables
23+
activation_path = args.activation_path
24+
neuron = args.neuron
25+
layer = args.layer
26+
batch = args.batch
27+
output_path = args.output
28+
29+
# Processing activation data
30+
values = []
31+
for i in range(layer):
32+
# Load and sort activation data for each layer
33+
freq = torch.load(f"{activation_path}/activation_{i}.pt")
34+
freq, _ = torch.sort(freq, descending=True)
35+
freq = freq * -1.0
36+
freq = freq.view(-1, batch)
37+
freq = freq.sum(dim=1)
38+
freq = freq.tolist()
39+
values += freq
40+
41+
# Padding zero values for additional constraints
42+
for i in range(layer):
43+
values += [0.0]
44+
c = np.array(values, dtype=float)
45+
c = matrix(c)
46+
47+
# Setting capacity and neuron count per batch
48+
CAP = args.capacity
49+
CAP = int(CAP / batch)
50+
neuron = int(neuron / batch)
51+
coeff = []
52+
h = []
53+
54+
# Constraint 1: Total neuron activation constraint
55+
lst = []
56+
for i in range(neuron * layer):
57+
lst.append(1)
58+
for i in range(layer):
59+
lst.append(0)
60+
coeff.append(lst)
61+
h.append(CAP)
62+
63+
# Constraint 2: Threshold constraint for GPU split per layer
64+
for i in range(layer):
65+
lst = [0] * (neuron * layer + layer)
66+
for j in range(neuron):
67+
lst[i * neuron + j] = -1
68+
lst[neuron * layer + i] = int(args.threshold / batch)
69+
coeff.append(lst)
70+
h.append(0)
71+
72+
# Constraint 3: Upper bound on neuron activations
73+
for i in range(layer):
74+
lst = [0] * (neuron * layer + layer)
75+
for j in range(neuron):
76+
lst[i * neuron + j] = 1
77+
lst[neuron * layer + i] = -1000000 # Arbitrary large negative number as an upper bound
78+
coeff.append(lst)
79+
h.append(0)
80+
81+
# Convert lists to matrix format for ILP solver
82+
coeff = np.array(coeff, dtype=float)
83+
G = matrix(coeff)
84+
h = np.array(h, dtype=float)
85+
h = matrix(h)
86+
87+
# Define the set of integer and binary variables
88+
I = set(range(neuron * layer + layer))
89+
B = set()
90+
91+
# Solving the ILP problem
92+
(status, x) = ilp(c, G, h, None, None, B, I)
93+
print(f"ILP Status: {status}")
94+
ans = list(x)
95+
print(f"Total Activation Units: {sum(ans)}")
96+
97+
# Serialize the solution
98+
serialize = []
99+
for i in range(layer):
100+
serialize.append(sum(ans[i * neuron:i * neuron + neuron] * batch))
101+
102+
aligned_lst = serialize
103+
104+
# Save the solution to a pickle file
105+
with open(output_path, 'wb') as handle:
106+
pickle.dump(aligned_lst, handle)

0 commit comments

Comments
 (0)