mirror of
https://github.com/dreamcmi/AT32F403A-TinyMaix.git
synced 2026-05-21 01:11:59 +00:00
806 lines
34 KiB
C
806 lines
34 KiB
C
/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
|
||
Licensed under the Apache License, Version 2.0 (the "License");
|
||
you may not use this file except in compliance with the License.
|
||
You may obtain a copy of the License at
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
Unless required by applicable law or agreed to in writing, software
|
||
distributed under the License is distributed on an "AS IS" BASIS,
|
||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
See the License for the specific language governing permissions and
|
||
limitations under the License.
|
||
==============================================================================*/
|
||
// It is O1 implement
|
||
// warning: O1 code bloat much to get better performance, experimental now
|
||
/*
|
||
pwconv
|
||
dwconv
|
||
3x3:tm_dot_prod_gap_3x3x1
|
||
none 3x3:
|
||
mdwconv
|
||
conv
|
||
*/
|
||
|
||
#include "tinymaix.h"
|
||
#include "float.h"
|
||
#include "math.h"
|
||
|
||
#if TM_OPT_LEVEL == TM_OPT1
|
||
|
||
#if TM_ARCH==TM_ARCH_CPU
|
||
#include "arch_cpu.h"
|
||
#elif TM_ARCH==TM_ARCH_ARM_SIMD
|
||
#include "arch_arm_simd.h"
|
||
#elif TM_ARCH==TM_ARCH_ARM_NEON
|
||
#include "arch_arm_neon.h"
|
||
#elif TM_ARCH==TM_ARCH_ARM_MVEI
|
||
#include "arch_arm_mvei.h"
|
||
#elif TM_ARCH==TM_ARCH_RV32P
|
||
#include "arch_rv32p.h"
|
||
#elif TM_ARCH==TM_ARCH_RV64V
|
||
#include "arch_rv64v.h"
|
||
#else
|
||
#error "UNSUPPORT ARCH!"
|
||
#endif
|
||
|
||
TM_PERF_REG(t_sbuf);TM_PERF_REG(t_dotp);TM_PERF_REG(t_post);
|
||
TM_PERF_REG(t_valid); TM_PERF_REG(t_pad);
|
||
TM_PERF_REG(t_conv); TM_PERF_REG(t_pwconv); TM_PERF_REG(t_dwconv);
|
||
|
||
#define BATCH_SIZE 2 //batch sum size
|
||
/*************************** TML_CONV2D **********************************/
|
||
static uint32_t k_oft[TM_MAX_KSIZE];
|
||
static mtype_t sbuf[TM_MAX_KCSIZE];
|
||
#if (TM_MDL_TYPE==TM_MDL_FP32) || (TM_MDL_TYPE==TM_MDL_FP16)
|
||
#define SUMSCALE (1.0)
|
||
static sctype_t outscale;
|
||
TM_INLINE void l_postprocess_sum(sumtype_t sum, btype_t b, int act, mtype_t* outp, \
|
||
sctype_t scale, sctype_t out_s, zptype_t out_zp)
|
||
{ sum += b;
|
||
switch(act){ //activation func
|
||
case TM_ACT_RELU:
|
||
case TM_ACT_RELU6: //treat relu6 as relu in float mode //speed up
|
||
sum = sum>0?sum:0;
|
||
break;
|
||
// sum = sum>0?sum:0;
|
||
// sum = sum>6?6:sum;
|
||
// break;
|
||
default:
|
||
break;
|
||
}
|
||
*outp = (mtype_t)sum;
|
||
return;
|
||
}
|
||
|
||
#elif (TM_MDL_TYPE==TM_MDL_INT8) || (TM_MDL_TYPE==TM_MDL_INT16)
|
||
|
||
#if TM_FASTSCALE
|
||
static int32_t sumscale[TM_MAX_CSIZE];
|
||
static int32_t outscale;
|
||
#else
|
||
static float sumscale[TM_MAX_CSIZE];
|
||
static sctype_t outscale;
|
||
#endif
|
||
#define SUMSCALE (sumscale[c])
|
||
|
||
#if !TM_FASTSCALE
|
||
TM_INLINE void l_postprocess_sum(sumtype_t sum, btype_t b, int act, mtype_t* outp, sctype_t scale, sctype_t out_s, zptype_t out_zp)
|
||
#else
|
||
TM_INLINE void l_postprocess_sum(sumtype_t sum, btype_t b, int act, mtype_t* outp, int32_t scale, int32_t out_s, zptype_t out_zp)
|
||
#endif
|
||
{ sum += b;
|
||
#if !TM_FASTSCALE
|
||
float sumf = sum*scale;
|
||
#else
|
||
sumtype_t sumf = (sum<<TM_FASTSCALE_SHIFT)/scale;
|
||
#endif
|
||
switch(act){ //activation func
|
||
case TM_ACT_RELU:
|
||
sumf = sumf>0?sumf:0;
|
||
break;
|
||
case TM_ACT_RELU6:
|
||
sumf = sumf>0?sumf:0;
|
||
#if (!TM_FASTSCALE)
|
||
sumf = sumf>6?6:sumf;
|
||
#else
|
||
sumf = sumf>(6<<TM_FASTSCALE_SHIFT)?(6<<TM_FASTSCALE_SHIFT):sumf;
|
||
#endif
|
||
break;
|
||
default:
|
||
break;
|
||
}
|
||
#if !TM_FASTSCALE
|
||
*outp = (mtype_t)(sumf/out_s + out_zp); //(mtype_t)((int)(sumf/out_s) + out_zp) //(mtype_t)((int)(sumf/out_s +0.5) + out_zp)
|
||
#else
|
||
*outp = (mtype_t)(((sumf*out_s)>>(TM_FASTSCALE_SHIFT+TM_FASTSCALE_SHIFT))+out_zp);
|
||
#endif
|
||
return;
|
||
}
|
||
#endif
|
||
|
||
//1x1 pw conv
|
||
TM_INLINE tm_err_t l_tml_pwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
||
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{
|
||
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
||
int maxk = kw*kh;
|
||
int chi = in->c;
|
||
int cho = out->c;
|
||
sumtype_t sum;
|
||
mtype_t* outp = out->data;
|
||
sumtype_t sums[BATCH_SIZE];
|
||
|
||
for (int y = 0; y < out->h; y++) {
|
||
for (int x = 0; x < out->w; x++) {
|
||
mtype_t* sptr = (mtype_t*)TM_MATP(in, sy*y, sx*x, 0);
|
||
wtype_t* kptr = (wtype_t*)w;
|
||
|
||
int c = 0;
|
||
for(; c<out->c-BATCH_SIZE+1; ){
|
||
tm_dot_prod_pack2(sptr, kptr, chi, sums);
|
||
l_postprocess_sum(sums[0], b[c], act, outp, SUMSCALE, outscale, out_zp); c++; outp++;
|
||
l_postprocess_sum(sums[1], b[c], act, outp, SUMSCALE, outscale, out_zp); c++; outp++;
|
||
kptr += chi*BATCH_SIZE;//*2;
|
||
}
|
||
for(; c<out->c; c++){
|
||
tm_dot_prod(sptr, kptr, chi, &sum); //size=maxk*chi //pw maxk==1
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
kptr += chi;
|
||
}
|
||
}
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
//normal conv
|
||
TM_INLINE tm_err_t l_tml_conv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
||
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{
|
||
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
||
int maxk = kw*kh;
|
||
int chi = in->c;
|
||
int cho = out->c;
|
||
sumtype_t sum;
|
||
mtype_t* outp = out->data;
|
||
|
||
int oft = 0;
|
||
int idx = 0;
|
||
for(int y=0; y<kh; y++){ //gen k_oft table
|
||
for(int x=0; x<kw; x++){
|
||
k_oft[idx] = oft;
|
||
idx += 1;
|
||
oft += chi;
|
||
}
|
||
oft += (in->w - kw)*chi;
|
||
}
|
||
int slow_flag = 0; //same pad part is slow
|
||
|
||
for (int y = 0; y < out->h; y++) {
|
||
int src_y0 = sy*y - pad_top;
|
||
for (int x = 0; x < out->w; x++) {
|
||
int src_x0 = sx*x - pad_left;
|
||
sumtype_t sum;
|
||
slow_flag = ((src_y0<0)+(src_x0<0)+(src_y0+kh>in->h)+(src_x0+kw>in->w));
|
||
if(!slow_flag) {//valid or same valid part
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //?c/dmul:0
|
||
mtype_t* sptr = sptr_base; //= (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //sbuf 不变
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
for (int cc = 0; cc < chi; cc++) {
|
||
for (int k = 0; k < maxk; k++) {
|
||
sbuf[sidx+k] = sptr[k_oft[k]];
|
||
}
|
||
sidx += maxk;
|
||
sptr += 1;
|
||
}
|
||
} else { //same pad part
|
||
int _ky0 = src_y0<0 ? -src_y0 : 0;
|
||
int _kx0 = src_x0<0 ? -src_x0 : 0;
|
||
int _ky1 = in->h-src_y0>kh ? kh : in->h-src_y0;
|
||
int _kx1 = in->w-src_x0>kw ? kw : in->w-src_x0;
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
uint32_t s_step = (_ky1-_ky0)*(_kx1-_kx0);
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0);
|
||
mtype_t* sptr = sptr_base;
|
||
#if TM_MDL_TYPE == TM_MDL_INT8
|
||
memset(sbuf, in_zp, chi*maxk); //do padding
|
||
#elif (TM_MDL_TYPE == TM_MDL_FP32)||(TM_MDL_TYPE == TM_MDL_FP16)
|
||
memset(sbuf, 0, chi*maxk*sizeof(mtype_t));
|
||
#else
|
||
#error "unsupport mdl type"
|
||
#endif
|
||
for (int cc = 0; cc < chi; cc++) {
|
||
for(int _ky=_ky0; _ky<_ky1; _ky++){
|
||
for(int _kx=_kx0; _kx<_kx1; _kx++){
|
||
int k = _ky*kw + _kx;
|
||
sbuf[sidx+k] = sptr[k_oft[k]];
|
||
}
|
||
}
|
||
sidx += maxk;
|
||
sptr += 1;
|
||
}
|
||
}
|
||
mtype_t* sptr = sbuf; //sbuf prepare ok~
|
||
sumtype_t sums[BATCH_SIZE];
|
||
int c = 0;
|
||
wtype_t* kptr = (wtype_t*)w;
|
||
for(; c<out->c-BATCH_SIZE+1; ){
|
||
tm_dot_prod_pack2(sptr, kptr, maxk*chi, sums);
|
||
l_postprocess_sum(sums[0], b[c], act, outp, SUMSCALE, outscale, out_zp); c++; outp++;
|
||
l_postprocess_sum(sums[1], b[c], act, outp, SUMSCALE, outscale, out_zp); c++; outp++;
|
||
kptr += chi*maxk*BATCH_SIZE;
|
||
}
|
||
for(; c<out->c; c++){
|
||
tm_dot_prod(sptr, kptr, maxk*chi, &sum);
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
kptr += chi*maxk;
|
||
}
|
||
}
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
//dwconv
|
||
TM_INLINE tm_err_t l_tml_dwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
||
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{
|
||
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
||
int maxk = kw*kh;
|
||
int chi = in->c;
|
||
int cho = out->c;
|
||
sumtype_t sum;
|
||
mtype_t* outp = out->data;
|
||
|
||
int oft = 0;
|
||
int idx = 0;
|
||
for(int y=0; y<kh; y++){ //gen k_oft table
|
||
for(int x=0; x<kw; x++){
|
||
k_oft[idx] = oft;
|
||
idx += 1;
|
||
oft += chi;
|
||
}
|
||
oft += (in->w - kw)*chi;
|
||
}
|
||
int slow_flag = 0; //same pad part is slow
|
||
|
||
for (int y = 0; y < out->h; y++) {
|
||
int src_y0 = sy*y - pad_top;
|
||
for (int x = 0; x < out->w; x++) {
|
||
int src_x0 = sx*x - pad_left;
|
||
sumtype_t sum;
|
||
slow_flag = ((src_y0<0)+(src_x0<0)+(src_y0+kh>in->h)+(src_x0+kw>in->w));
|
||
if(!slow_flag) {//valid or same valid part
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //?c/dmul:0
|
||
mtype_t* sptr = sptr_base; //= (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //sbuf 不变
|
||
if(maxk==9){
|
||
for (int c = 0; c < cho; c++) {
|
||
wtype_t* kptr = (wtype_t*)w + c*9;
|
||
tm_dot_prod_gap_3x3x1(sptr, kptr, k_oft, &sum);
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
sptr += 1;
|
||
}
|
||
} else {
|
||
uint32_t sidx = 0;
|
||
for (int c = 0; c < cho; c++) {
|
||
for (int k = 0; k < maxk; k++) {
|
||
sbuf[sidx+k] = sptr[k_oft[k]];
|
||
}
|
||
sidx += maxk;
|
||
sptr += 1;
|
||
}
|
||
sptr = sbuf; //sbuf prepare ok~
|
||
for(int c=0; c<out->c; c++){
|
||
sum = 0;
|
||
wtype_t* kptr = (wtype_t*)w + c*maxk;
|
||
tm_dot_prod(sptr, kptr, maxk, &sum);
|
||
//sum = sptr[0]*kptr[0] + sptr[1]*kptr[1] + sptr[2]*kptr[2] + sptr[3]*kptr[3] + sptr[4]*kptr[4] + sptr[5]*kptr[5] + sptr[6]*kptr[6] + sptr[7]*kptr[7] + sptr[8]*kptr[8] ;
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
sptr += maxk; //dwconv need move step
|
||
}
|
||
}
|
||
} else { //same pad part
|
||
int _ky0 = src_y0<0 ? -src_y0 : 0;
|
||
int _kx0 = src_x0<0 ? -src_x0 : 0;
|
||
int _ky1 = in->h-src_y0>kh ? kh : in->h-src_y0;
|
||
int _kx1 = in->w-src_x0>kw ? kw : in->w-src_x0;
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
uint32_t s_step = (_ky1-_ky0)*(_kx1-_kx0);
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0);
|
||
mtype_t* sptr = sptr_base;
|
||
#if TM_MDL_TYPE == TM_MDL_INT8
|
||
memset(sbuf, in_zp, cho*maxk); //do padding
|
||
#elif (TM_MDL_TYPE == TM_MDL_FP32)||(TM_MDL_TYPE == TM_MDL_FP16)
|
||
memset(sbuf, 0, cho*maxk*sizeof(mtype_t));
|
||
#else
|
||
#error "unsupport mdl type"
|
||
#endif
|
||
for (int cc = 0; cc < cho; cc++) {
|
||
for(int _ky=_ky0; _ky<_ky1; _ky++){
|
||
for(int _kx=_kx0; _kx<_kx1; _kx++){
|
||
int k = _ky*kw + _kx;
|
||
sbuf[sidx+k] = sptr[k_oft[k]];
|
||
}
|
||
}
|
||
sidx += maxk;
|
||
sptr += 1;
|
||
}
|
||
sptr = sbuf; //sbuf prepare ok~
|
||
if(maxk==9){
|
||
for (int c = 0; c < cho; c++) {
|
||
wtype_t* kptr = (wtype_t*)w + c*9;
|
||
tm_dot_prod_3x3x1(sptr, kptr, &sum);
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
sptr += maxk;
|
||
}
|
||
} else {
|
||
for(int c=0; c<out->c; c++){
|
||
wtype_t* kptr = (wtype_t*)w + c*maxk;
|
||
tm_dot_prod(sptr, kptr, maxk, &sum);
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
sptr += maxk; //dwconv need move step
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
//dwconv
|
||
TM_INLINE tm_err_t l_tml_dwconv2d_3x3_part(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
||
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp,
|
||
int x0, int x1, int y0, int y1)
|
||
{
|
||
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
||
int maxk = kw*kh;
|
||
int chi = in->c;
|
||
int cho = out->c;
|
||
sumtype_t sum;
|
||
mtype_t* outp = out->data;
|
||
|
||
int oft = 0;
|
||
int idx = 0;
|
||
for(int y=0; y<kh; y++){ //gen k_oft table
|
||
for(int x=0; x<kw; x++){
|
||
k_oft[idx] = oft;
|
||
idx += 1;
|
||
oft += chi;
|
||
}
|
||
oft += (in->w - kw)*chi;
|
||
}
|
||
int slow_flag = 0; //same pad part is slow
|
||
|
||
for (int y = 0; y < out->h; y++) {
|
||
int src_y0 = sy*y - pad_top;
|
||
for (int x = 0; x < out->w; x++) {
|
||
int src_x0 = sx*x - pad_left;
|
||
sumtype_t sum;
|
||
slow_flag = ((src_y0<0)+(src_x0<0)+(src_y0+kh>in->h)+(src_x0+kw>in->w));
|
||
if(!slow_flag) {//valid or same valid part
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //?c/dmul:0
|
||
mtype_t* sptr = sptr_base; //= (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //sbuf 不变
|
||
for (int c = 0; c < cho; c++) {
|
||
wtype_t* kptr = (wtype_t*)w + c*9;
|
||
tm_dot_prod_gap_3x3x1(sptr, kptr, k_oft, &sum);
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
sptr += 1;
|
||
}
|
||
} else { //same pad part
|
||
int _ky0 = src_y0<0 ? -src_y0 : 0;
|
||
int _kx0 = src_x0<0 ? -src_x0 : 0;
|
||
int _ky1 = in->h-src_y0>kh ? kh : in->h-src_y0;
|
||
int _kx1 = in->w-src_x0>kw ? kw : in->w-src_x0;
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
uint32_t s_step = (_ky1-_ky0)*(_kx1-_kx0);
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0);
|
||
mtype_t* sptr = sptr_base;
|
||
#if TM_MDL_TYPE == TM_MDL_INT8
|
||
memset(sbuf, in_zp, cho*maxk); //do padding
|
||
#elif (TM_MDL_TYPE == TM_MDL_FP32)||(TM_MDL_TYPE == TM_MDL_FP16)
|
||
memset(sbuf, 0, cho*maxk*sizeof(mtype_t));
|
||
#else
|
||
#error "unsupport mdl type"
|
||
#endif
|
||
for (int cc = 0; cc < cho; cc++) {
|
||
for(int _ky=_ky0; _ky<_ky1; _ky++){
|
||
for(int _kx=_kx0; _kx<_kx1; _kx++){
|
||
int k = _ky*kw + _kx;
|
||
sbuf[sidx+k] = sptr[k_oft[k]];
|
||
}
|
||
}
|
||
sidx += maxk;
|
||
sptr += 1;
|
||
}
|
||
sptr = sbuf; //sbuf prepare ok~
|
||
outp = out->data + (y*out->w+x)*cho;
|
||
for (int c = 0; c < cho; c++) {
|
||
wtype_t* kptr = (wtype_t*)w + c*9;
|
||
tm_dot_prod_3x3x1(sptr, kptr, &sum);
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
sptr += maxk;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
|
||
#define CONV_BLK_W (4)
|
||
#define CONV_BLK_H (4)
|
||
#define CONV_BLK_STEPX (CONV_BLK_W-3+1)
|
||
#define CONV_BLK_STEPY (CONV_BLK_H-3+1)
|
||
static mtype_t dw_sbuf[CONV_BLK_W*CONV_BLK_H*TM_MAX_CSIZE];
|
||
static uint32_t dw_koft[CONV_BLK_W*CONV_BLK_H];
|
||
|
||
// 40ms->27ms
|
||
TM_INLINE tm_err_t l_tml_dwconv2d_3x3_nostride(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
||
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{
|
||
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
||
int maxk = kw*kh;
|
||
int chi = in->c;
|
||
int cho = out->c;
|
||
sumtype_t sum;
|
||
mtype_t* outp = out->data;
|
||
|
||
int oft = 0;
|
||
int idx = 0;
|
||
for(int y=0; y<kh; y++){ //gen k_oft table
|
||
for(int x=0; x<kw; x++){
|
||
k_oft[idx] = oft;
|
||
idx += 1;
|
||
oft += chi;
|
||
}
|
||
oft += (in->w - kw)*chi;
|
||
}
|
||
oft = 0;
|
||
idx = 0;
|
||
for(int y=0; y<CONV_BLK_H; y++){ //gen k_oft table
|
||
for(int x=0; x<CONV_BLK_W; x++){
|
||
dw_koft[idx] = oft;
|
||
idx += 1;
|
||
oft += chi;
|
||
}
|
||
oft += (in->w - CONV_BLK_W)*chi;
|
||
}
|
||
int slow_flag = 0; //same pad part is slow
|
||
sumtype_t sum0,sum1,sum2,sum3;
|
||
|
||
uint32_t maxk_blk = CONV_BLK_H*CONV_BLK_W;
|
||
for (int y = 0; y+CONV_BLK_STEPY <= out->h; y+=CONV_BLK_STEPY) {
|
||
int src_y0 = sy*y - pad_top;
|
||
for (int x = 0; x+CONV_BLK_STEPX <= out->w; x+=CONV_BLK_STEPX) {
|
||
int src_x0 = sx*x - pad_left;
|
||
sumtype_t sum;
|
||
slow_flag = ((src_y0<0)+(src_x0<0)+(src_y0+CONV_BLK_H>in->h)+(src_x0+CONV_BLK_W>in->w));
|
||
if(!slow_flag) {//valid or same valid part
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //?c/dmul:0
|
||
mtype_t* sptr = sptr_base; //= (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //sbuf 不变
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
outp = out->data + (y*out->w+x)*cho;
|
||
for (int c = 0; c < cho; c++) {
|
||
wtype_t* kptr = (wtype_t*)w + c*maxk;
|
||
sum0 = sptr[dw_koft[0]]*kptr[0] + sptr[dw_koft[1]]*kptr[1] + sptr[dw_koft[2]]*kptr[2] + \
|
||
sptr[dw_koft[4]]*kptr[3] + sptr[dw_koft[5]]*kptr[4] + sptr[dw_koft[6]]*kptr[5] + \
|
||
sptr[dw_koft[8]]*kptr[6] + sptr[dw_koft[9]]*kptr[7] + sptr[dw_koft[10]]*kptr[8] ;
|
||
sum1 = sptr[dw_koft[1]]*kptr[0] + sptr[dw_koft[2]]*kptr[1] + sptr[dw_koft[3]]*kptr[2] + \
|
||
sptr[dw_koft[5]]*kptr[3] + sptr[dw_koft[6]]*kptr[4] + sptr[dw_koft[7]]*kptr[5] + \
|
||
sptr[dw_koft[9]]*kptr[6] + sptr[dw_koft[10]]*kptr[7] + sptr[dw_koft[11]]*kptr[8] ;
|
||
sum2 = sptr[dw_koft[4]]*kptr[0] + sptr[dw_koft[5]]*kptr[1] + sptr[dw_koft[6]]*kptr[2] + \
|
||
sptr[dw_koft[8]]*kptr[3] + sptr[dw_koft[9]]*kptr[4] + sptr[dw_koft[10]]*kptr[5] + \
|
||
sptr[dw_koft[12]]*kptr[6] + sptr[dw_koft[13]]*kptr[7] + sptr[dw_koft[14]]*kptr[8] ;
|
||
sum3 = sptr[dw_koft[5]]*kptr[0] + sptr[dw_koft[6]]*kptr[1] + sptr[dw_koft[7]]*kptr[2] + \
|
||
sptr[dw_koft[9]]*kptr[3] + sptr[dw_koft[10]]*kptr[4] + sptr[dw_koft[11]]*kptr[5] + \
|
||
sptr[dw_koft[13]]*kptr[6] + sptr[dw_koft[14]]*kptr[7] + sptr[dw_koft[15]]*kptr[8] ;
|
||
l_postprocess_sum(sum0, b[c], act, outp+0*cho, SUMSCALE, outscale, out_zp);
|
||
l_postprocess_sum(sum1, b[c], act, outp+1*cho, SUMSCALE, outscale, out_zp);
|
||
l_postprocess_sum(sum2, b[c], act, outp+(out->w+0)*cho, SUMSCALE, outscale, out_zp);
|
||
l_postprocess_sum(sum3, b[c], act, outp+(out->w+1)*cho, SUMSCALE, outscale, out_zp);
|
||
outp ++;
|
||
sptr ++;
|
||
}
|
||
} else { //same pad part
|
||
int _ky0 = src_y0<0 ? -src_y0 : 0;
|
||
int _kx0 = src_x0<0 ? -src_x0 : 0;
|
||
int _ky1 = in->h-src_y0>CONV_BLK_H ? CONV_BLK_H : in->h-src_y0;
|
||
int _kx1 = in->w-src_x0>CONV_BLK_W ? CONV_BLK_W : in->w-src_x0;
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
uint32_t s_step = (_ky1-_ky0)*(_kx1-_kx0);
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0);
|
||
mtype_t* sptr = sptr_base;
|
||
#if TM_MDL_TYPE == TM_MDL_INT8
|
||
memset(dw_sbuf, in_zp, cho*maxk_blk); //do padding
|
||
#elif (TM_MDL_TYPE == TM_MDL_FP32)||(TM_MDL_TYPE == TM_MDL_FP16)
|
||
memset(dw_sbuf, 0, cho*maxk_blk*sizeof(mtype_t));
|
||
#else
|
||
#error "unsupport mdl type"
|
||
#endif
|
||
for (int cc = 0; cc < cho; cc++) {
|
||
for(int _ky=_ky0; _ky<_ky1; _ky++){
|
||
for(int _kx=_kx0; _kx<_kx1; _kx++){
|
||
int k = _ky*CONV_BLK_W + _kx;
|
||
dw_sbuf[sidx+k] = sptr[dw_koft[k]];
|
||
}
|
||
}
|
||
sidx += maxk_blk;
|
||
sptr ++;
|
||
}
|
||
sptr = dw_sbuf; //sbuf prepare ok~
|
||
outp = out->data + (y*out->w+x)*cho;
|
||
sumtype_t sum0,sum1,sum2,sum3;
|
||
for(int c=0; c<out->c; c++){
|
||
wtype_t* kptr = (wtype_t*)w + c*maxk;
|
||
//tm_dot_prod(sptr, kptr, maxk, &sum);TM_PERF_ADD(t_dotp);
|
||
//l_postprocess_sum(sum, b[c], act, outp+0*cho, SUMSCALE, outscale, out_zp);
|
||
sum0 = sptr[0]*kptr[0] + sptr[1]*kptr[1] + sptr[2]*kptr[2] + \
|
||
sptr[4]*kptr[3] + sptr[5]*kptr[4] + sptr[6]*kptr[5] + \
|
||
sptr[8]*kptr[6] + sptr[9]*kptr[7] + sptr[10]*kptr[8] ;
|
||
sum1 = sptr[1]*kptr[0] + sptr[2]*kptr[1] + sptr[3]*kptr[2] + \
|
||
sptr[5]*kptr[3] + sptr[6]*kptr[4] + sptr[7]*kptr[5] + \
|
||
sptr[9]*kptr[6] + sptr[10]*kptr[7] + sptr[11]*kptr[8] ;
|
||
sum2 = sptr[4]*kptr[0] + sptr[5]*kptr[1] + sptr[6]*kptr[2] + \
|
||
sptr[8]*kptr[3] + sptr[9]*kptr[4] + sptr[10]*kptr[5] + \
|
||
sptr[12]*kptr[6] + sptr[13]*kptr[7] + sptr[14]*kptr[8] ;
|
||
sum3 = sptr[5]*kptr[0] + sptr[6]*kptr[1] + sptr[7]*kptr[2] + \
|
||
sptr[9]*kptr[3] + sptr[10]*kptr[4] + sptr[11]*kptr[5] + \
|
||
sptr[13]*kptr[6] + sptr[14]*kptr[7] + sptr[15]*kptr[8] ;
|
||
l_postprocess_sum(sum0, b[c], act, outp+0*cho, SUMSCALE, outscale, out_zp);
|
||
l_postprocess_sum(sum1, b[c], act, outp+1*cho, SUMSCALE, outscale, out_zp);
|
||
l_postprocess_sum(sum2, b[c], act, outp+(out->w+0)*cho, SUMSCALE, outscale, out_zp);
|
||
l_postprocess_sum(sum3, b[c], act, outp+(out->w+1)*cho, SUMSCALE, outscale, out_zp);
|
||
//printf("==%.1f,%.1f,%.1f,%.1f\r\n", out->data[0], out->data[1], out->data[2], out->data[3]);
|
||
sptr += maxk_blk; //dwconv need move step
|
||
outp++;
|
||
}
|
||
}
|
||
}
|
||
//x loop end
|
||
}
|
||
|
||
if(out->w % CONV_BLK_W != 0 || out->h % CONV_BLK_H != 0) { //cal rest part
|
||
// x = [out->w-out->w%CONV_BLK_W, out->w), y = [0, out->h/CONV_BLK_H*CONV_BLK_H)
|
||
// x = [0, out->w), y = [out->h-out->h%CONV_BLK_H, out->h/CONV_BLK_H*CONV_BLK_H)
|
||
l_tml_dwconv2d_3x3_part(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp,\
|
||
out->w-out->w%CONV_BLK_W, out->w, 0, out->h/CONV_BLK_H*CONV_BLK_H);
|
||
l_tml_dwconv2d_3x3_part(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp,\
|
||
0, out->w, out->h-out->h%CONV_BLK_H, out->h/CONV_BLK_H*CONV_BLK_H);
|
||
}
|
||
|
||
//TODO: rest
|
||
return TM_OK;
|
||
}
|
||
|
||
|
||
//dmul>1 dwconv //not opt yet
|
||
TM_INLINE tm_err_t l_tml_mdwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
||
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{
|
||
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
||
int maxk = kw*kh;
|
||
int chi = in->c;
|
||
int cho = out->c;
|
||
sumtype_t sum;
|
||
mtype_t* outp = out->data;
|
||
|
||
int oft = 0;
|
||
int idx = 0;
|
||
for(int y=0; y<kh; y++){ //gen k_oft table
|
||
for(int x=0; x<kw; x++){
|
||
k_oft[idx] = oft;
|
||
idx += 1;
|
||
oft += chi;
|
||
}
|
||
oft += (in->w - kw)*chi;
|
||
}
|
||
int slow_flag = 0; //same pad part is slow
|
||
|
||
for (int y = 0; y < out->h; y++) {
|
||
int src_y0 = sy*y - pad_top;
|
||
for (int x = 0; x < out->w; x++) {
|
||
int src_x0 = sx*x - pad_left;
|
||
sumtype_t sum;
|
||
slow_flag = ((src_y0<0)+(src_x0<0)+(src_y0+kh>in->h)+(src_x0+kw>in->w));
|
||
if(!slow_flag) {//valid or same valid part
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //?c/dmul:0
|
||
mtype_t* sptr = sptr_base; //= (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //sbuf 不变
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
for (int cc = 0; cc < cho; cc++) {
|
||
for (int k = 0; k < maxk; k++) {
|
||
sbuf[sidx+k] = sptr[k_oft[k]];
|
||
}
|
||
sidx += maxk;
|
||
sptr = sptr_base + (cc+1)/dmul;
|
||
}
|
||
} else { //same pad part
|
||
int _ky0 = src_y0<0 ? -src_y0 : 0;
|
||
int _kx0 = src_x0<0 ? -src_x0 : 0;
|
||
int _ky1 = in->h-src_y0>kh ? kh : in->h-src_y0;
|
||
int _kx1 = in->w-src_x0>kw ? kw : in->w-src_x0;
|
||
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
||
uint32_t s_step = (_ky1-_ky0)*(_kx1-_kx0);
|
||
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0);
|
||
mtype_t* sptr = sptr_base;
|
||
#if TM_MDL_TYPE == TM_MDL_INT8
|
||
memset(sbuf, in_zp, cho*maxk); //do padding
|
||
#elif (TM_MDL_TYPE == TM_MDL_FP32)||(TM_MDL_TYPE == TM_MDL_FP16)
|
||
memset(sbuf, 0, cho*maxk*sizeof(mtype_t));
|
||
#else
|
||
#error "unsupport mdl type"
|
||
#endif
|
||
for (int cc = 0; cc < cho; cc++) {
|
||
for(int _ky=_ky0; _ky<_ky1; _ky++){
|
||
for(int _kx=_kx0; _kx<_kx1; _kx++){
|
||
int k = _ky*kw + _kx;
|
||
sbuf[sidx+k] = sptr[k_oft[k]];
|
||
}
|
||
}
|
||
sidx += maxk;
|
||
sptr = sptr_base + (cc+1)/dmul;
|
||
}
|
||
}
|
||
mtype_t* sptr = sbuf; //sbuf prepare ok~
|
||
for(int c=0; c<out->c; c++){
|
||
sum = 0;
|
||
wtype_t* kptr = (wtype_t*)w + c*maxk;
|
||
tm_dot_prod(sptr, kptr, maxk, &sum);
|
||
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
||
sptr += maxk; //dwconv need move step
|
||
}
|
||
}
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
|
||
tm_err_t tml_conv2d_dwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
||
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp) //kernel: (cho, chi, h, w)
|
||
{ TM_PERF_INIT(t_conv); TM_PERF_INIT(t_pwconv); TM_PERF_INIT(t_dwconv);
|
||
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
||
if(dx!=1 || dy!= 1) return TM_ERR_TODO;
|
||
if(act >= TM_ACT_MAXCNT) return TM_ERR_UNSUPPORT;
|
||
int maxk = kw*kh;
|
||
if(maxk>TM_MAX_KSIZE) return TM_ERR_KSIZE;
|
||
if(maxk==1 && (pad_flag||dmul)) return TM_ERR_UNSUPPORT; //assume no pad or dwconv when pwconv
|
||
|
||
#if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
|
||
#if TM_FASTSCALE
|
||
outscale = (1<<TM_FASTSCALE_SHIFT)/out_s;
|
||
for(int c=0; c<out->c;c++) sumscale[c]=1.0/ws[c]/in_s;
|
||
#else
|
||
outscale = out_s;
|
||
for(int c=0; c<out->c;c++) sumscale[c]=ws[c]*in_s;
|
||
#endif
|
||
#else
|
||
outscale = out_s;
|
||
#endif
|
||
|
||
if(maxk == 1) { TM_PERF_START(t_pwconv); //pw conv
|
||
l_tml_pwconv2d(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp);
|
||
TM_PERF_ADD(t_pwconv);
|
||
} else if(dmul == 0) { TM_PERF_START(t_conv); //normal conv
|
||
l_tml_conv2d(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp);
|
||
TM_PERF_ADD(t_conv);
|
||
} else if(dmul == 1) { TM_PERF_START(t_dwconv); //dw conv
|
||
if(kh==3 && kw==3){ //opt for 3x3 dwconv
|
||
if(sx==1&&sy==1){
|
||
l_tml_dwconv2d_3x3_nostride(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp);
|
||
} else {
|
||
l_tml_dwconv2d_3x3_part(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp,\
|
||
0, out->w, 0, out->h);
|
||
}
|
||
}else{
|
||
l_tml_dwconv2d(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp);
|
||
}
|
||
TM_PERF_ADD(t_dwconv);
|
||
} else { //dmul>1 dwconv, rare
|
||
l_tml_mdwconv2d(in,out,w,b, kw,kh, sx,sy, dx,dy, act, \
|
||
pad_top, pad_bottom, pad_left, pad_right, dmul, ws, in_s, in_zp, out_s, out_zp);
|
||
}
|
||
|
||
return TM_OK;
|
||
}
|
||
|
||
|
||
/*************************** TML_GAP **********************************/
|
||
tm_err_t __attribute__((weak)) tml_gap(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{ TM_DBGT_INIT();
|
||
mtype_t* data;
|
||
for(int c=0; c <out->c; c++){
|
||
sumtype_t sum = 0;
|
||
data = in->data + c;
|
||
for(int y=0; y <in->h; y++){
|
||
for(int x=0; x <in->w; x++){
|
||
sum += ((sumtype_t)(*data));
|
||
data += out->c;
|
||
}
|
||
}
|
||
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
||
out->data[c] = (mtype_t)((sum/((in->h)*(in->w))-in_zp)*in_s/out_s + out_zp); //requant
|
||
#elif TM_MDL_TYPE == TM_MDL_FP32 || TM_MDL_TYPE == TM_MDL_FP16
|
||
out->data[c] = (mtype_t)(sum/((in->h)*(in->w)));
|
||
//#else //#elif TM_MDL_TYPE == TM_MDL_FP8_143 || TM_MDL_TYPE == TM_MDL_FP8_152
|
||
#endif
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
/*************************** TML_FC **********************************/
|
||
tm_err_t __attribute__((weak)) tml_fc(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
||
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{ TM_DBGT_INIT();
|
||
mtype_t* data = in->data;
|
||
for(int c=0; c <out->c; c++){
|
||
sumtype_t sum = 0;
|
||
tm_dot_prod(data, w+c*in->c, in->c, &sum);
|
||
sum += b[c]; //fuse with zp
|
||
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
||
out->data[c] = (mtype_t)(sum*in_s*ws[0]/out_s + out_zp); //requant
|
||
#else
|
||
out->data[c] = (mtype_t)(sum);
|
||
#endif
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
/*************************** TML_SOFTMAX **********************************/
|
||
tm_err_t __attribute__((weak)) tml_softmax(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{ TM_DBGT_INIT(); //note we have float size output buf even in INT8/INT16 mode
|
||
mtype_t* din = in->data;
|
||
float* dout = (float*)(out->data);
|
||
float dmax = -FLT_MAX;
|
||
for(int c=0; c <in->c; c++){
|
||
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
||
dout[c] = (float)((sumtype_t)din[c] - in_zp)*in_s;
|
||
#else
|
||
dout[c] = din[c];
|
||
#endif
|
||
if(dout[c] > dmax) dmax = dout[c];
|
||
}
|
||
float sum = 0;
|
||
for(int c=0; c <in->c; c++){
|
||
dout[c] -= dmax;
|
||
dout[c] = (float)tm_exp(dout[c]);
|
||
sum += dout[c];
|
||
dout[c] -= 0.000001; //prevent 1.0 value (cause 256 overflow)
|
||
}
|
||
for(int c=0; c <in->c; c++){ //int8/int16 <= fp32, so it is ok
|
||
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
||
out->data[c] = (mtype_t)(dout[c]/sum/out_s + out_zp); //requant
|
||
#else
|
||
out->data[c] = (mtype_t)(dout[c]/sum);
|
||
#endif
|
||
}
|
||
return TM_OK;
|
||
}
|
||
|
||
/*************************** TML_RESHAPE **********************************/
|
||
tm_err_t __attribute__((weak)) tml_reshape(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
||
{
|
||
//in fact do nothing... out shape
|
||
return TM_OK;
|
||
}
|
||
|
||
|
||
#endif
|
||
|