Skip to content

Commit b67407a

Browse files
authored
Basic SAM2 AutomaticMaskGeneration example server (#1039)
1 parent 745085f commit b67407a

File tree

5 files changed

+454
-0
lines changed

5 files changed

+454
-0
lines changed

examples/sam2_amg_server/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
To run this example you need to download the vit_h checkpoint and put it into a local folder named checkpoints
2+
3+
You can find the checkpoint for vit_h here: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
4+
5+
To read the image you also need to install opencv-python: https://pypi.org/project/opencv-python/
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import numpy as np
2+
import torch
3+
import matplotlib.pyplot as plt
4+
import cv2
5+
import torch.utils.benchmark as benchmark
6+
7+
from torch._inductor import config as inductorconfig
8+
inductorconfig.triton.unique_kernel_names = True
9+
inductorconfig.coordinate_descent_tuning = True
10+
inductorconfig.coordinate_descent_check_all_directions = True
11+
12+
def profiler_runner(path, fn, *args, **kwargs):
13+
with torch.profiler.profile(
14+
activities=[torch.profiler.ProfilerActivity.CPU,
15+
torch.profiler.ProfilerActivity.CUDA],
16+
record_shapes=True) as prof:
17+
result = fn(*args, **kwargs)
18+
print(f"Saving trace under {path}")
19+
prof.export_chrome_trace(path)
20+
return result
21+
22+
def show_anns(anns):
23+
if len(anns) == 0:
24+
return
25+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
26+
ax = plt.gca()
27+
ax.set_autoscale_on(False)
28+
29+
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
30+
img[:,:,3] = 0
31+
ms = []
32+
for ann in sorted_anns:
33+
m = ann['segmentation']
34+
ms.append(torch.as_tensor(m))
35+
color_mask = np.concatenate([np.random.random(3), [0.35]])
36+
img[m] = color_mask
37+
ax.imshow(img)
38+
return torch.stack(ms)
39+
40+
image = cv2.imread('dog.jpg')
41+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
42+
43+
44+
# from segment_anything_fast import sam_model_registry, sam_model_fast_registry, SamAutomaticMaskGenerator
45+
#
46+
# sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
47+
# model_type = "vit_h"
48+
device = "cuda"
49+
#
50+
# sam = sam_model_fast_registry[model_type](checkpoint=sam_checkpoint)
51+
# sam.to(device=device)
52+
53+
from sam2.build_sam import build_sam2
54+
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
55+
56+
sam2_checkpoint = "checkpoints/sam2_hiera_large.pt"
57+
model_cfg = "sam2_hiera_l.yaml"
58+
59+
sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)
60+
sam2.to(device=device)
61+
62+
# mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=256)
63+
mask_generator = SAM2AutomaticMaskGenerator(sam2, points_per_batch=None)
64+
65+
## NOTE: Causes numerical differences
66+
## TODO: Implement mIoU to allow approximations.
67+
# torch.set_float32_matmul_precision('high')
68+
# torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
69+
##
70+
71+
## TODO: Using CUDA graphs can cause numerical differences?
72+
mask_generator.predictor.model.image_encoder = torch.compile(
73+
mask_generator.predictor.model.image_encoder,
74+
# mode="max-autotune-no-cudagraphs",
75+
mode="max-autotune",
76+
fullgraph=True,
77+
dynamic=False,
78+
)
79+
80+
# mask_generator.predictor._predict = torch.compile(
81+
# mask_generator.predictor._predict,
82+
# mode="max-autotune-no-cudagraphs",
83+
# fullgraph=True,
84+
# dynamic=False,
85+
# )
86+
87+
torch._dynamo.config.capture_dynamic_output_shape_ops = True
88+
mask_generator._process_batch = torch.compile(
89+
mask_generator._process_batch,
90+
mode="max-autotune-no-cudagraphs",
91+
fullgraph=True,
92+
dynamic=True,
93+
)
94+
95+
# with torch.backends.cuda.sdp_kernel(enable_cudnn=False): #, enable_math=False, enable_mem_efficient=False):
96+
with torch.backends.cuda.sdp_kernel(enable_cudnn=True): #, enable_math=False, enable_mem_efficient=False):
97+
# Run thrice for warmup
98+
masks = mask_generator.generate(image)
99+
masks = mask_generator.generate(image)
100+
masks = mask_generator.generate(image)
101+
102+
# Save an example
103+
plt.figure(figsize=(image.shape[1]/100., image.shape[0]/100.), dpi=100)
104+
plt.imshow(image)
105+
ms = show_anns(masks)
106+
ms_ref = torch.load("dog_mask_fast.pt")
107+
torch.testing.assert_allclose(ms, ms_ref)
108+
print("Masks match reference")
109+
# # torch.save(ms, "dog_mask_fast.pt")
110+
plt.axis('off')
111+
plt.tight_layout()
112+
plt.savefig('dog_mask_fast.png', format='png')
113+
114+
# Benchmark
115+
torch.cuda.synchronize()
116+
start_event = torch.cuda.Event(enable_timing=True)
117+
end_event = torch.cuda.Event(enable_timing=True)
118+
start_event.record()
119+
for _ in range(10):
120+
masks = mask_generator.generate(image)
121+
end_event.record()
122+
torch.cuda.synchronize()
123+
print(start_event.elapsed_time(end_event) / 10.)
124+
125+
# Save a GPU trace
126+
profiler_runner(f"amg_example_trace.json.gz", mask_generator.generate, image)
127+
128+
# Write out memory usage
129+
max_memory_allocated_bytes = torch.cuda.max_memory_allocated()
130+
_, total_memory = torch.cuda.mem_get_info()
131+
max_memory_allocated_percentage = int(100 * (max_memory_allocated_bytes / total_memory))
132+
max_memory_allocated_bytes = max_memory_allocated_bytes >> 20
133+
print(f"memory(MiB): {max_memory_allocated_bytes} memory(%): {max_memory_allocated_percentage}")

