PS/BOJ (Baekjoon Online Judge)

[BOJ_Platinum 2] 12995 - 트리나라 [python]

까진호두 2022. 3. 20. 18:38

12995번: 트리나라

트리나라는 N개의 도시로 이루어져 있고, 각각의 도시는 1번부터 N번까지 번호가 매겨져 있다. 트리나라의 도로 체계는 트리를 이룬다. 즉, 트리나라에는 N-1개의 양방향도로가 있다. 또, 모두 연

www.acmicpc.net

문제

트리나라는 N개의 도시로 이루어져 있고, 각각의 도시는 1번부터 N번까지 번호가 매겨져 있다. 트리나라의 도로 체계는 트리를 이룬다. 즉, 트리나라에는 N-1개의 양방향도로가 있다. 또, 모두 연결되어 있기 때문에, 임의의 두 도시 사이를 항상 오갈 수 있다.

스타트링크의 직원 K명은 트리나라로 이사를 가려고 한다. 모든 직원은 서로 다른 도시로 이사를 가야한다. 즉, 이사할 도시 K개를 선택해야 한다. 이사할 도시에는 중요한 조건이 하나 있는데, 모든 직원이 사는 도시는 연결되어 있어야 한다는 점이다. 예를 들어, 임의의 두 직원 사는 도시가 i와 j라면, i와 j를 연결하는 경로상에 있는 도시에도 직원이 살고 있어야 한다는 점이다.

트리나라의 트리 구조가 주어졌을 때, 이사할 도시 K개를 고르는 방법의 수를 구하는 프로그램을 작성하시오.

입력

첫째 줄에 도시의 수 N과 스타트링크 직원의 수 K가 주어진다. (2 ≤ N ≤ 50, 1 ≤ K ≤ N)

둘째 줄부터 N-1개의 줄에는 도로 정보가 주어진다.

출력

첫째 줄에 도시 K개를 선택하는 방법의 수를 1,000,000,007로 나눈 나머지를 출력한다. 


트리디피를 활용해 풀 수 있는 문제입니다.

DP[N][K] - 노드 N을 루트로 하는 서브트리에서 K개의 노드를 선택하는 경우의 수

라고 정의하면, 트리디피로 문제를 해결할 수 있습니다.

def dfs(node,parent):
    dp[node][1] = 1  #자기 자신 선택하기

    for nextnode in graph[node]:
        if nextnode == parent: continue

        dfs(nextnode,node)  #자식 노드 호출
        nodeCount[node] += nodeCount[nextnode]  #노드 값 추가

        if nodeCount[node] == nodeCount[nextnode] + 1:  #처음인 경우
            for i in range(nodeCount[nextnode] + 1):
                dp[node][i + 1] = (dp[node][i + 1] + dp[nextnode][i]) % MOD
        else:
            tmp = [0] * 51
            for i in range(nodeCount[node],0,-1):  #두개까지 보면 됨
                for j in range(1,nodeCount[nextnode]+1):
                    if i + j > nodeCount[node]: break
                    tmp[i + j] = (tmp[i + j] + (dp[node][i] * dp[nextnode][j]) % MOD) % MOD

            for i in range(nodeCount[node] + 1):
                dp[node][i] = (dp[node][i] + tmp[i]) % MOD

dfs 함수는 위와 같이 정의됩니다.
임의의 노드 X는 자신의 자식들을 순서대로 재귀호출하며, 그렇게 얻어낸 자식들의 DP값을 통해
자신을 루트로 하여 만들어지는, 자신을 포함한 (중요!)K개의 노드로 이루어진 모든 경우의 수를 연산합니다.

n,k = map(int,input().split())
graph = [[] for _ in range(n+1)]
dp = [[0] * 51 for _ in range(n+1)]
nodeCount = [1] * (n + 1)
MOD = 1000000007

for _ in range(n-1):
    a,b = map(int,input().split())
    graph[a].append(b)
    graph[b].append(a)

def dfs(node,parent):
    dp[node][1] = 1  #자기 자신 선택하기

    for nextnode in graph[node]:
        if nextnode == parent: continue

        dfs(nextnode,node)  #자식 노드 호출
        nodeCount[node] += nodeCount[nextnode]  #노드 값 추가

        if nodeCount[node] == nodeCount[nextnode] + 1:  #처음인 경우
            for i in range(nodeCount[nextnode] + 1):
                dp[node][i + 1] = (dp[node][i + 1] + dp[nextnode][i]) % MOD
        else:
            tmp = [0] * 51
            for i in range(nodeCount[node],0,-1):  #두개까지 보면 됨
                for j in range(1,nodeCount[nextnode]+1):
                    if i + j > nodeCount[node]: break
                    tmp[i + j] = (tmp[i + j] + (dp[node][i] * dp[nextnode][j]) % MOD) % MOD

            for i in range(nodeCount[node] + 1):
                dp[node][i] = (dp[node][i] + tmp[i]) % MOD
        #print(node,dp[node][:nodeCount[node] + 1])
dfs(1,-1)
ans = 0
for i in range(1,n+1):
    if nodeCount[i] < k: continue
    ans += dp[i][k]
print(ans)

각각의 노드들을 한번씩 방문하며 $nodeCount[x]^2$만큼의 반복문을 (전체 노드의 수 -1)회 만큼 반복하므로,
전체 시간 복잡도는 $O(N^3)$정도라고 볼 수 있겠습니다.

728x90