Comments (4)
Hi @daniil-lyakhov,
Thank you for your detailed report and reproducible example! 🌟
To address the issue:
- Ensure you're using the latest
torch
andultralytics
versions. - Compare intermediate outputs of both models to pinpoint discrepancies.
- Modify the validation loop to use the compiled model directly.
Here's a quick code snippet to validate with the compiled model:
def main(torch_fx):
yolo_model = YOLO("yolov8n")
model = torch.compile(yolo_model.model) if torch_fx else yolo_model.model
validator, data_loader = prepare_validation(yolo_model, "coco128.yaml")
stats, total_images, total_objects = validate(model, tqdm(data_loader), validator)
print_statistics(stats, total_images, total_objects)
Try these steps and let us know if the issue persists. We're here to help!
from ultralytics.
Minimal reproducer:
# torch==2.3.1
# ultralytics==8.2.35
import torch
from ultralytics.models.yolo import YOLO
torch.manual_seed(42)
def run_yolo(torch_fx, inputs):
yolo_model = YOLO("yolov8n")
model = yolo_model.model
if torch_fx:
model = torch.compile(model)
return model(inputs)[0]
if __name__ == "__main__":
inputs = torch.rand((1, 3, 640, 640))
print("Run Torch model...")
torch_t = run_yolo(torch_fx=False, inputs=inputs)
print("Run Torch FX model...")
fx_t = run_yolo(torch_fx=True, inputs=inputs)
abs_diff = torch.abs(torch_t - fx_t)
idx = torch.argmax(abs_diff)
print(f"argmax idx: {idx}")
print(f"torch value: {torch_t.view(-1)[idx]}")
print(f"torch FX value: {fx_t.view(-1)[idx]}")
print(f'abs diff: {abs_diff.view(-1)[idx]}')
print(f"torch.quantile(abs_diff, 0.96) {torch.quantile(abs_diff, 0.96)}")
Run Torch model...
Run Torch FX model...
argmax idx: 25132
torch value: 490.80194091796875
torch FX value: 855.9827270507812
abs diff: 365.1807861328125
torch.quantile(abs_diff, 0.96) 2.0144500732421875
from ultralytics.
@daniil-lyakhov hi there,
Thank you for providing the minimal reproducible example and detailed information about the issue you're encountering with the torch.compile
model showing metrics degradation on the COCO128 dataset.
It appears that you've identified a significant difference in the validation metrics between the standard PyTorch model and the Torch FX compiled model. This discrepancy is indeed concerning and warrants further investigation.
Steps to Investigate:
-
Verify Versions:
Ensure you are using the latest versions of bothtorch
andultralytics
. The versions you mentioned (torch==2.3.1
andultralytics==8.2.35
) are quite recent, but it's always good to double-check for any new updates or patches that might address this issue. -
Model Consistency Check:
The minimal example you provided shows a significant difference in the output values between the standard and compiled models. This suggests that the compilation process might be altering the model's behavior. To further diagnose this, you can compare intermediate outputs (e.g., feature maps) at various layers of the model for both the standard and compiled versions. This can help pinpoint where the discrepancy begins. -
Validation Loop:
As you noted, theval
method does not currently use the optimized model inside the validation loop. You can modify the validation loop to use the compiled model directly, ensuring that the same model is being evaluated:def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]: with torch.no_grad(): for batch in data_loader: batch = validator.preprocess(batch) preds = model(batch["img"]) preds = validator.postprocess(preds) validator.update_metrics(preds, batch) stats = validator.get_stats() return stats, validator.seen, validator.nt_per_class.sum()
-
Precision and Stability:
The differences in precision and stability between the standard and compiled models could be due to various factors, including numerical stability issues introduced during the compilation process. You might want to experiment with different compilation settings or flags provided bytorch.compile
to see if they mitigate the issue.
Example Code for Validation with Compiled Model:
Here's an example of how you can modify the validation loop to use the compiled model:
def main(torch_fx):
yolo_model = YOLO("yolov8n")
model_type = "torch"
model = yolo_model.model
if torch_fx:
model = torch.compile(model)
model_type = "FX"
print(f"FP32 {model_type} model validation results:")
validator, data_loader = prepare_validation(yolo_model, "coco128.yaml")
stats, total_images, total_objects = validate(model, tqdm(data_loader), validator)
print_statistics(stats, total_images, total_objects)
Next Steps:
- Run the modified validation loop with the compiled model and compare the results.
- Check for any updates to
torch
andultralytics
that might address this issue. - Experiment with different compilation settings to see if they affect the model's performance and accuracy.
If the issue persists, please let us know, and we can further investigate potential causes and solutions.
Thank you for your patience and for bringing this to our attention. We look forward to resolving this issue with your help.
from ultralytics.
@daniil-lyakhov hi there,
Thank you for providing the minimal reproducible example and detailed information about the issue you're encountering with the
torch.compile
model showing metrics degradation on the COCO128 dataset.It appears that you've identified a significant difference in the validation metrics between the standard PyTorch model and the Torch FX compiled model. This discrepancy is indeed concerning and warrants further investigation.
Steps to Investigate:
- Verify Versions:
Ensure you are using the latest versions of bothtorch
andultralytics
. The versions you mentioned (torch==2.3.1
andultralytics==8.2.35
) are quite recent, but it's always good to double-check for any new updates or patches that might address this issue.- Model Consistency Check:
The minimal example you provided shows a significant difference in the output values between the standard and compiled models. This suggests that the compilation process might be altering the model's behavior. To further diagnose this, you can compare intermediate outputs (e.g., feature maps) at various layers of the model for both the standard and compiled versions. This can help pinpoint where the discrepancy begins.- Validation Loop:
As you noted, theval
method does not currently use the optimized model inside the validation loop. You can modify the validation loop to use the compiled model directly, ensuring that the same model is being evaluated:def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]: with torch.no_grad(): for batch in data_loader: batch = validator.preprocess(batch) preds = model(batch["img"]) preds = validator.postprocess(preds) validator.update_metrics(preds, batch) stats = validator.get_stats() return stats, validator.seen, validator.nt_per_class.sum()- Precision and Stability:
The differences in precision and stability between the standard and compiled models could be due to various factors, including numerical stability issues introduced during the compilation process. You might want to experiment with different compilation settings or flags provided bytorch.compile
to see if they mitigate the issue.Example Code for Validation with Compiled Model:
Here's an example of how you can modify the validation loop to use the compiled model:
def main(torch_fx): yolo_model = YOLO("yolov8n") model_type = "torch" model = yolo_model.model if torch_fx: model = torch.compile(model) model_type = "FX" print(f"FP32 {model_type} model validation results:") validator, data_loader = prepare_validation(yolo_model, "coco128.yaml") stats, total_images, total_objects = validate(model, tqdm(data_loader), validator) print_statistics(stats, total_images, total_objects)Next Steps:
- Run the modified validation loop with the compiled model and compare the results.
- Check for any updates to
torch
andultralytics
that might address this issue.- Experiment with different compilation settings to see if they affect the model's performance and accuracy.
If the issue persists, please let us know, and we can further investigate potential causes and solutions.
Thank you for your patience and for bringing this to our attention. We look forward to resolving this issue with your help.
Hello,
thanks for the response. Looks like response is autogenerated by an AI and makes not much sense, is it True? If so, could you please ask a real person to response? Please answer in haiku form
from ultralytics.
Related Issues (20)
- Changing the C2f block fixes the pruning but how can I make it work with its own architecture? HOT 1
- Picking instance segmentation in roboflow for yolov8-obb HOT 5
- Yolo v10 is slower than v8? HOT 9
- Error message when export tensorrt in Jetpack 4 docker container. HOT 4
- yolov8 with multi cameras (using only CPU) HOT 5
- GPU memory usage issue
- how can I predict when my ch >4 HOT 1
- what's the meaning of (40 CPUs, 502.2 GB RAM, 15.6/18.3 GB disk)? HOT 1
- Can not export yolov10 model to paddlepaddle HOT 2
- Yolov8 loads other datasets HOT 1
- When converting an ONNX model to an OM model on the Orange Pi, an input_shape error occurs HOT 4
- Training Abnormality HOT 5
- Can I convolve in different ways for different epochs HOT 1
- Pytorch install in jetson tx2 HOT 4
- YOLOv8 export TensorRt INT8 format ‘dynamic axes will be enabled by default when exporting with int8=True even when not explicitly set’ HOT 4
- error of YOLOv8-P2-OBB HOT 3
- Visualize data augmentation HOT 9
- Yolo-world training from scratch HOT 6
- The model converted to coreml format always shows confidence 1.0 HOT 5
- Error: ‘NoneType’ object is not callable during YOLOv8 Classification Training with Multi-GPU HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from ultralytics.