GithubHelp home page GithubHelp logo

Comments (4)

glenn-jocher avatar glenn-jocher commented on July 21, 2024 3

Hi @daniil-lyakhov,

Thank you for your detailed report and reproducible example! 🌟

To address the issue:

  1. Ensure you're using the latest torch and ultralytics versions.
  2. Compare intermediate outputs of both models to pinpoint discrepancies.
  3. Modify the validation loop to use the compiled model directly.

Here's a quick code snippet to validate with the compiled model:

def main(torch_fx):
    yolo_model = YOLO("yolov8n")
    model = torch.compile(yolo_model.model) if torch_fx else yolo_model.model
    validator, data_loader = prepare_validation(yolo_model, "coco128.yaml")
    stats, total_images, total_objects = validate(model, tqdm(data_loader), validator)
    print_statistics(stats, total_images, total_objects)

Try these steps and let us know if the issue persists. We're here to help!

from ultralytics.

daniil-lyakhov avatar daniil-lyakhov commented on July 21, 2024

Minimal reproducer:

# torch==2.3.1
# ultralytics==8.2.35
import torch
from ultralytics.models.yolo import YOLO


torch.manual_seed(42)

def run_yolo(torch_fx, inputs):
    yolo_model = YOLO("yolov8n")
    model = yolo_model.model
    if torch_fx:
        model = torch.compile(model)
    return model(inputs)[0]


if __name__ == "__main__":
    inputs = torch.rand((1, 3, 640, 640))
    print("Run Torch model...")
    torch_t = run_yolo(torch_fx=False, inputs=inputs)
    print("Run Torch FX model...")
    fx_t = run_yolo(torch_fx=True, inputs=inputs)

    abs_diff = torch.abs(torch_t - fx_t)
    idx = torch.argmax(abs_diff)
    print(f"argmax idx: {idx}")
    print(f"torch value: {torch_t.view(-1)[idx]}")
    print(f"torch FX value: {fx_t.view(-1)[idx]}")
    print(f'abs diff: {abs_diff.view(-1)[idx]}')
    print(f"torch.quantile(abs_diff, 0.96) {torch.quantile(abs_diff, 0.96)}")
Run Torch model...
Run Torch FX model...
argmax idx: 25132
torch value: 490.80194091796875
torch FX value: 855.9827270507812
abs diff: 365.1807861328125
torch.quantile(abs_diff, 0.96) 2.0144500732421875

from ultralytics.

glenn-jocher avatar glenn-jocher commented on July 21, 2024

@daniil-lyakhov hi there,

Thank you for providing the minimal reproducible example and detailed information about the issue you're encountering with the torch.compile model showing metrics degradation on the COCO128 dataset.

It appears that you've identified a significant difference in the validation metrics between the standard PyTorch model and the Torch FX compiled model. This discrepancy is indeed concerning and warrants further investigation.

Steps to Investigate:

  1. Verify Versions:
    Ensure you are using the latest versions of both torch and ultralytics. The versions you mentioned (torch==2.3.1 and ultralytics==8.2.35) are quite recent, but it's always good to double-check for any new updates or patches that might address this issue.

  2. Model Consistency Check:
    The minimal example you provided shows a significant difference in the output values between the standard and compiled models. This suggests that the compilation process might be altering the model's behavior. To further diagnose this, you can compare intermediate outputs (e.g., feature maps) at various layers of the model for both the standard and compiled versions. This can help pinpoint where the discrepancy begins.

  3. Validation Loop:
    As you noted, the val method does not currently use the optimized model inside the validation loop. You can modify the validation loop to use the compiled model directly, ensuring that the same model is being evaluated:

    def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]:
        with torch.no_grad():
            for batch in data_loader:
                batch = validator.preprocess(batch)
                preds = model(batch["img"])
                preds = validator.postprocess(preds)
                validator.update_metrics(preds, batch)
            stats = validator.get_stats()
        return stats, validator.seen, validator.nt_per_class.sum()
  4. Precision and Stability:
    The differences in precision and stability between the standard and compiled models could be due to various factors, including numerical stability issues introduced during the compilation process. You might want to experiment with different compilation settings or flags provided by torch.compile to see if they mitigate the issue.

