본문 바로가기
공부/SW

비트 연산을 활용한 ReLu 함수 구현

by 기찬이즘 2024. 11. 26.
반응형

ReLu 함수를 그냥 구현하는 것은 굉장히 쉽다.
그냥 C 코드로 구현한다라고 생각하면

int relu(int x) {
	if (x < 0) 
	   return 0;
	else
	   return x;
}

이런식으로 구현일 될 것이다.
 
그런데 하드웨어 관점으로 본다면 이런식의 코딩을 굉장히 비효율적일 수 있다.
저런식으로의 코드가 컴파일 되서 ISA를 통해 하드웨어에서 구현된다라고 생각해보자.
어셈블리어로 구현된 각 instruction은 수행될 때, 한 clock 안에 끝나지 않고 몇 번의 clock에 걸쳐 수행되는 구조를 가지고 있다. 하지만 clock 마다 여러개의 instruction이 동시에 수행되게 하여 효율을 높인다.

위의 pipeline과 같이 ADD insturction이 입력되고 execude 과정이 수행되고 있는 시간에 STR insturction의 decode 과정과 또 다른 ADD instruction이 같이 수행되는 것 을 볼 수 있다.
이런식의 pipeline을 통해 낮은 CPI (Clock Cycles per Instruction)을 기대할 수 있게 된다.
 
하지만 이런 pipeline은 앞선 코드와 같이 분기문이 들어가서 memory 위치가 들쭉날쭉한 코드에 적용되기 어렵다.
따라서 최대한 분기문을 적게 쓰는데 하드웨어에 최적화된 코드를 작성하는데 유리하다.
그러면 분기문이 없는 ReLu 코드를 어떻게 만들 수 있을까?
 

void arm_relu_q7(int8_t *data, uint16_t size)
{
...
    /* Run the following code for M cores with DSP extension */

    uint16_t i = size >> 2;
    int8_t *input = data;
    int8_t *output = data;
    int32_t in;
    int32_t buf;
    int32_t mask;

    while (i)
    {
        in = arm_nn_read_s8x4_ia((const int8_t **)&input);

        /* extract the first bit */
        buf = (int32_t)ROR((uint32_t)in & 0x80808080, 7);

        /* if MSB=1, mask will be 0xFF, 0x0 otherwise */
        mask = QSUB8(0x00000000, buf);

        arm_nn_write_s8x4_ia(&output, in & (~mask));

        i--;
    }

    i = size & 0x3;
    while (i)
    {
        if (*input < 0)
        {
            *input = 0;
        }
        input++;
        i--;
    }
    
 ...

위 코드는 ARM에서 제공하는 CMSIS-NN Library 중 arm_relu_q7 함수의 일부이다.
간단히 설명하자면 8비트로 된 데이터를 4개씩 arm_nn_read_s8x4_ia 함수를 통해 가지고 온다.
그리고 0x80808080 값과 and 연산을 해주고 결과 값을 우측으로 7비트 이동시킨다.
0x80은 0b1000 0000 이므로 입력과 and 연산 시 입력이 음수일 때는 7비트 이동 후 0000 0001이 될것이고, 양수 일때는 0000 0000이 될것이다.
비트 쉬프트 이후 0x00000000에서 쉬프트된 비트 값을 빼는데, 결국 입력이 음수일 때는 0000 0000에서 0000 0001을 빼므로 1111 1111이 되고 양수일 때는 0000 0000에서 0000 0000을 빼므로 0000 0000이 된다.
자 이제 마스크가 완성이 되었다.
출력으로 처음에 들어왔던 입력과 mask에 not 연산을 취한 값을 and 연산 해주면 입력이 음수 일때는 0, 양수일 때는 입력 그대로 나가는 ReLu 함수가 만들어지게 된다.
 
이런식으로 비트 연산을 잘 이용하면 분기 없이 하드웨어 최적화된 코드를 작성할 수 있다.

반응형