# BERT int8 inference problems for parameter b = 56 incuding all the
# relevant post-ops and data types propagation
# 2D problems have M = b * 384
# 4D problems have batch = b x 16
#
# In total, there are 24 identical fragments in the topology:
#         ____|____ -----.
#        /    |    \     :
#      MM_1  MM_2  MM_3  :
#       |     |   /      :
#       |    MM_4 -------:
#        \   /           :
#         MM_5           :
#           |            :
#         MM_6 ----------`
#           |
#     Layer_norm ---.
#           |       :
#         MM_7      :
#           |       :
#         MM_8 -----`
#           |
#     Layer_norm

--reset
--skip-impl=ref
--cfg=u8s8s8 --attr-scales=src:common:0.25*+wei:common:0.5*+dst:common:0.025* --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7*+src:common:3*
--stag=ab --wtag=any --dtag=ab
# MM_3 is the same
21504x1024:1024x1024n"BERT:MM_1*48"

--reset
--skip-impl=ref
--cfg=u8s8u8 --attr-scales=src:common:0.25*+wei:common:0.5*+dst:common:0.025* --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7*+src:common:3*
--stag=ab --wtag=any --dtag=ab
21504x1024:1024x1024n"BERT:MM_2*24"

--reset
--skip-impl=ref
--cfg=u8s8bf16 --stag=abcd --wtag=abdc --dtag=abcd
--attr-scales=src:common:0.25*+wei:common:0.5*+dst:common:0.025*
#--attr-post-ops=add:bf16:per_dim_023
--attr-zero-points=dst:common:7*+wei:common:2*+src:common:3*
56x16x384x64:56x16x64x384n"BERT:MM_4*24"

--reset
--skip-impl=ref
--cfg=u8s8u8 --stag=abcd --wtag=abcd --dtag=abcd
--attr-scales=src:common:0.25*+wei:common:0.5*+dst:common:0.025*
--attr-zero-points=dst:common:7*+wei:common:2*+src:common:3*
56x16x384x384:56x16x384x64n"BERT:MM_5*24"

--reset
--skip-impl=ref
--cfg=u8s8bf16
--attr-scales=src:common:0.25*+wei:common:0.5*+dst:common:0.025* --bia_dt=bf16 --bia_mask=2
#--attr-post-ops=add:bf16:per_tensor
--attr-zero-points=dst:common:7*+src:common:3*
--stag=ab --wtag=any --dtag=ab
21504x1024:1024x1024n"BERT:MM_6*24"

--reset
--skip-impl=ref
--cfg=u8s8u8
--attr-scales=src:common:0.25*+wei:common:0.5*+dst:common:0.025* --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7*+src:common:3*
--stag=ab --wtag=any --dtag=ab
21504x1024:1024x4096n"BERT:MM_7*24"

--reset
--skip-impl=ref
--cfg=u8s8bf16
--attr-scales=src:common:0.25*+wei:common:0.5*+dst:common:0.025* --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7*+src:common:3*
#--attr-post-ops=add:bf16:per_tensor
--stag=ab --wtag=any --dtag=ab
21504x4096:4096x1024n"BERT:MM_8*24"