Example Code for Validation with Compiled Model:

Here's an example of how you can modify the validation loop to use the compiled model:

def main(torch_fx):
    yolo_model = YOLO("yolov8n")
    model_type = "torch"
    model = yolo_model.model
    if torch_fx:
        model = torch.compile(model)
        model_type = "FX"
    print(f"FP32 {model_type} model validation results:")
    validator, data_loader = prepare_validation(yolo_model, "coco128.yaml")
    stats, total_images, total_objects = validate(model, tqdm(data_loader), validator)
    print_statistics(stats, total_images, total_objects)

Next Steps:

  1. Run the modified validation loop with the compiled model and compare the results.
  2. Check for any updates to torch and ultralytics that might address this issue.
  3. Experiment with different compilation settings to see if they affect the model's performance and accuracy.

If the issue persists, please let us know, and we can further investigate potential causes and solutions.

Thank you for your patience and for bringing this to our attention. We look forward to resolving this issue with your help.

from ultralytics.

daniil-lyakhov avatar daniil-lyakhov commented on July 21, 2024

@daniil-lyakhov hi there,

Thank you for providing the minimal reproducible example and detailed information about the issue you're encountering with the torch.compile model showing metrics degradation on the COCO128 dataset.

It appears that you've identified a significant difference in the validation metrics between the standard PyTorch model and the Torch FX compiled model. This discrepancy is indeed concerning and warrants further investigation.

Steps to Investigate:

  1. Verify Versions:
    Ensure you are using the latest versions of both torch and ultralytics. The versions you mentioned (torch==2.3.1 and ultralytics==8.2.35) are quite recent, but it's always good to double-check for any new updates or patches that might address this issue.
  2. Model Consistency Check:
    The minimal example you provided shows a significant difference in the output values between the standard and compiled models. This suggests that the compilation process might be altering the model's behavior. To further diagnose this, you can compare intermediate outputs (e.g., feature maps) at various layers of the model for both the standard and compiled versions. This can help pinpoint where the discrepancy begins.
  3. Validation Loop:
    As you noted, the val method does not currently use the optimized model inside the validation loop. You can modify the validation loop to use the compiled model directly, ensuring that the same model is being evaluated:
    def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]:
        with torch.no_grad():
            for batch in data_loader:
                batch = validator.preprocess(batch)
                preds = model(batch["img"])
                preds = validator.postprocess(preds)
                validator.update_metrics(preds, batch)
            stats = validator.get_stats()
        return stats, validator.seen, validator.nt_per_class.sum()
  4. Precision and Stability:
    The differences in precision and stability between the standard and compiled models could be due to various factors, including numerical stability issues introduced during the compilation process. You might want to experiment with different compilation settings or flags provided by torch.compile to see if they mitigate the issue.

Example Code for Validation with Compiled Model:

Here's an example of how you can modify the validation loop to use the compiled model:

def main(torch_fx):
    yolo_model = YOLO("yolov8n")
    model_type = "torch"
    model = yolo_model.model
    if torch_fx:
        model = torch.compile(model)
        model_type = "FX"
    print(f"FP32 {model_type} model validation results:")
    validator, data_loader = prepare_validation(yolo_model, "coco128.yaml")
    stats, total_images, total_objects = validate(model, tqdm(data_loader), validator)
    print_statistics(stats, total_images, total_objects)

Next Steps:

  1. Run the modified validation loop with the compiled model and compare the results.
  2. Check for any updates to torch and ultralytics that might address this issue.
  3. Experiment with different compilation settings to see if they affect the model's performance and accuracy.

If the issue persists, please let us know, and we can further investigate potential causes and solutions.

Thank you for your patience and for bringing this to our attention. We look forward to resolving this issue with your help.

Hello,
thanks for the response. Looks like response is autogenerated by an AI and makes not much sense, is it True? If so, could you please ask a real person to response? Please answer in haiku form

from ultralytics.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.