最近和朋友使用 Lightsplit 分帳覺得很方便,其中一個功能是使用者只須在裡面以筆為單位新增所有帳務,程式就會自動計算出每個人的結餘以及最少的轉帳次數及轉帳方式。
一開始覺得這個功能並沒有很難,但仔細想想後發現其實不容易,閒著沒事嘗試自己研究一下並記錄此篇文章。

實作時偷懶只想了一些簡單測資,不保證功能完全正確

思路

  1. 無論有多少筆帳,每個人最後的結餘只會有一個數字

    綜觀來看,每個人都會有一個付出的金額(先付的金額總和),以及得到的金額價值(讓別人先幫自己付),相減即是結餘,為正表應該收到別人的轉帳,為負則應該轉帳給別人

  2. 使用 Minimum Cost Maximum Flow (MCMF) 找出最小轉帳次數及方式

    想到之前上課學過 MCMF 演算法,利用結餘建立 edge,每個 edge 之 cost 表轉帳次數,感覺可以解決這個問題,便建立 MCMF 資料模型嘗試得出正確結果

資料轉換

以使用 Lightsplit 時做的設定可得到的資料模擬測資

1
2
3
4
4
1 1000 1 2 3 4
2 500 1 2
3 70 2 3
  • 第一個數字 4 表群組中共有 4 人 (編號 1 ~ 4)

  • 每筆帳務會有先付的人與金額以及平攤給誰,第二行開始以行為單位為所有帳務
    ex: 第一筆帳務 1 1000 1 2 3 4 表示 1 號先付 1000 元並平攤給 1, 2, 3, 4 號 (情境可能是 4 個人一起吃飯,其中一位先付款)

總和每筆帳務可得到

付出的金額 得到的金額價值 結餘
1 1000 500 500
2 500 535 -35
3 70 215 -215
4 0 250 -250

建立 MCMF 的 Capacity 模型並新增 super source 及 sink,以此測資為例分別為 0 與 9 號 node
為了方便建邊,將每位使用者視為兩個 node 放在中間

capacity

  • 若結餘為負則在 2 ~ 3 排 nodes 間建邊,表可以轉帳的金額 (如 2 號需付 35 元,則它最多可以轉給 1, 3, 4 號 35 元)
  • 若結餘為正表可以收到的金額,則在 3 ~ 4 排 nodes 建邊 (如 1 號可收到 500 元)

顯而易見地,也許會發生一個問題,如 2 號可能同時轉給 1 號 35 元又轉給 3 號 35 元,因此調整 super source 至使用者的 capacity 避免此問題

fixedCapacity

將資料轉換後將其視為 MCMF 的問題去解即可

執行以下程式可得出結果

result

簡單實作

將情況簡化成每筆金額都可整除,依照以上思路簡單實作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# 第一行為所有人數
# 之後每一行: 付錢的人 付錢金額 平攤給...(人編號從 1 開始)
# ex: 1 1000 2 3 4 -> 1 先付 1000 平攤給 1, 2, 3, 4(一人 250)
TEST_DATA = '''
4
1 1000 1 2 3 4
2 500 1 2
3 70 2 3
'''
IS_PRINT_DETAIL = False
PRINT_SPACE = 9
INF = int(1e10)

sink = 0
totalPeople = 0
edges = []
capacity = []
flow = []
cost = []
distance = []
parents = []
inQueue = []

balance = []


def init_data(totalPeople):
# 0: super source
# totalPeople * 2 + 1: sink
arraySize = totalPeople * 2 + 2

global edges, capacity, flow, cost, distance, parents, inQueue
edges = [[] for _ in range(arraySize)]
capacity = [[0] * arraySize for _ in range(arraySize)]
flow = [[0] * arraySize for _ in range(arraySize)]
cost = [[0] * arraySize for _ in range(arraySize)]
distance = [0] * arraySize
parents = [0] * arraySize
inQueue = [False] * arraySize


def process_input(inputData):
global sink, edges, capacity, flow, cost, distance, parents, inQueue
global totalPeople, balance

