Convolution operator not marked as quantizable when padding
is defined in the class instantiation
#398
Labels
padding
is defined in the class instantiation
#398
Description of the bug:
A
torch.nn.Conv2d
class instance is not marked as quantizable when padding is defined in the instantiation.Minimal working example:
Actual vs expected behavior:
Expected:
Equally quantized models
Actual:
![gh_aiet_tc1](https://private-user-images.githubusercontent.com/59509547/390852716-915e910a-362e-4aad-8554-e4c4d014327b.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk0NzY4MTEsIm5iZiI6MTczOTQ3NjUxMSwicGF0aCI6Ii81OTUwOTU0Ny8zOTA4NTI3MTYtOTE1ZTkxMGEtMzYyZS00YWFkLTg1NTQtZTRjNGQwMTQzMjdiLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEzVDE5NTUxMVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWMxYjFmNzEzNWEzN2UxOTc1MmJlMzJhNDhhZDhjMTU1ODM1YWE2YTc1ZjMwYTZkYmExNzc5MzJjODY3MzY2MWUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.k8EDV3Dy-5Z6R_CI4KOr3aIl9CHcow9GA7lUCNApzQs)
TinyConv1 has quantized parameters:
TinyConv2 does not:
![gh_aiet_tc2](https://private-user-images.githubusercontent.com/59509547/390852726-69b5bcf2-a842-45b4-89c6-995dbc37c54e.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk0NzY4MTEsIm5iZiI6MTczOTQ3NjUxMSwicGF0aCI6Ii81OTUwOTU0Ny8zOTA4NTI3MjYtNjliNWJjZjItYTg0Mi00NWI0LTg5YzYtOTk1ZGJjMzdjNTRlLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTMlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjEzVDE5NTUxMVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTU1MGViYTZhYWNhMjUxOTgwNTkzNGQwN2ZjYWNlMGRiMzAyNTNmYjZhZjM5NTA2NDA5MDZhMWNjNmM2NzZiOGUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.R3x__oasIr7puhxBQvvYToWTgxzMAdP12-GXPgd3nCg)
Any other information you'd like to share?
The cause of this issue is that the test for identifying a
conv
node in the FX graph when applying quantization is too restrictive:ai-edge-torch/ai_edge_torch/quantize/pt2e_quantizer_utils.py
Lines 311 to 315 in d4e358e
When padding is defined in the instantiation,
n.target
becomes an instance oftorch._ops.OpOverload
, so this test fails even though it probably shouldn't.The text was updated successfully, but these errors were encountered: