Skip to content

Commit b229455

Browse files
georgiaphillipspytorchmergebot
authored andcommitted
Update placement utils and weights to handle meta device (pytorch#162842)
Summary: This diff fixes two things which come up when testing a tgif-published pt2 model remote net: 1) Updates isSameDevice to handle meta device to avoid this error: ``` what(): Unsupported device typemeta and meta Exception raised from isSameDevice at fbcode/caffe2/torch/nativert/executor/PlacementUtils.cpp:20 ``` 2. Updates xl weight v2 loading logic in Weights.cpp to handle non-TBE xl-weights. Today, we enforce the device is the same for an old weight and new weight when replacing with ModelRunnerAdapter.setAttr(). However, the way we replace non-TBE xl weights is to find any weights on "meta" device and then replace them with their correct weight with real device from xl_weights folder. Therefore, the new weight and old weight will always have different devices and the device check is invalid. I don't think we've run into this so far bc non-TBE xl weights have not been thoroughly tested until now. Test Plan: Run MRS you model merge net, which uses non-TBE xl weights. Confirm that before change #1 we get error: ``` Unsupported device typemeta and meta ``` Then after change #1 and before change #2 we get: ``` what(): Mismatched device for merge.user_tower.linear.weight: meta vs cpu Exception raised from validateValue at fbcode/caffe2/torch/nativert/executor/Weights.cpp:374 ``` After change run is successful Command: ``` MODEL_ENTITY_ID=921242082 SNAPSHOT_ID=1269 module_name=merge SAMPLE_INPUT_DIR=/data/users/georgiaphillips/models/921242082/${SNAPSHOT_ID}/${module_name}_archive/package/data/sample_inputs buck2 run mode/dev-nosan -c fbcode.nvcc_arch=h100,a100 -c fbcode.enable_gpu_sections=true caffe2/torch/fb/model_transform/fx2trt/packaging:load_net_predictor -- --loadMode=Benchmark --inputNetFile=/data/users/$USER/models/${MODEL_ENTITY_ID}/${SNAPSHOT_ID}/${MODEL_ENTITY_ID}_${SNAPSHOT_ID}.predictor.${module_name} --moduleName=${module_name} --submodToDevice="merge|cuda0" --benchmarkEnableProfiling=false --disableStaticRuntime=true --doNotRandomizeSampleInputs=true --benchmarkDontRebatchSamples=true --pytorch_predictor_sigmoid_static_dispatch_enable=false --pytorch_predictor_sigmoid_graph_passes_enable=false --sampleInputFilePath=${SAMPLE_INPUT_DIR}/${module_name}.pt ``` Rollback Plan: Differential Revision: D80713052 Pull Request resolved: pytorch#162842 Approved by: https://github.com/henryoier
1 parent a541974 commit b229455

File tree

3 files changed

+39
-12
lines changed

3 files changed

+39
-12
lines changed

torch/nativert/executor/PlacementUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ bool isSameDevice(const c10::Device& a, const c10::Device& b) {
1717
return false;
1818
}
1919
}
20+
if (a.is_meta()) {
21+
return b.is_meta();
22+
}
2023
TORCH_CHECK(false, "Unsupported device type", a, " and ", b);
2124
return false;
2225
}

torch/nativert/executor/Weights.cpp

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,13 @@ void Weights::loadStateDict(
337337

338338
void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
339339
const {
340+
validateValue(name, newValue, /*skipDeviceCheck=*/false);
341+
}
342+
343+
void Weights::validateValue(
344+
const std::string& name,
345+
const at::Tensor& newValue,
346+
bool skipDeviceCheck) const {
340347
auto& weightMeta = weightsMeta_.at(name);
341348

342349
TORCH_CHECK(
@@ -360,23 +367,32 @@ void Weights::validateValue(const std::string& name, const at::Tensor& newValue)
360367
" vs ",
361368
newValue.dtype());
362369

363-
auto targetDevice = weightMeta.device();
364-
if (targetDevice.is_cpu() && targetDevice.has_index()) {
365-
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
370+
if (!skipDeviceCheck) {
371+
auto targetDevice = weightMeta.device();
372+
if (targetDevice.is_cpu() && targetDevice.has_index()) {
373+
LOG(WARNING) << "Target device is cpu but has index: " << targetDevice;
374+
}
375+
TORCH_CHECK(
376+
isSameDevice(targetDevice, newValue.device()),
377+
"Mismatched device for ",
378+
name,
379+
": ",
380+
targetDevice,
381+
" vs ",
382+
newValue.device());
366383
}
367-
TORCH_CHECK(
368-
isSameDevice(targetDevice, newValue.device()),
369-
"Mismatched device for ",
370-
name,
371-
": ",
372-
targetDevice,
373-
" vs ",
374-
newValue.device());
375384
}
376385

377386
void Weights::setValue(const std::string& name, const at::Tensor& newValue) {
387+
setValue(name, newValue, /*skipDeviceCheck=*/false);
388+
}
389+
390+
void Weights::setValue(
391+
const std::string& name,
392+
const at::Tensor& newValue,
393+
bool skipDeviceCheck) {
378394
if (allValues_.find(name) != allValues_.end()) {
379-
validateValue(name, newValue);
395+
validateValue(name, newValue, skipDeviceCheck);
380396
} else {
381397
LOG(WARNING) << name << " is not found in the registered weights";
382398
}

torch/nativert/executor/Weights.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ class Weights {
6666
* Replace the value stored at the weight with name "name".
6767
*/
6868
void setValue(const std::string& name, const at::Tensor& newValue);
69+
void setValue(
70+
const std::string& name,
71+
const at::Tensor& newValue,
72+
bool skipDeviceCheck);
6973

7074
/*
7175
* Update the value stored at the weight with name "name".
@@ -77,6 +81,10 @@ class Weights {
7781
const std::unordered_map<std::string, at::Tensor>& newValues);
7882

7983
void validateValue(const std::string& name, const at::Tensor& newValue) const;
84+
void validateValue(
85+
const std::string& name,
86+
const at::Tensor& newValue,
87+
bool skipDeviceCheck) const;
8088

8189
void validateAllWeightsLoaded();
8290

0 commit comments

Comments
 (0)