inputDataSplit = inputData.strip().split('\n')
totalPeople = int(inputDataSplit[0])
sink = totalPeople * 2 + 1
balance = [0] * (totalPeople + 1) # 值為正: 最後應收回多少錢, 值為負: 最後應付出多少錢
init_data(totalPeople)

# 連接 super source 至各節點
for i in range(1, totalPeople + 1):
capacity[0][i] = INF

# cost 紀錄轉帳次數
for i in range(1, totalPeople + 1):
for j in range(totalPeople + 1, totalPeople * 2 + 1):
cost[i][j] = 1

pay = [0] * (totalPeople + 1) # 先付多少錢
get = [0] * (totalPeople + 1) # 獲得價值多少錢
for line in inputDataSplit[1:]:
lst = line.split(' ')
p1 = int(lst[0]) # 先付的人
money = int(lst[1]) # 先付的金額
splitMoney = int(money / (len(lst) - 2)) # 平攤後每個人獲得的價值(整數)
people = list(map(int, lst[2:])) # 平攤給

pay[p1] += money
for p in people:
get[p] += splitMoney

# 應收: 先付 - 獲得
for i in range(len(balance)):
balance[i] = pay[i] - get[i]

# 調整 super source 至各節點的 capacity
for i in range(1, totalPeople + 1):
if balance[i] < 0:
capacity[0][i] = -balance[i]

if IS_PRINT_DETAIL:
print(f'{"pay:":>{PRINT_SPACE}}', pay)
print(f'{"get:":>{PRINT_SPACE}}', get)
print(f'{"balance:":>{PRINT_SPACE}}', balance)
print()

# 建邊
for i in range(1, totalPeople + 1):
edges[0].append(i)

if balance[i] >= 0:
edges[i + totalPeople].append(sink)
capacity[i + totalPeople][sink] = balance[i]
continue

for j in range(totalPeople + 1, totalPeople * 2 + 1):
if i + totalPeople != j:
edges[i].append(j)
capacity[i][j] = -balance[i]

if IS_PRINT_DETAIL:
print(f'{"capacity:":>{PRINT_SPACE}}')
for i in capacity:
print(' ' * PRINT_SPACE, i)

print(f'{"edges:":>{PRINT_SPACE}}')
for i in edges:
print(' ' * PRINT_SPACE, i)


def SPFA():
global sink, edges, capacity, flow, cost, distance, parents, inQueue

distance = [INF] * len(distance)
distance[0] = 0

inQueue = [False] * len(inQueue)

que = [0]
inQueue[0] = True

while que:
u = que.pop(0)
inQueue[u] = False

for v in edges[u]:
if capacity[u][v] > flow[u][v] and distance[u] + cost[u][v] < distance[v]:
distance[v] = distance[u] + cost[u][v]
parents[v] = u
if not inQueue[v]:
que.append(v)
inQueue[v] = True
if flow[v][u] > 0 and distance[u] + (-cost[v][u] < distance[v]):
distance[v] = distance[u] + (-cost[v][u])
parents[v] = u
if not inQueue[v]:
que.append(v)
inQueue[v] = True

return distance[sink] != INF


def augment(u, v, bottleNeck):
global parents, capacity, flow

if v == 0:
return bottleNeck
bottleNeck = augment(parents[u], u, min(capacity[u][v] - flow[u][v], bottleNeck))
flow[u][v] += bottleNeck
flow[v][u] -= bottleNeck
return bottleNeck


def MCMF():
global edges, capacity, flow, cost, distance, parents, inQueue, sink
global totalPeople, balance

times = 0
while SPFA():
augment(parents[sink], sink, INF)
times += 1

for i, bal in enumerate(balance[1:]):
pay, get = 0, 0
if bal < 0:
pay = -bal
elif bal > 0:
get = bal
print(f'{i + 1:2} 號應付 {pay:4} 元,應得 {get:4} 元')

print('\n最少轉帳次數:', times)
for i in range(totalPeople + 1, totalPeople * 2 + 1):
for j in range(1, totalPeople + 1):
if flow[i][j] < 0:
print(f'{j:2} 號需給 {i - totalPeople:2}{-flow[i][j]:4} 元')


def main():
process_input(TEST_DATA)
MCMF()


if __name__ == '__main__':
main()