examples/sam2_amg_server/example.html

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
!DOCTYPE html>
2+
<html lang="en">
3+
<head>
4+
<meta charset="UTF-8">
5+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
6+
<title>Upload and Display Image from FastAPI Response</title>
7+
<style>
8+
#preview {
9+
margin-top: 20px;
10+
max-width: 100%;
11+
max-height: 400px;
12+
display: none;
13+
}
14+
</style>
15+
</head>
16+
<body>
17+
<h1>Upload an Image and Display the Response</h1>
18+
<form id="uploadForm">
19+
<label for="image">Choose an image to upload:</label><br>
20+
<input type="file" id="image" name="image" accept="image/*" required><br><br>
21+
<input type="submit" value="Upload Image">
22+
</form>
23+
24+
<h2>Received Image Preview:</h2>
25+
<img id="preview" alt="Received Image">
26+
27+
<script>
28+
document.getElementById('uploadForm').addEventListener('submit', function (e) {
29+
e.preventDefault();
30+
31+
const formData = new FormData();
32+
const fileInput = document.getElementById('image');
33+
const file = fileInput.files[0];
34+
35+
if (file) {
36+
formData.append('image', file);
37+
38+
// Perform the image upload via Fetch API
39+
fetch('http://127.0.0.1:5000/upload', {
40+
method: 'POST',
41+
body: formData
42+
})
43+
.then(response => response.blob()) // Get the image as a Blob from the response
44+
.then(imageBlob => {
45+
const imageObjectURL = URL.createObjectURL(imageBlob);
46+
const preview = document.getElementById('preview');
47+
preview.src = imageObjectURL;
48+
preview.style.display = 'block';
49+
})
50+
.catch(error => {
51+
console.error('Error:', error);
52+
});
53+
}
54+
});
55+
</script>
56+
</body>
57+
</html>
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# @package _global_
2+
3+
# Model
4+
model:
5+
_target_: sam2.modeling.sam2_base.SAM2Base
6+
image_encoder:
7+
_target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8+
scalp: 1
9+
trunk:
10+
_target_: sam2.modeling.backbones.hieradet.Hiera
11+
embed_dim: 144
12+
num_heads: 2
13+
stages: [2, 6, 36, 4]
14+
global_att_blocks: [23, 33, 43]
15+
window_pos_embed_bkg_spatial_size: [7, 7]
16+
window_spec: [8, 4, 16, 8]
17+
neck:
18+
_target_: sam2.modeling.backbones.image_encoder.FpnNeck
19+
position_encoding:
20+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21+
num_pos_feats: 256
22+
normalize: true
23+
scale: null
24+
temperature: 10000
25+
d_model: 256
26+
backbone_channel_list: [1152, 576, 288, 144]
27+
fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28+
fpn_interp_model: nearest
29+
30+
memory_attention:
31+
_target_: sam2.modeling.memory_attention.MemoryAttention
32+
d_model: 256
33+
pos_enc_at_input: true
34+
layer:
35+
_target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36+
activation: relu
37+
dim_feedforward: 2048
38+
dropout: 0.1
39+
pos_enc_at_attn: false
40+
self_attention:
41+
_target_: sam2.modeling.sam.transformer.RoPEAttention
42+
rope_theta: 10000.0
43+
feat_sizes: [32, 32]
44+
embedding_dim: 256
45+
num_heads: 1
46+
downsample_rate: 1
47+
dropout: 0.1
48+
d_model: 256
49+
pos_enc_at_cross_attn_keys: true
50+
pos_enc_at_cross_attn_queries: false
51+
cross_attention:
52+
_target_: sam2.modeling.sam.transformer.RoPEAttention
53+
rope_theta: 10000.0
54+
feat_sizes: [32, 32]
55+
rope_k_repeat: True
56+
embedding_dim: 256
57+
num_heads: 1
58+
downsample_rate: 1
59+
dropout: 0.1
60+
kv_in_dim: 64
61+
num_layers: 4
62+
63+
memory_encoder:
64+
_target_: sam2.modeling.memory_encoder.MemoryEncoder
65+
out_dim: 64
66+
position_encoding:
67+
_target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68+
num_pos_feats: 64
69+
normalize: true
70+
scale: null
71+
temperature: 10000
72+
mask_downsampler:
73+
_target_: sam2.modeling.memory_encoder.MaskDownSampler
74+
kernel_size: 3
75+
stride: 2
76+
padding: 1
77+
fuser:
78+
_target_: sam2.modeling.memory_encoder.Fuser
79+
layer:
80+
_target_: sam2.modeling.memory_encoder.CXBlock
81+
dim: 256
82+
kernel_size: 7
83+
padding: 3
84+
layer_scale_init_value: 1e-6
85+
use_dwconv: True # depth-wise convs
86+
num_layers: 2
87+
88+
num_maskmem: 7
89+
image_size: 1024
90+
# apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91+
sigmoid_scale_for_mem_enc: 20.0
92+
sigmoid_bias_for_mem_enc: -10.0
93+
use_mask_input_as_output_without_sam: true
94+
# Memory
95+
directly_add_no_mem_embed: true
96+
# use high-resolution feature map in the SAM mask decoder
97+
use_high_res_features_in_sam: true
98+
# output 3 masks on the first click on initial conditioning frames
99+
multimask_output_in_sam: true
100+
# SAM heads
101+
iou_prediction_use_sigmoid: True
102+
# cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103+
use_obj_ptrs_in_encoder: true
104+
add_tpos_enc_to_obj_ptrs: false
105+
only_obj_ptrs_in_the_past_for_eval: true
106+
# object occlusion prediction
107+
pred_obj_scores: true
108+
pred_obj_scores_mlp: true
109+
fixed_no_obj_ptr: true
110+
# multimask tracking settings
111+
multimask_output_for_tracking: true
112+
use_multimask_token_for_obj_ptr: true
113+
multimask_min_pt_num: 0
114+
multimask_max_pt_num: 1
115+
use_mlp_for_obj_ptr_proj: true
116+
# Compilation flag
117+
compile_image_encoder: False

0 commit comments

Comments
 (0)