[Softeer] 바이러스 (C++)
문제
바이러스가 숙주의 몸속에서 1초당 P배씩 증가한다. 처음에 바이러스 K마리가 있었다면 N초 후에는 총 몇 마리의 바이러스로 불어날까? N초 동안 죽는 바이러스는 없다고 가정한다.
제약조건
1 ≤ K ≤ 10^8
1 ≤ P ≤ 10^8
1 ≤ N ≤ 10^6
입력 형식
첫 번째 줄에 처음 바이러스의 수 K, 증가율 P, 총 시간 N(초)이 주어진다.
출력 형식
최종 바이러스 개수를 1000000007로 나눈 나머지를 출력하라.
입력 예시
2 3 2
출력 예시
18
문제의 의도
문제의 의도는 주어진 시간 동안 기하급수적으로 증가하는 바이러스의 수를 효율적으로 계산할 수 있도록 코드를 작성하는 것이다.
접근법
- long long 타입 변수
C++에서 int 타입은 4바이트 32비트 정수인데, 이는 약 -2^31 ~ 2^31 - 1( 2,147,483,647) 사이의 값을 저장할 수 있다. K, P, N 숫자들의 범위를 보면 알 수 있듯이 바이러스의 총 계수를 계산(P^N * K)하다 보면 결과 값이 매우 커질 수 있기 때문에 64비트 정수를 표현할 수 있는 long long 타입을 사용하는 것이 좋다.
- 지수법칙
P^N을 pow와 같은 함수를 통해 직접 계산하는 것은 반환형이 double이기 때문에 모듈러 연산을 위해 정수로의 형변환 과정에서 오차가 발생할 수 있어서 부적절하다. 또한, 단순히 N까지 반복문으로 바이러스 수를 계산할 경우, N = 10^6이라면, P를 10^6번 곱해야 하는데, 이것은 시간적으로 매우 비효율적이고 시간 초과될 것이다.
그래서, 지수법칙(Exponentiation by Squaring)을 이용한다. 지수법칙은 분할 정복 알고리즘을 사용하여 지수를 절반으로 나누어 계산하기 때문에 훨씬 적은 연산으로 결과를 얻을 수 있다. 예를 들면, P^128이라면 128번이 아닌 단 7번의 곱셈으로 답을 얻을 수 있다.
P^128 = (P^64)^2 = ((P^32)^2)^2 = (((P^16)^2)^2)^2 = ((((P^8)^2)^2)^2)^2 = (((((P^4)^2)^2)^2)^2)^2 = ((((((P^2)^2)^2)^2)^2)^2)^2 (7번의 곱셈)
- 모듈러 연산
큰 수의 곱셈을 여러 번 반복하면 중간 결과가 매우 커질 수 있다. 예를 들면, 에서 P와 N이 매우 크다면, 이 계산 결과는 여전 long long 자료형의 범위를 초과하여 오버플로우가 발생할 수 있다. 그렇기 때문에 모듈러 연산을 사용하여 중간 결과를 항상 일정 범위 내(모듈러 값 이하)로 제한하자는 것이다. 예들 들어, MOD = 1000000007로 모듈러 연산을 적용하면, 중간 결과가 절대 1000000007을 초과하지 않게 되기 때문에 오버플로우를 방지할 수 있다.
추가적으로, 모듈러 연산은 수학적으로 다음과 같은 수식이 성립한다.
(a×b) mod m = ((a mod m) × (b mod m)) mod m
수식에 의하면, 곱셈 결과에 모듈러 연산을 직접 적용하는 것과 중간중간에 모듈러 연산을 적용하는 것이 동일한 결과를 가진다. 즉, 최종 결과에만 영향을 미칠 뿐, 연산 중간에는 영향을 미치지 않기 때문에 중간 계산에 모듈러 연산을 사용해도 최종 결과의 정확성을 유지할 수 있다.
C++ 구현 코드
위 3가지 접근법을 이용한 코드는 다음과 같다.
#include<iostream>
using namespace std;
int MOD = 1000000007;
// 거듭제곱을 모듈러 연산을 사용하여 계산
long long mod_exp(long long p, long long n) {
long long result = 1;
// n이 0이 될 때까지 지수를 줄임
while (n > 0) {
// 지수가 홀수 일때 p를 따로 곱함
if (n % 2 == 1) {
result = (result * p) % MOD;
}
p = (p * p) % MOD;
n /= 2;
}
return result;
}
int main() {
long long k, p, n;
cin >> k >> p >> n;
// P^N % MOD 계산
long long result = mod_exp(p, n);
// 최종 바이러스 수 계산
result = (k * result) % MOD;
cout << result << endl;
return 0;
}