mirror of
https://github.com/dreamcmi/AT32F403A-TinyMaix.git
synced 2026-05-22 01:12:22 +00:00
324 lines
13 KiB
C
324 lines
13 KiB
C
/* Copyright 2022 Sipeed Technology Co., Ltd. All Rights Reserved.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
// It is default O0 implement
|
|
#include "tinymaix.h"
|
|
#include "float.h"
|
|
#include "math.h"
|
|
|
|
#if TM_OPT_LEVEL == TM_OPT0
|
|
|
|
#if TM_ARCH==TM_ARCH_CPU
|
|
#include "arch_cpu.h"
|
|
#elif TM_ARCH==TM_ARCH_ARM_SIMD
|
|
#include "arch_arm_simd.h"
|
|
#elif TM_ARCH==TM_ARCH_ARM_NEON
|
|
#include "arch_arm_neon.h"
|
|
#elif TM_ARCH==TM_ARCH_ARM_MVEI
|
|
#include "arch_arm_mvei.h"
|
|
#elif TM_ARCH==TM_ARCH_RV32P
|
|
#include "arch_rv32p.h"
|
|
#elif TM_ARCH==TM_ARCH_RV64V
|
|
#include "arch_rv64v.h"
|
|
#else
|
|
#error "UNSUPPORT ARCH!"
|
|
#endif
|
|
|
|
|
|
TM_PERF_REG(t_sbuf);TM_PERF_REG(t_dotp);TM_PERF_REG(t_post);
|
|
TM_PERF_REG(t_valid); TM_PERF_REG(t_pad);
|
|
TM_PERF_REG(t_conv); TM_PERF_REG(t_pwconv); TM_PERF_REG(t_dwconv);
|
|
|
|
/*************************** TML_CONV2D **********************************/
|
|
static uint32_t k_oft[TM_MAX_KSIZE];
|
|
static mtype_t sbuf[TM_MAX_KCSIZE];
|
|
#if (TM_MDL_TYPE==TM_MDL_FP32) || (TM_MDL_TYPE==TM_MDL_FP16)
|
|
#define SUMSCALE 1.0
|
|
TM_INLINE void l_postprocess_sum(sumtype_t sum, btype_t b, int act, mtype_t* outp, \
|
|
sctype_t scale, sctype_t out_s, zptype_t out_zp)
|
|
{ sum += b;
|
|
switch(act){ //activation func
|
|
case TM_ACT_RELU:
|
|
sum = sum>0?sum:0;
|
|
break;
|
|
case TM_ACT_RELU6:
|
|
sum = sum>0?sum:0;
|
|
sum = sum>6?6:sum;
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
*outp = (mtype_t)sum;
|
|
return;
|
|
}
|
|
#elif (TM_MDL_TYPE==TM_MDL_INT8) || (TM_MDL_TYPE==TM_MDL_INT16)
|
|
|
|
#if TM_FASTSCALE
|
|
static int32_t sumscale[TM_MAX_CSIZE];
|
|
#else
|
|
static float sumscale[TM_MAX_CSIZE];
|
|
#endif
|
|
#define SUMSCALE (sumscale[c])
|
|
|
|
#if !TM_FASTSCALE
|
|
TM_INLINE void l_postprocess_sum(sumtype_t sum, btype_t b, int act, mtype_t* outp, sctype_t scale, sctype_t out_s, zptype_t out_zp)
|
|
#else
|
|
TM_INLINE void l_postprocess_sum(sumtype_t sum, btype_t b, int act, mtype_t* outp, int32_t scale, int32_t out_s, zptype_t out_zp)
|
|
#endif
|
|
{ sum += b;
|
|
#if !TM_FASTSCALE
|
|
float sumf = sum*scale;
|
|
#else
|
|
sumtype_t sumf = (sum<<TM_FASTSCALE_SHIFT)/scale;
|
|
#endif
|
|
switch(act){ //activation func
|
|
case TM_ACT_RELU:
|
|
sumf = sumf>0?sumf:0;
|
|
break;
|
|
case TM_ACT_RELU6:
|
|
sumf = sumf>0?sumf:0;
|
|
#if (!TM_FASTSCALE)
|
|
sumf = sumf>6?6:sumf;
|
|
#else
|
|
sumf = sumf>(6<<TM_FASTSCALE_SHIFT)?(6<<TM_FASTSCALE_SHIFT):sumf;
|
|
#endif
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
#if !TM_FASTSCALE
|
|
*outp = (mtype_t)(sumf/out_s + out_zp); //(mtype_t)((int)(sumf/out_s) + out_zp) //(mtype_t)((int)(sumf/out_s +0.5) + out_zp)
|
|
#else
|
|
*outp = (mtype_t)(((sumf*out_s)>>(TM_FASTSCALE_SHIFT+TM_FASTSCALE_SHIFT))+out_zp);
|
|
#endif
|
|
return;
|
|
}
|
|
#endif
|
|
|
|
//for valid or kernel in valid part, use fast method
|
|
tm_err_t __attribute__((weak)) tml_conv2d_dwconv2d(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
|
int kw, int kh, int sx, int sy, int dx, int dy, int act, \
|
|
int pad_top, int pad_bottom, int pad_left, int pad_right, int dmul, \
|
|
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp) //kernel: (cho, chi, h, w)
|
|
{ TM_PERF_INIT(t_sbuf);TM_PERF_INIT(t_dotp);TM_PERF_INIT(t_post);
|
|
TM_PERF_INIT(t_valid);TM_PERF_INIT(t_pad);
|
|
TM_PERF_INIT(t_conv); TM_PERF_INIT(t_pwconv); TM_PERF_INIT(t_dwconv);
|
|
int pad_flag = (pad_top != 0 ||pad_bottom != 0 ||pad_left != 0 ||pad_right != 0);
|
|
if(dx!=1 || dy!= 1) return TM_ERR_TODO;
|
|
if(act >= TM_ACT_MAXCNT) return TM_ERR_UNSUPPORT;
|
|
int maxk = kw*kh;
|
|
if(maxk>TM_MAX_KSIZE) return TM_ERR_KSIZE;
|
|
if(maxk==1 && (pad_flag||dmul)) return TM_ERR_UNSUPPORT; //assume no pad or dwconv when pwconv
|
|
int chi = in->c;
|
|
int cho = out->c;
|
|
sumtype_t sum = 0;
|
|
mtype_t* outp = out->data;
|
|
|
|
#if (TM_MDL_TYPE == TM_MDL_INT8) || (TM_MDL_TYPE == TM_MDL_INT16)
|
|
#if TM_FASTSCALE
|
|
int32_t outscale = (1<<TM_FASTSCALE_SHIFT)/out_s;
|
|
for(int c=0; c<out->c;c++) sumscale[c]=1.0/ws[c]/in_s;
|
|
#else
|
|
sctype_t outscale = out_s;
|
|
for(int c=0; c<out->c;c++) sumscale[c]=ws[c]*in_s;
|
|
#endif
|
|
#else
|
|
sctype_t outscale = out_s;
|
|
#endif
|
|
|
|
if(maxk==1){ TM_PERF_START(t_pwconv); //pointwise conv
|
|
#define BATCH_SIZE 2
|
|
sumtype_t sums[BATCH_SIZE];
|
|
for (int y = 0; y < out->h; y++) {
|
|
for (int x = 0; x < out->w; x++) {
|
|
mtype_t* sptr = (mtype_t*)TM_MATP(in, sy*y, sx*x, 0);
|
|
wtype_t* kptr = (wtype_t*)w;
|
|
int c = 0;
|
|
for(; c<out->c-BATCH_SIZE+1; ){
|
|
tm_dot_prod_pack2(sptr, kptr, chi, sums);
|
|
l_postprocess_sum(sums[0], b[c], act, outp, SUMSCALE, outscale, out_zp); c++; outp++;
|
|
l_postprocess_sum(sums[1], b[c], act, outp, SUMSCALE, outscale, out_zp); c++; outp++;
|
|
kptr += chi*BATCH_SIZE;//*2;
|
|
}
|
|
for(; c<out->c; c++){
|
|
tm_dot_prod(sptr, kptr, chi, &sum); //size=maxk*chi //pw maxk==1
|
|
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;
|
|
kptr += chi;
|
|
}
|
|
}
|
|
}
|
|
TM_PERF_ADD(t_pwconv);
|
|
return TM_OK;
|
|
}
|
|
|
|
if(dmul) {TM_PERF_START(t_dwconv);} else {TM_PERF_START(t_conv);};
|
|
int oft = 0;
|
|
int idx = 0;
|
|
for(int y=0; y<kh; y++){ //gen k_oft table
|
|
for(int x=0; x<kw; x++){
|
|
k_oft[idx] = oft;
|
|
idx += 1;
|
|
oft += chi;
|
|
}
|
|
oft += (in->w - kw)*chi;
|
|
}
|
|
chi = dmul ? 1 : in->c; // dmul>=1 indicate depthwise; dummy chi for dwconv compatible
|
|
int slow_flag = 0; //same pad part is slow
|
|
for (int y = 0; y < out->h; y++) {
|
|
int src_y0 = sy*y - pad_top;
|
|
for (int x = 0; x < out->w; x++) {
|
|
int src_x0 = sx*x - pad_left;
|
|
sumtype_t sum;
|
|
slow_flag = ((src_y0<0)+(src_x0<0)+(src_y0+kh>in->h)+(src_x0+kw>in->w));
|
|
//TM_PERF_START(t_sbuf);
|
|
if(!slow_flag) {TM_PERF_START(t_valid); //valid or same valid part
|
|
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //?c/dmul:0
|
|
mtype_t* sptr = sptr_base; //= (mtype_t*)TM_MATP(in, src_y0, src_x0, 0); //sbuf 不变
|
|
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
|
for (int cc = 0; cc < (dmul?cho:chi); cc++) {
|
|
for (int k = 0; k < maxk; k++) {
|
|
sbuf[sidx+k] = sptr[k_oft[k]];
|
|
}
|
|
sidx += maxk;
|
|
sptr = sptr_base + (dmul?(cc+1)/dmul:(cc+1));
|
|
}
|
|
} else { TM_PERF_START(t_pad); //same pad part
|
|
int _ky0 = src_y0<0 ? -src_y0 : 0;
|
|
int _kx0 = src_x0<0 ? -src_x0 : 0;
|
|
int _ky1 = in->h-src_y0>kh ? kh : in->h-src_y0;
|
|
int _kx1 = in->w-src_x0>kw ? kw : in->w-src_x0;
|
|
uint32_t sidx=0; //sbuf:cho,chi,maxk //dw:chi==1;
|
|
uint32_t s_step = (_ky1-_ky0)*(_kx1-_kx0);
|
|
mtype_t* sptr_base = (mtype_t*)TM_MATP(in, src_y0, src_x0, 0);
|
|
mtype_t* sptr = sptr_base;
|
|
#if TM_MDL_TYPE == TM_MDL_INT8
|
|
memset(sbuf, in_zp, dmul?cho*maxk:chi*maxk); //do padding
|
|
#elif (TM_MDL_TYPE == TM_MDL_FP32)||(TM_MDL_TYPE == TM_MDL_FP16)||(TM_MDL_TYPE == TM_MDL_FP8_143)||(TM_MDL_TYPE == TM_MDL_FP8_152)
|
|
memset(sbuf, 0, (dmul?cho*maxk:chi*maxk)*sizeof(mtype_t));
|
|
#else
|
|
#error "unsupport mdl type"
|
|
#endif
|
|
for (int cc = 0; cc < (dmul?cho:chi); cc++) {
|
|
for(int _ky=_ky0; _ky<_ky1; _ky++){
|
|
for(int _kx=_kx0; _kx<_kx1; _kx++){
|
|
int k = _ky*kw + _kx;
|
|
sbuf[sidx+k] = sptr[k_oft[k]];
|
|
}
|
|
}
|
|
sidx += maxk;
|
|
sptr = sptr_base + (dmul?(cc+1)/dmul:(cc+1));
|
|
}
|
|
}
|
|
//TM_PERF_ADD(t_sbuf);
|
|
mtype_t* sptr = sbuf; //sbuf prepare ok~
|
|
if(maxk*chi==9 && dmul){ //simple opt for 3x3 dwconv
|
|
for(int c=0; c<out->c; c++){
|
|
wtype_t* kptr = (wtype_t*)w + c*chi*maxk;//TM_PERF_START(t_dotp);
|
|
tm_dot_prod_3x3x1(sptr, kptr, &sum);//TM_PERF_ADD(t_dotp);TM_PERF_START(t_post);
|
|
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;//TM_PERF_ADD(t_post);
|
|
sptr += maxk; //dwconv need move step
|
|
}
|
|
}else {
|
|
for(int c=0; c<out->c; c++){
|
|
wtype_t* kptr = (wtype_t*)w + c*chi*maxk;//TM_PERF_START(t_dotp);
|
|
tm_dot_prod(sptr, kptr, maxk*chi, &sum);//TM_PERF_ADD(t_dotp);TM_PERF_START(t_post);
|
|
l_postprocess_sum(sum, b[c], act, outp, SUMSCALE, outscale, out_zp); outp++;//TM_PERF_ADD(t_post);
|
|
if(dmul) sptr += maxk; //dwconv need move step
|
|
}
|
|
}
|
|
if(!slow_flag) {TM_PERF_ADD(t_valid);} else {TM_PERF_ADD(t_pad);}
|
|
}
|
|
}
|
|
if(dmul) {TM_PERF_ADD(t_dwconv);} else {TM_PERF_ADD(t_conv);};
|
|
return TM_OK;
|
|
}
|
|
|
|
/*************************** TML_GAP **********************************/
|
|
tm_err_t __attribute__((weak)) tml_gap(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
|
{ TM_DBGT_INIT();
|
|
mtype_t* data;
|
|
for(int c=0; c <out->c; c++){
|
|
sumtype_t sum = 0;
|
|
data = in->data + c;
|
|
for(int y=0; y <in->h; y++){
|
|
for(int x=0; x <in->w; x++){
|
|
sum += ((sumtype_t)(*data));
|
|
data += out->c;
|
|
}
|
|
}
|
|
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
|
out->data[c] = (mtype_t)((sum/((in->h)*(in->w))-in_zp)*in_s/out_s + out_zp); //requant
|
|
#elif TM_MDL_TYPE == TM_MDL_FP32 || TM_MDL_TYPE == TM_MDL_FP16
|
|
out->data[c] = (mtype_t)(sum/((in->h)*(in->w)));
|
|
//#else //#elif TM_MDL_TYPE == TM_MDL_FP8_143 || TM_MDL_TYPE == TM_MDL_FP8_152
|
|
#endif
|
|
}
|
|
return TM_OK;
|
|
}
|
|
|
|
/*************************** TML_FC **********************************/
|
|
tm_err_t __attribute__((weak)) tml_fc(tm_mat_t* in, tm_mat_t* out, wtype_t* w, btype_t* b, \
|
|
sctype_t* ws, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
|
{ TM_DBGT_INIT();
|
|
mtype_t* data = in->data;
|
|
for(int c=0; c <out->c; c++){
|
|
sumtype_t sum = 0;
|
|
tm_dot_prod(data, w+c*in->c, in->c, &sum);
|
|
sum += b[c]; //fuse with zp
|
|
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
|
out->data[c] = (mtype_t)(sum*in_s*ws[0]/out_s + out_zp); //requant
|
|
#else
|
|
out->data[c] = (mtype_t)(sum);
|
|
#endif
|
|
}
|
|
return TM_OK;
|
|
}
|
|
|
|
/*************************** TML_SOFTMAX **********************************/
|
|
tm_err_t __attribute__((weak)) tml_softmax(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
|
{ TM_DBGT_INIT(); //note we have float size output buf even in INT8/INT16 mode
|
|
mtype_t* din = in->data;
|
|
float* dout = (float*)(out->data);
|
|
float dmax = -FLT_MAX;
|
|
for(int c=0; c <in->c; c++){
|
|
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
|
dout[c] = (float)((sumtype_t)din[c] - in_zp)*in_s;
|
|
#else
|
|
dout[c] = din[c];
|
|
#endif
|
|
if(dout[c] > dmax) dmax = dout[c];
|
|
}
|
|
float sum = 0;
|
|
for(int c=0; c <in->c; c++){
|
|
dout[c] -= dmax;
|
|
dout[c] = (float)tm_exp(dout[c]);
|
|
sum += dout[c];
|
|
dout[c] -= 0.000001; //prevent 1.0 value (cause 256 overflow)
|
|
}
|
|
for(int c=0; c <in->c; c++){ //int8/int16 <= fp32, so it is ok
|
|
#if TM_MDL_TYPE == TM_MDL_INT8 || TM_MDL_TYPE == TM_MDL_INT16
|
|
out->data[c] = (mtype_t)(dout[c]/sum/out_s + out_zp); //requant
|
|
#else
|
|
out->data[c] = (mtype_t)(dout[c]/sum);
|
|
#endif
|
|
}
|
|
return TM_OK;
|
|
}
|
|
|
|
/*************************** TML_RESHAPE **********************************/
|
|
tm_err_t __attribute__((weak)) tml_reshape(tm_mat_t* in, tm_mat_t* out, sctype_t in_s, zptype_t in_zp, sctype_t out_s, zptype_t out_zp)
|
|
{
|
|
//in fact do nothing... out shape
|
|
return TM_OK;
|
|
}
|
|
|
|
#endif
|