-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathB200TC.m
More file actions
175 lines (98 loc) · 4.95 KB
/
B200TC.m
File metadata and controls
175 lines (98 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
function D = B200TC(alpha, A, B, beta, C, informat, outformat)
%
% B200TC Compute GEMM with a model of a tensor core of the B200 GPU.
%
% This function evaluates the expression D = A * B + C using the
% B200 TC numerical-feature-based model. The accumulation of block
% products is performed using recursive summation.
%
% Inputs
% A: Left matrix operand for the matrix multiplication A * B.
% B: Right matrix operand for the matrix multiplication A * B.
% C: Matrix added to the product A * B.
% informat: a string specifying the format of A and B.
% Supported input formats:
% fp8-(e5m2,e4m3), fp16, binary16, half,
% bf16, bfloat16, tensorfloat32, tf32.
% outformat: a string specifying the numerical format for C and D.
% Supported output formats:
% fp32, single, binary32,
% fp16, binary16, half.
%
% Output
% D: Result of the operation D = A * B + C computed under the
% specified tensor core configuration.
%addpath('tools')
% Allowed formats
allowedOutFormats = {'fp32', 'single', 'binary32',...
'fp16', 'binary16', 'half'};
allowedInFormats = {'fp8-e5m2','fp8-e4m3','e5m2','e4m3',...
'fp16','binary16', 'half','bf16','bfloat16','tensorfloat32','tf32'};
if exist('informat', 'var')
if (~ismember(lower(informat), allowedInFormats))
error('The specified input format is not supported.');
end
informat=lower(informat);
end
if (exist('outformat', 'var'))
if (~ismember(lower(outformat), allowedOutFormats))
error('The specified output format is not supported.');
end
outformat=lower(outformat);
end
% Default structures assuming fp16 in and fp32 output. See
% Generic_TC_Model.m for the information.
%---------------- Core configuration ----------------%
def_params.fma = 16; % Number of products in one FMA group
def_params.neab = 2; % Number of extra alignment bits (guard precision)
%---------------- Rounding configuration ----------------%
def_params.frmode = 'rz'; % Final rounding mode:
% 'rne' = round-to-nearest-even
def_params.armode = 'rz'; % Rounding mode during 2-operand alignment:
% 'rd' = round-down (towards -Inf)
% (multi-operand alignment uses truncation)
def_params.stkbitenabled = 0; % Enable sticky bit during alignment (1 = enabled)
%---------------- Accumulation architecture ----------------%
def_params.global_alignment = 1; % Align all products (and optionally c) to a common exponent
def_params.late_partial_sum = 0; % Add accumulation term 'c' after product summation
% (products kept in denormalised form)
def_params.odd_even_grouping = 0; % Enable separate accumulation of odd/even उत्पाद
def_params.pair_wise_sum = 0; % Enable pair-wise summation (not implemented)
%---------------- Exponent handling ----------------%
def_params.min_exp_limit = -133; % Minimum exponent allowed for product alignment
def_params.c_min_exp_limit = 0; % Control minimum exponent for c:
% 1 → clamp to -126 (FP32 subnormal boundary)
% 0 → allow special handling when c = 0
def_params.prd_limit = 0; % products are limited by output exponent bits, 1: limited, 0: allowed to exceed
%---------------- Accuracy / reference model ----------------%
def_params.correct_rounding = 0; % Enable exact (Kulisch-style) accumulation
% (used as reference / ground truth model)
%---------------- Subnormal handling ----------------%
def_params.in_subnormals = 1; % Input subnormal support:
% 1 → preserve and process subnormals
% 0 → flush subnormals to zero (FTZ)
def_params.out_subnormals = 1; % Output subnormal support:
% 1 → generate subnormal outputs
% 0 → flush subnormal results to zero
% Set up the model according to the formats specified.
if ismember(informat, {'fp16','half','binary16'})
if exist('outformat', 'var')
if ismember(outformat, {'fp16','binary16','half'})
def_params.frmode='rne'; % TC final rounding mode
end
end
elseif ismember(informat, {'tf32', 'tensorfloat32'})
def_params.fma=8;
elseif ismember(informat, {'fp8-e5m2','fp8-e4m3','e5m2','e4m3'})
% FMA size is 16, but interleaved pattern is used to join two
% 16-element vectors.
def_params.fma = 32;
if exist('outformat', 'var')
outformat='fp32';
if ismember(outformat, {'fp16','binary16','half'})
def_params.frmode='rne';
end
end
end
D = GEMM(alpha, A, B, beta, C, informat, outformat, def_params);
end