update_hosts.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import sys
  5. import re
  6. import socket
  7. import getopt
  8. import threading
  9. import subprocess
  10. import shlex
  11. import time
  12. import select
  13. blackhole = (
  14. '10::2222',
  15. '21:2::2',
  16. '101::1234',
  17. '200:2:807:c62d::',
  18. '200:2:253d:369e::',
  19. '200:2:2e52:ae44::',
  20. '200:2:3b18:3ad::',
  21. '200:2:4e10:310f::',
  22. '200:2:5d2e:859::',
  23. '200:2:9f6a:794b::',
  24. '200:2:cb62:741::',
  25. '200:2:cc9b:953e::',
  26. '200:2:f3b9:bb27::',
  27. '2001::212',
  28. '2001:da8:112::21ae',
  29. '2003:ff:1:2:3:4:5fff:6',
  30. '2003:ff:1:2:3:4:5fff:7',
  31. '2003:ff:1:2:3:4:5fff:8',
  32. '2003:ff:1:2:3:4:5fff:9',
  33. '2003:ff:1:2:3:4:5fff:10',
  34. '2003:ff:1:2:3:4:5fff:11',
  35. '2003:ff:1:2:3:4:5fff:12',
  36. '2123::3e12',
  37. '3059:83eb::e015:2bee:0:0',
  38. '1.2.3.4',
  39. '4.36.66.178',
  40. '8.7.198.45',
  41. '37.61.54.158',
  42. '46.82.174.68',
  43. '59.24.3.173',
  44. '64.33.88.161',
  45. '78.16.49.15',
  46. '93.46.8.89',
  47. '127.0.0.1',
  48. '159.106.121.75',
  49. '202.181.7.85',
  50. '203.98.7.65',
  51. '243.185.187.39'
  52. )
  53. dns = {
  54. 'google_a': '2001:4860:4860::8888',
  55. 'google_b': '2001:4860:4860::8844',
  56. 'he_net': '2001:470:20::2',
  57. 'lax_he_net': '2001:470:0:9d::2'
  58. }
  59. config = {
  60. 'dns': dns['google_b'],
  61. 'infile': '',
  62. 'outfile': '',
  63. 'querytype': 'aaaa',
  64. 'cname': False,
  65. 'threadnum': 10
  66. }
  67. hosts = []
  68. done_num = 0
  69. thread_lock = threading.Lock()
  70. running = True
  71. class worker_thread(threading.Thread):
  72. def __init__(self, start_pt, end_pt):
  73. threading.Thread.__init__(self)
  74. self.start_pt = start_pt
  75. self.end_pt = end_pt
  76. def run(self):
  77. global hosts, done_num
  78. for i in range(self.start_pt, self.end_pt):
  79. if not running: break
  80. line = hosts[i].strip()
  81. if line == '' or line[0:2] == '##':
  82. hosts[i] = line + '\r\n'
  83. with thread_lock: done_num += 1
  84. continue
  85. # uncomment line
  86. line = line.lstrip('#')
  87. # split comment that appended to line
  88. comment = ''
  89. p = line.find('#')
  90. if p > 0:
  91. comment = line[p:]
  92. line = line[:p]
  93. arr = line.split()
  94. if len(arr) == 1:
  95. domain = arr[0]
  96. else:
  97. domain = arr[1]
  98. flag = False
  99. if validate_domain(domain):
  100. cname, ip = query_domain(domain, False)
  101. if ip == '' or ip in blackhole:
  102. cname, ip = query_domain(domain, True)
  103. if ip:
  104. flag = True
  105. arr[0] = ip
  106. if len(arr) == 1:
  107. arr.append(domain)
  108. if config['cname'] and cname:
  109. arr.append('#' + cname)
  110. else:
  111. if comment:
  112. arr.append(comment)
  113. if not flag:
  114. arr[0] = '#' + arr[0]
  115. if comment:
  116. arr.append(comment)
  117. hosts[i] = ' '.join(arr)
  118. hosts[i] += '\r\n'
  119. with thread_lock: done_num += 1
  120. class watcher_thread(threading.Thread):
  121. def run(self):
  122. total_num = len(hosts)
  123. wn = int(config['threadnum'])
  124. if wn > total_num:
  125. wn = total_num
  126. print "There are %d threads working..." % wn
  127. print "Press 'Enter' to exit.\n"
  128. while True:
  129. if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
  130. raw_input()
  131. print 'Waiting threads to exit...'
  132. global running
  133. with thread_lock:
  134. running = False
  135. break
  136. dn = done_num
  137. outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
  138. % (total_num, dn, dn * 100 / total_num)
  139. print outbuf,
  140. sys.stdout.flush()
  141. if dn == total_num:
  142. print outbuf
  143. break
  144. time.sleep(1)
  145. def query_domain(domain, tcp):
  146. cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
  147. % (config['querytype'], config['dns'], domain)
  148. if tcp:
  149. cmd = cmd + ' +tcp'
  150. proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
  151. out, _ = proc.communicate()
  152. outarr = out.splitlines()
  153. cname = ip = ''
  154. for v in outarr:
  155. if cname == '' and validate_domain(v[:-1]):
  156. cname = v[:-1]
  157. if ip == '' and validate_ip_addr(v):
  158. ip = v
  159. break
  160. return (cname, ip)
  161. def validate_domain(domain):
  162. pattern = '^((?!-)[*A-Za-z0-9-]{1,63}(?<!-)\\.)+[A-Za-z]{2,6}$'
  163. p = re.compile(pattern)
  164. m = p.match(domain)
  165. if m:
  166. return True
  167. else:
  168. return False
  169. def validate_ip_addr(ip_addr):
  170. if ':' in ip_addr:
  171. try:
  172. socket.inet_pton(socket.AF_INET6, ip_addr)
  173. return True
  174. except socket.error:
  175. return False
  176. else:
  177. try:
  178. socket.inet_pton(socket.AF_INET, ip_addr)
  179. return True
  180. except socket.error:
  181. return False
  182. def print_help():
  183. print '''usage: update_hosts [OPTIONS] FILE
  184. A simple multi-threading tool used for updating hosts file.
  185. Options:
  186. -h, --help show this help message and exit
  187. -s DNS set another dns server, default: 2001:4860:4860::8844
  188. -o OUT_FILE output file, default: inputfilename.out
  189. -t QUERY_TYPE dig command query type, default: aaaa
  190. -c, --cname write canonical name into hosts file
  191. -n THREAD_NUM set the number of worker threads, default: 10
  192. '''
  193. def get_config():
  194. shortopts = 'hs:o:t:n:c'
  195. longopts = ['help', 'cname']
  196. try:
  197. optlist, args = getopt.gnu_getopt(sys.argv[1:], shortopts, longopts)
  198. except getopt.GetoptError as e:
  199. print e, '\n'
  200. print_help()
  201. sys.exit(1)
  202. global config
  203. for key, value in optlist:
  204. if key == '-s':
  205. config['dns'] = value
  206. elif key == '-o':
  207. config['outfile'] = value
  208. elif key == '-t':
  209. config['querytype'] = value
  210. elif key in ('-c', '--cname'):
  211. config['cname'] = True
  212. elif key == '-n':
  213. config['threadnum'] = int(value)
  214. elif key in ('-h', '--help'):
  215. print_help()
  216. sys.exit(0)
  217. if len(args) != 1:
  218. print "You must specify the input hosts file (only one)."
  219. sys.exit(1)
  220. config['infile'] = args[0]
  221. if config['outfile'] == '':
  222. config['outfile'] = config['infile'] + '.out'
  223. def main():
  224. get_config()
  225. dig_path = '/usr/bin/dig'
  226. if not os.path.isfile(dig_path) or not os.access(dig_path, os.X_OK):
  227. print "It seems you don't have 'dig' command installed properly "\
  228. "on your system."
  229. sys.exit(2)
  230. global hosts
  231. try:
  232. with open(config['infile'], 'r') as infile:
  233. hosts = infile.readlines()
  234. except IOError as e:
  235. print e
  236. sys.exit(e.errno)
  237. if os.path.exists(config['outfile']):
  238. config['outfile'] += '.new'
  239. try:
  240. outfile = open(config['outfile'], 'w')
  241. except IOError as e:
  242. print e
  243. sys.exit(e.errno)
  244. print "Input: %s Output: %s\n" % (config['infile'], config['outfile'])
  245. threads = []
  246. t = watcher_thread()
  247. t.start()
  248. threads.append(t)
  249. worker_num = config['threadnum']
  250. lines_num = len(hosts)
  251. lines_per_thread = lines_num / worker_num
  252. lines_remain = lines_num % worker_num
  253. start_pt = 0
  254. for _ in range(worker_num):
  255. if not running: break
  256. lines_for_thread = lines_per_thread
  257. if lines_for_thread == 0 and lines_remain == 0:
  258. break
  259. if lines_remain > 0:
  260. lines_for_thread += 1
  261. lines_remain -= 1
  262. t = worker_thread(start_pt, start_pt + lines_for_thread)
  263. start_pt += lines_for_thread
  264. t.start()
  265. threads.append(t)
  266. for t in threads:
  267. t.join()
  268. try:
  269. outfile.writelines(hosts)
  270. except IOError as e:
  271. print e
  272. sys.exit(e.errno)
  273. sys.exit(0)
  274. if __name__ == '__main__':
  275. main()