update_hosts.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import sys
  5. import re
  6. import socket
  7. import getopt
  8. import threading
  9. import subprocess
  10. import shlex
  11. import time
  12. import select
  13. blackhole = (
  14. '10::2222',
  15. '101::1234',
  16. '2001::212',
  17. '2001:da8:112::21ae',
  18. '2003:ff:1:2:3:4:5fff:6',
  19. '2003:ff:1:2:3:4:5fff:7',
  20. '2003:ff:1:2:3:4:5fff:8',
  21. '2003:ff:1:2:3:4:5fff:9',
  22. '2003:ff:1:2:3:4:5fff:10',
  23. '2003:ff:1:2:3:4:5fff:11',
  24. '2003:ff:1:2:3:4:5fff:12',
  25. '21:2::2',
  26. '2123::3e12')
  27. dns = {
  28. 'google_a':'2001:4860:4860::8888',
  29. 'google_b':'2001:4860:4860::8844',
  30. 'he_net':'2001:470:20::2',
  31. 'lax_he_net':'2001:470:0:9d::2'
  32. }
  33. config = {
  34. 'dns':dns['google_b'],
  35. 'infile':'',
  36. 'outfile':'',
  37. 'querytype':'aaaa',
  38. 'cname': False,
  39. 'threadnum':10
  40. }
  41. hosts = []
  42. done_num = 0
  43. thread_lock = threading.Lock()
  44. running = True
  45. class worker_thread(threading.Thread):
  46. def __init__(self, start_pt, end_pt):
  47. threading.Thread.__init__(self)
  48. self.start_pt = start_pt
  49. self.end_pt = end_pt
  50. def run(self):
  51. global hosts, done_num
  52. for i in range(self.start_pt, self.end_pt):
  53. if not running: break
  54. line = hosts[i].strip()
  55. if line == '' or line[0:2] == '##':
  56. hosts[i] = line + '\r\n'
  57. with thread_lock: done_num += 1
  58. continue
  59. # uncomment line
  60. line = line.lstrip('#')
  61. # split comment that appended to line
  62. comment = ''
  63. p = line.find('#')
  64. if p > 0:
  65. comment = line[p:]
  66. line = line[:p]
  67. arr = line.split()
  68. if len(arr) == 1:
  69. domain = arr[0]
  70. else:
  71. domain = arr[1]
  72. flag = False
  73. if validate_domain(domain):
  74. cname, ip = query_domain(domain, False)
  75. if ip == '' or ip in blackhole:
  76. cname, ip = query_domain(domain, True)
  77. if ip:
  78. flag = True
  79. arr[0] = ip
  80. if len(arr) == 1:
  81. arr.append(domain)
  82. if config['cname'] and cname:
  83. arr.append('#' + cname)
  84. else:
  85. arr.append(comment)
  86. if not flag:
  87. arr[0] = '#' + arr[0]
  88. arr.append(comment)
  89. hosts[i] = ' '.join(arr)
  90. hosts[i] += '\r\n'
  91. with thread_lock: done_num += 1
  92. class watcher_thread(threading.Thread):
  93. def run(self):
  94. total_num = len(hosts)
  95. wn = int(config['threadnum'])
  96. if wn > total_num:
  97. wn = total_num
  98. print "There are %d threads working..." % wn
  99. print "Press 'Enter' to exit.\n"
  100. while True:
  101. if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
  102. t = raw_input()
  103. global running
  104. with thread_lock:
  105. running = False
  106. print 'Waiting threads to exit...'
  107. break
  108. with thread_lock:
  109. dn = done_num
  110. outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
  111. % (total_num, dn, dn * 100 / total_num)
  112. print outbuf,
  113. sys.stdout.flush()
  114. if done_num == total_num:
  115. print outbuf
  116. break
  117. time.sleep(1)
  118. def query_domain(domain, tcp):
  119. cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
  120. % (config['querytype'], config['dns'], domain)
  121. if tcp:
  122. cmd = cmd[:3] + ' +tcp' + cmd[3:]
  123. proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
  124. out, err = proc.communicate()
  125. outarr = out.splitlines()
  126. cname = ip = ''
  127. if len(outarr) == 0:
  128. return ()
  129. else:
  130. for v in outarr:
  131. if cname == '' and validate_domain(v[:-1]):
  132. cname = v[:-1]
  133. if ip == '' and validate_ip_addr(v):
  134. ip = v
  135. break
  136. return (cname, ip)
  137. def validate_domain(domain):
  138. pattern = '^((?!-)[*A-Za-z0-9-]{1,63}(?<!-)\\.)+[A-Za-z]{2,6}$'
  139. p = re.compile(pattern)
  140. m = p.match(domain)
  141. if m:
  142. return True
  143. else:
  144. return False
  145. def validate_ip_addr(ip_addr):
  146. if ':' in ip_addr:
  147. try:
  148. socket.inet_pton(socket.AF_INET6, ip_addr)
  149. return True
  150. except socket.error:
  151. return False
  152. else:
  153. try:
  154. socket.inet_pton(socket.AF_INET, ip_addr)
  155. return True
  156. except socket.error:
  157. return False
  158. def print_help():
  159. print('''usage: update_hosts [OPTIONS] FILE
  160. A simple multi-threading tool used to update hosts file.
  161. Options:
  162. -h, --help show this help message and exit
  163. -s DNS set another dns server, default: 2001:4860:4860::8844
  164. -o OUT_FILE ouput file, default: inputfilename.out
  165. -t QUERY_TYPE dig command query type, defalut: aaaa
  166. -c, --cname write canonical name into hosts file
  167. -n THREAD_NUM set the number of worker threads, default: 10
  168. ''')
  169. def get_config():
  170. shortopts = 'hs:o:t:n:c'
  171. longopts = ['help', 'cname']
  172. try:
  173. optlist, args = getopt.gnu_getopt(sys.argv[1:], shortopts, longopts)
  174. except getopt.GetoptError as e:
  175. print e, '\n'
  176. print_help()
  177. sys.exit(1)
  178. global config
  179. for key, value in optlist:
  180. if key == '-s':
  181. config['dns'] = value
  182. elif key == '-o':
  183. config['outfile'] = value
  184. elif key == '-t':
  185. config['querytype'] = value
  186. elif key in ('-c', '--cname'):
  187. config['cname'] = True
  188. elif key == '-n':
  189. config['threadnum'] = int(value)
  190. elif key in ('-h', '--help'):
  191. print_help()
  192. sys.exit(0)
  193. if len(args) != 1:
  194. print "You must specify the input hosts file (only one)."
  195. sys.exit(1)
  196. config['infile'] = args[0]
  197. if config['outfile'] == '':
  198. config['outfile'] = config['infile'] + '.out'
  199. def main():
  200. get_config()
  201. dig_path = '/usr/bin/dig'
  202. if not os.path.isfile(dig_path) or not os.access(dig_path, os.X_OK):
  203. print "It seems you don't have 'dig' command installed properly "\
  204. "on your system."
  205. sys.exit(2)
  206. global hosts
  207. try:
  208. with open(config['infile'], 'r') as infile:
  209. hosts = infile.readlines()
  210. except IOError as e:
  211. print e
  212. sys.exit(e.errno)
  213. if os.path.exists(config['outfile']):
  214. config['outfile'] += '.new'
  215. try:
  216. outfile = open(config['outfile'], 'w')
  217. except IOError as e:
  218. print e
  219. sys.exit(e.errno)
  220. print "Input: %s Output: %s\n" % (config['infile'], config['outfile'])
  221. threads = []
  222. t = watcher_thread()
  223. t.start()
  224. threads.append(t)
  225. worker_num = config['threadnum']
  226. lines_num = len(hosts)
  227. lines_per_thread = lines_num / worker_num
  228. lines_remain = lines_num % worker_num
  229. start_pt = 0
  230. for i in range(worker_num):
  231. if not running: break
  232. lines_for_thread = lines_per_thread
  233. if lines_for_thread == 0 and lines_remain == 0:
  234. break
  235. if lines_remain > 0:
  236. lines_for_thread += 1
  237. lines_remain -= 1
  238. t = worker_thread(start_pt, start_pt + lines_for_thread)
  239. start_pt += lines_for_thread
  240. t.start()
  241. threads.append(t)
  242. for t in threads:
  243. t.join()
  244. try:
  245. outfile.writelines(hosts)
  246. except IOError as e:
  247. print e
  248. sys.exit(e.errno)
  249. sys.exit(0)
  250. if __name__ == '__main__':
  251. main()