update_hosts.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import sys
  5. import re
  6. import socket
  7. import getopt
  8. import threading
  9. import subprocess
  10. import shlex
  11. import time
  12. import select
  13. blackhole = (
  14. '10::2222',
  15. '101::1234',
  16. '2001::212',
  17. '2001:da8:112::21ae',
  18. '2003:ff:1:2:3:4:5fff:6',
  19. '2003:ff:1:2:3:4:5fff:7',
  20. '2003:ff:1:2:3:4:5fff:8',
  21. '2003:ff:1:2:3:4:5fff:9',
  22. '2003:ff:1:2:3:4:5fff:10',
  23. '2003:ff:1:2:3:4:5fff:11',
  24. '2003:ff:1:2:3:4:5fff:12',
  25. '21:2::2',
  26. '2123::3e12',
  27. '1.2.3.4',
  28. '127.0.0.1',
  29. '159.106.121.75',
  30. '202.181.7.85',
  31. '203.98.7.65',
  32. '243.185.187.39',
  33. '37.61.54.158',
  34. '4.36.66.178',
  35. '46.82.174.68',
  36. '59.24.3.173',
  37. '64.33.88.161',
  38. '78.16.49.15',
  39. '8.7.198.45',
  40. '93.46.8.89',
  41. )
  42. dns = {
  43. 'google_a': '2001:4860:4860::8888',
  44. 'google_b': '2001:4860:4860::8844',
  45. 'he_net': '2001:470:20::2',
  46. 'lax_he_net': '2001:470:0:9d::2'
  47. }
  48. config = {
  49. 'dns': dns['google_b'],
  50. 'infile': '',
  51. 'outfile': '',
  52. 'querytype': 'aaaa',
  53. 'cname': False,
  54. 'threadnum': 10
  55. }
  56. hosts = []
  57. done_num = 0
  58. thread_lock = threading.Lock()
  59. running = True
  60. class worker_thread(threading.Thread):
  61. def __init__(self, start_pt, end_pt):
  62. threading.Thread.__init__(self)
  63. self.start_pt = start_pt
  64. self.end_pt = end_pt
  65. def run(self):
  66. global hosts, done_num
  67. for i in range(self.start_pt, self.end_pt):
  68. if not running: break
  69. line = hosts[i].strip()
  70. if line == '' or line[0:2] == '##':
  71. hosts[i] = line + '\r\n'
  72. with thread_lock: done_num += 1
  73. continue
  74. # uncomment line
  75. line = line.lstrip('#')
  76. # split comment that appended to line
  77. comment = ''
  78. p = line.find('#')
  79. if p > 0:
  80. comment = line[p:]
  81. line = line[:p]
  82. arr = line.split()
  83. if len(arr) == 1:
  84. domain = arr[0]
  85. else:
  86. domain = arr[1]
  87. flag = False
  88. if validate_domain(domain):
  89. cname, ip = query_domain(domain, False)
  90. if ip == '' or ip in blackhole:
  91. cname, ip = query_domain(domain, True)
  92. if ip:
  93. flag = True
  94. arr[0] = ip
  95. if len(arr) == 1:
  96. arr.append(domain)
  97. if config['cname'] and cname:
  98. arr.append('#' + cname)
  99. else:
  100. if comment:
  101. arr.append(comment)
  102. if not flag:
  103. arr[0] = '#' + arr[0]
  104. if comment:
  105. arr.append(comment)
  106. hosts[i] = ' '.join(arr)
  107. hosts[i] += '\r\n'
  108. with thread_lock: done_num += 1
  109. class watcher_thread(threading.Thread):
  110. def run(self):
  111. total_num = len(hosts)
  112. wn = int(config['threadnum'])
  113. if wn > total_num:
  114. wn = total_num
  115. print "There are %d threads working..." % wn
  116. print "Press 'Enter' to exit.\n"
  117. while True:
  118. if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
  119. raw_input()
  120. print 'Waiting threads to exit...'
  121. global running
  122. with thread_lock:
  123. running = False
  124. break
  125. dn = done_num
  126. outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
  127. % (total_num, dn, dn * 100 / total_num)
  128. print outbuf,
  129. sys.stdout.flush()
  130. if dn == total_num:
  131. print outbuf
  132. break
  133. time.sleep(1)
  134. def query_domain(domain, tcp):
  135. cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
  136. % (config['querytype'], config['dns'], domain)
  137. if tcp:
  138. cmd = cmd + ' +tcp'
  139. proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
  140. out, _ = proc.communicate()
  141. outarr = out.splitlines()
  142. cname = ip = ''
  143. for v in outarr:
  144. if cname == '' and validate_domain(v[:-1]):
  145. cname = v[:-1]
  146. if ip == '' and validate_ip_addr(v):
  147. ip = v
  148. break
  149. return (cname, ip)
  150. def validate_domain(domain):
  151. pattern = '^((?!-)[*A-Za-z0-9-]{1,63}(?<!-)\\.)+[A-Za-z]{2,6}$'
  152. p = re.compile(pattern)
  153. m = p.match(domain)
  154. if m:
  155. return True
  156. else:
  157. return False
  158. def validate_ip_addr(ip_addr):
  159. if ':' in ip_addr:
  160. try:
  161. socket.inet_pton(socket.AF_INET6, ip_addr)
  162. return True
  163. except socket.error:
  164. return False
  165. else:
  166. try:
  167. socket.inet_pton(socket.AF_INET, ip_addr)
  168. return True
  169. except socket.error:
  170. return False
  171. def print_help():
  172. print '''usage: update_hosts [OPTIONS] FILE
  173. A simple multi-threading tool used to update hosts file.
  174. Options:
  175. -h, --help show this help message and exit
  176. -s DNS set another dns server, default: 2001:4860:4860::8844
  177. -o OUT_FILE ouput file, default: inputfilename.out
  178. -t QUERY_TYPE dig command query type, defalut: aaaa
  179. -c, --cname write canonical name into hosts file
  180. -n THREAD_NUM set the number of worker threads, default: 10
  181. '''
  182. def get_config():
  183. shortopts = 'hs:o:t:n:c'
  184. longopts = ['help', 'cname']
  185. try:
  186. optlist, args = getopt.gnu_getopt(sys.argv[1:], shortopts, longopts)
  187. except getopt.GetoptError as e:
  188. print e, '\n'
  189. print_help()
  190. sys.exit(1)
  191. global config
  192. for key, value in optlist:
  193. if key == '-s':
  194. config['dns'] = value
  195. elif key == '-o':
  196. config['outfile'] = value
  197. elif key == '-t':
  198. config['querytype'] = value
  199. elif key in ('-c', '--cname'):
  200. config['cname'] = True
  201. elif key == '-n':
  202. config['threadnum'] = int(value)
  203. elif key in ('-h', '--help'):
  204. print_help()
  205. sys.exit(0)
  206. if len(args) != 1:
  207. print "You must specify the input hosts file (only one)."
  208. sys.exit(1)
  209. config['infile'] = args[0]
  210. if config['outfile'] == '':
  211. config['outfile'] = config['infile'] + '.out'
  212. def main():
  213. get_config()
  214. dig_path = '/usr/bin/dig'
  215. if not os.path.isfile(dig_path) or not os.access(dig_path, os.X_OK):
  216. print "It seems you don't have 'dig' command installed properly "\
  217. "on your system."
  218. sys.exit(2)
  219. global hosts
  220. try:
  221. with open(config['infile'], 'r') as infile:
  222. hosts = infile.readlines()
  223. except IOError as e:
  224. print e
  225. sys.exit(e.errno)
  226. if os.path.exists(config['outfile']):
  227. config['outfile'] += '.new'
  228. try:
  229. outfile = open(config['outfile'], 'w')
  230. except IOError as e:
  231. print e
  232. sys.exit(e.errno)
  233. print "Input: %s Output: %s\n" % (config['infile'], config['outfile'])
  234. threads = []
  235. t = watcher_thread()
  236. t.start()
  237. threads.append(t)
  238. worker_num = config['threadnum']
  239. lines_num = len(hosts)
  240. lines_per_thread = lines_num / worker_num
  241. lines_remain = lines_num % worker_num
  242. start_pt = 0
  243. for _ in range(worker_num):
  244. if not running: break
  245. lines_for_thread = lines_per_thread
  246. if lines_for_thread == 0 and lines_remain == 0:
  247. break
  248. if lines_remain > 0:
  249. lines_for_thread += 1
  250. lines_remain -= 1
  251. t = worker_thread(start_pt, start_pt + lines_for_thread)
  252. start_pt += lines_for_thread
  253. t.start()
  254. threads.append(t)
  255. for t in threads:
  256. t.join()
  257. try:
  258. outfile.writelines(hosts)
  259. except IOError as e:
  260. print e
  261. sys.exit(e.errno)
  262. sys.exit(0)
  263. if __name__ == '__main__':
  264. main()