update_hosts.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import sys
  5. import re
  6. import socket
  7. import ipaddress
  8. import getopt
  9. import threading
  10. import subprocess
  11. import shlex
  12. import time
  13. import select
  14. blackhole = (
  15. '10::2222',
  16. '21:2::2',
  17. '101::1234',
  18. '200:2:807:c62d::',
  19. '200:2:253d:369e::',
  20. '200:2:2e52:ae44::',
  21. '200:2:3b18:3ad::',
  22. '200:2:4e10:310f::',
  23. '200:2:5d2e:859::',
  24. '200:2:9f6a:794b::',
  25. '200:2:cb62:741::',
  26. '200:2:cc9b:953e::',
  27. '200:2:f3b9:bb27::',
  28. '2001::212',
  29. '2001:da8:112::21ae',
  30. '2003:ff:1:2:3:4:5fff:6',
  31. '2003:ff:1:2:3:4:5fff:7',
  32. '2003:ff:1:2:3:4:5fff:8',
  33. '2003:ff:1:2:3:4:5fff:9',
  34. '2003:ff:1:2:3:4:5fff:10',
  35. '2003:ff:1:2:3:4:5fff:11',
  36. '2003:ff:1:2:3:4:5fff:12',
  37. '2123::3e12',
  38. '3059:83eb::e015:2bee:0:0',
  39. '1.2.3.4',
  40. '4.36.66.178',
  41. '8.7.198.45',
  42. '37.61.54.158',
  43. '46.82.174.68',
  44. '59.24.3.173',
  45. '64.33.88.161',
  46. '78.16.49.15',
  47. '93.46.8.89',
  48. '127.0.0.1',
  49. '159.106.121.75',
  50. '202.181.7.85',
  51. '203.98.7.65',
  52. '243.185.187.39'
  53. )
  54. invalid_network = [
  55. '200:2::/32',
  56. '2001::/32', #TEREDO
  57. 'a000::/8',
  58. ]
  59. dns = {
  60. 'google_a': '2001:4860:4860::8888',
  61. 'google_b': '2001:4860:4860::8844',
  62. 'he_net': '2001:470:20::2',
  63. 'lax_he_net': '2001:470:0:9d::2'
  64. }
  65. config = {
  66. 'dns': dns['google_b'],
  67. 'infile': '',
  68. 'outfile': '',
  69. 'querytype': 'aaaa',
  70. 'cname': False,
  71. 'threadnum': 10
  72. }
  73. hosts = []
  74. done_num = 0
  75. thread_lock = threading.Lock()
  76. running = True
  77. class worker_thread(threading.Thread):
  78. def __init__(self, start_pt, end_pt):
  79. threading.Thread.__init__(self)
  80. self.start_pt = start_pt
  81. self.end_pt = end_pt
  82. def run(self):
  83. global hosts, done_num
  84. for i in range(self.start_pt, self.end_pt):
  85. if not running: break
  86. line = hosts[i].strip()
  87. if line == '' or line[0:2] == '##':
  88. hosts[i] = line + '\r\n'
  89. with thread_lock: done_num += 1
  90. continue
  91. # uncomment line
  92. line = line.lstrip('#')
  93. # split comment that appended to line
  94. comment = ''
  95. p = line.find('#')
  96. if p > 0:
  97. comment = line[p:]
  98. line = line[:p]
  99. arr = line.split()
  100. if len(arr) == 1:
  101. domain = arr[0]
  102. else:
  103. domain = arr[1]
  104. flag = False
  105. if validate_domain(domain):
  106. cname, ip = query_domain(domain, False)
  107. if ip == '' or ip in blackhole or invalid_address(ip):
  108. cname, ip = query_domain(domain, True)
  109. if ip:
  110. flag = True
  111. arr[0] = ip
  112. if len(arr) == 1:
  113. arr.append(domain)
  114. if config['cname'] and cname:
  115. arr.append('#' + cname)
  116. else:
  117. if comment:
  118. arr.append(comment)
  119. if not flag:
  120. arr[0] = '#' + arr[0]
  121. if comment:
  122. arr.append(comment)
  123. hosts[i] = ' '.join(arr)
  124. hosts[i] += '\r\n'
  125. with thread_lock: done_num += 1
  126. class watcher_thread(threading.Thread):
  127. def run(self):
  128. total_num = len(hosts)
  129. wn = int(config['threadnum'])
  130. if wn > total_num:
  131. wn = total_num
  132. print("There are %d threads working..." % wn)
  133. print("Press 'Enter' to exit.\n")
  134. while True:
  135. if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
  136. input()
  137. print("Waiting threads to exit...")
  138. global running
  139. with thread_lock:
  140. running = False
  141. break
  142. dn = done_num
  143. outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
  144. % (total_num, dn, dn * 100 / total_num)
  145. print(outbuf, end='', flush=True)
  146. if dn == total_num:
  147. print(outbuf)
  148. break
  149. time.sleep(1)
  150. def query_domain(domain, tcp):
  151. cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
  152. % (config['querytype'], config['dns'], domain)
  153. if tcp:
  154. cmd = cmd + ' +tcp'
  155. proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
  156. out, _ = proc.communicate()
  157. outarr = out.decode('utf-8').splitlines()
  158. cname = ip = ''
  159. for v in outarr:
  160. if cname == '' and validate_domain(v[:-1]):
  161. cname = v[:-1]
  162. if ip == '' and validate_ip_addr(v):
  163. ip = v
  164. break
  165. return (cname, ip)
  166. def validate_domain(domain):
  167. pattern = '^((?!-)[*A-Za-z0-9-]{1,63}(?<!-)\\.)+[A-Za-z]{2,6}$'
  168. p = re.compile(pattern)
  169. m = p.match(domain)
  170. if m:
  171. return True
  172. else:
  173. return False
  174. def validate_ip_addr(ip_addr):
  175. if ':' in ip_addr:
  176. try:
  177. socket.inet_pton(socket.AF_INET6, ip_addr)
  178. return True
  179. except socket.error:
  180. return False
  181. else:
  182. try:
  183. socket.inet_pton(socket.AF_INET, ip_addr)
  184. return True
  185. except socket.error:
  186. return False
  187. def invalid_address(ip_addr):
  188. address = ipaddress.ip_address(ip_addr)
  189. for cidr in invalid_network:
  190. if address in ipaddress.ip_network(cidr):
  191. return True
  192. else:
  193. return False
  194. def print_help():
  195. print('''usage: update_hosts [OPTIONS] FILE
  196. A simple multi-threading tool used for updating hosts file.
  197. Options:
  198. -h, --help show this help message and exit
  199. -s DNS set another dns server, default: 2001:4860:4860::8844
  200. -o OUT_FILE output file, default: inputfilename.out
  201. -t QUERY_TYPE dig command query type, default: aaaa
  202. -c, --cname write canonical name into hosts file
  203. -n THREAD_NUM set the number of worker threads, default: 10
  204. ''')
  205. def get_config():
  206. shortopts = 'hs:o:t:n:c'
  207. longopts = ['help', 'cname']
  208. try:
  209. optlist, args = getopt.gnu_getopt(sys.argv[1:], shortopts, longopts)
  210. except getopt.GetoptError as e:
  211. print(e)
  212. print_help()
  213. sys.exit(1)
  214. global config
  215. for key, value in optlist:
  216. if key == '-s':
  217. config['dns'] = value
  218. elif key == '-o':
  219. config['outfile'] = value
  220. elif key == '-t':
  221. config['querytype'] = value
  222. elif key in ('-c', '--cname'):
  223. config['cname'] = True
  224. elif key == '-n':
  225. config['threadnum'] = int(value)
  226. elif key in ('-h', '--help'):
  227. print_help()
  228. sys.exit(0)
  229. if len(args) != 1:
  230. print("You must specify the input hosts file (only one).")
  231. sys.exit(1)
  232. config['infile'] = args[0]
  233. if config['outfile'] == '':
  234. config['outfile'] = config['infile'] + '.out'
  235. def main():
  236. get_config()
  237. dig_path = '/usr/bin/dig'
  238. if not os.path.isfile(dig_path) or not os.access(dig_path, os.X_OK):
  239. print("It seems you don't have 'dig' command installed properly "\
  240. "on your system.")
  241. sys.exit(2)
  242. global hosts
  243. try:
  244. with open(config['infile'], 'r') as infile:
  245. hosts = infile.readlines()
  246. except IOError as e:
  247. print(e)
  248. sys.exit(e.errno)
  249. if os.path.exists(config['outfile']):
  250. config['outfile'] += '.new'
  251. try:
  252. outfile = open(config['outfile'], 'w')
  253. except IOError as e:
  254. print(e)
  255. sys.exit(e.errno)
  256. print("Input: %s Output: %s\n" % (config['infile'], config['outfile']))
  257. threads = []
  258. t = watcher_thread()
  259. t.start()
  260. threads.append(t)
  261. worker_num = config['threadnum']
  262. lines_num = len(hosts)
  263. lines_per_thread = lines_num // worker_num
  264. lines_remain = lines_num % worker_num
  265. start_pt = 0
  266. for _ in range(worker_num):
  267. if not running: break
  268. lines_for_thread = lines_per_thread
  269. if lines_for_thread == 0 and lines_remain == 0:
  270. break
  271. if lines_remain > 0:
  272. lines_for_thread += 1
  273. lines_remain -= 1
  274. t = worker_thread(start_pt, start_pt + lines_for_thread)
  275. start_pt += lines_for_thread
  276. t.start()
  277. threads.append(t)
  278. for t in threads:
  279. t.join()
  280. try:
  281. outfile.writelines(hosts)
  282. except IOError as e:
  283. print(e)
  284. sys.exit(e.errno)
  285. sys.exit(0)
  286. if __name__ == '__main__':
  287. main()