update_hosts.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import sys
  5. import re
  6. import socket
  7. import getopt
  8. import threading
  9. import subprocess
  10. import shlex
  11. import time
  12. import select
  13. blackhole = (
  14. '10::2222',
  15. '101::1234',
  16. '2001::212',
  17. '2001:da8:112::21ae',
  18. '2003:ff:1:2:3:4:5fff:6',
  19. '2003:ff:1:2:3:4:5fff:7',
  20. '2003:ff:1:2:3:4:5fff:8',
  21. '2003:ff:1:2:3:4:5fff:9',
  22. '2003:ff:1:2:3:4:5fff:10',
  23. '2003:ff:1:2:3:4:5fff:11',
  24. '2003:ff:1:2:3:4:5fff:12',
  25. '200:2:3b18:3ad::',
  26. '21:2::2',
  27. '2123::3e12',
  28. '1.2.3.4',
  29. '127.0.0.1',
  30. '159.106.121.75',
  31. '202.181.7.85',
  32. '203.98.7.65',
  33. '243.185.187.39',
  34. '37.61.54.158',
  35. '4.36.66.178',
  36. '46.82.174.68',
  37. '59.24.3.173',
  38. '64.33.88.161',
  39. '78.16.49.15',
  40. '8.7.198.45',
  41. '93.46.8.89',
  42. )
  43. dns = {
  44. 'google_a': '2001:4860:4860::8888',
  45. 'google_b': '2001:4860:4860::8844',
  46. 'he_net': '2001:470:20::2',
  47. 'lax_he_net': '2001:470:0:9d::2'
  48. }
  49. config = {
  50. 'dns': dns['google_b'],
  51. 'infile': '',
  52. 'outfile': '',
  53. 'querytype': 'aaaa',
  54. 'cname': False,
  55. 'threadnum': 10
  56. }
  57. hosts = []
  58. done_num = 0
  59. thread_lock = threading.Lock()
  60. running = True
  61. class worker_thread(threading.Thread):
  62. def __init__(self, start_pt, end_pt):
  63. threading.Thread.__init__(self)
  64. self.start_pt = start_pt
  65. self.end_pt = end_pt
  66. def run(self):
  67. global hosts, done_num
  68. for i in range(self.start_pt, self.end_pt):
  69. if not running: break
  70. line = hosts[i].strip()
  71. if line == '' or line[0:2] == '##':
  72. hosts[i] = line + '\r\n'
  73. with thread_lock: done_num += 1
  74. continue
  75. # uncomment line
  76. line = line.lstrip('#')
  77. # split comment that appended to line
  78. comment = ''
  79. p = line.find('#')
  80. if p > 0:
  81. comment = line[p:]
  82. line = line[:p]
  83. arr = line.split()
  84. if len(arr) == 1:
  85. domain = arr[0]
  86. else:
  87. domain = arr[1]
  88. flag = False
  89. if validate_domain(domain):
  90. cname, ip = query_domain(domain, False)
  91. if ip == '' or ip in blackhole:
  92. cname, ip = query_domain(domain, True)
  93. if ip:
  94. flag = True
  95. arr[0] = ip
  96. if len(arr) == 1:
  97. arr.append(domain)
  98. if config['cname'] and cname:
  99. arr.append('#' + cname)
  100. else:
  101. if comment:
  102. arr.append(comment)
  103. if not flag:
  104. arr[0] = '#' + arr[0]
  105. if comment:
  106. arr.append(comment)
  107. hosts[i] = ' '.join(arr)
  108. hosts[i] += '\r\n'
  109. with thread_lock: done_num += 1
  110. class watcher_thread(threading.Thread):
  111. def run(self):
  112. total_num = len(hosts)
  113. wn = int(config['threadnum'])
  114. if wn > total_num:
  115. wn = total_num
  116. print "There are %d threads working..." % wn
  117. print "Press 'Enter' to exit.\n"
  118. while True:
  119. if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
  120. raw_input()
  121. print 'Waiting threads to exit...'
  122. global running
  123. with thread_lock:
  124. running = False
  125. break
  126. dn = done_num
  127. outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
  128. % (total_num, dn, dn * 100 / total_num)
  129. print outbuf,
  130. sys.stdout.flush()
  131. if dn == total_num:
  132. print outbuf
  133. break
  134. time.sleep(1)
  135. def query_domain(domain, tcp):
  136. cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
  137. % (config['querytype'], config['dns'], domain)
  138. if tcp:
  139. cmd = cmd + ' +tcp'
  140. proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
  141. out, _ = proc.communicate()
  142. outarr = out.splitlines()
  143. cname = ip = ''
  144. for v in outarr:
  145. if cname == '' and validate_domain(v[:-1]):
  146. cname = v[:-1]
  147. if ip == '' and validate_ip_addr(v):
  148. ip = v
  149. break
  150. return (cname, ip)
  151. def validate_domain(domain):
  152. pattern = '^((?!-)[*A-Za-z0-9-]{1,63}(?<!-)\\.)+[A-Za-z]{2,6}$'
  153. p = re.compile(pattern)
  154. m = p.match(domain)
  155. if m:
  156. return True
  157. else:
  158. return False
  159. def validate_ip_addr(ip_addr):
  160. if ':' in ip_addr:
  161. try:
  162. socket.inet_pton(socket.AF_INET6, ip_addr)
  163. return True
  164. except socket.error:
  165. return False
  166. else:
  167. try:
  168. socket.inet_pton(socket.AF_INET, ip_addr)
  169. return True
  170. except socket.error:
  171. return False
  172. def print_help():
  173. print '''usage: update_hosts [OPTIONS] FILE
  174. A simple multi-threading tool used to update hosts file.
  175. Options:
  176. -h, --help show this help message and exit
  177. -s DNS set another dns server, default: 2001:4860:4860::8844
  178. -o OUT_FILE ouput file, default: inputfilename.out
  179. -t QUERY_TYPE dig command query type, defalut: aaaa
  180. -c, --cname write canonical name into hosts file
  181. -n THREAD_NUM set the number of worker threads, default: 10
  182. '''
  183. def get_config():
  184. shortopts = 'hs:o:t:n:c'
  185. longopts = ['help', 'cname']
  186. try:
  187. optlist, args = getopt.gnu_getopt(sys.argv[1:], shortopts, longopts)
  188. except getopt.GetoptError as e:
  189. print e, '\n'
  190. print_help()
  191. sys.exit(1)
  192. global config
  193. for key, value in optlist:
  194. if key == '-s':
  195. config['dns'] = value
  196. elif key == '-o':
  197. config['outfile'] = value
  198. elif key == '-t':
  199. config['querytype'] = value
  200. elif key in ('-c', '--cname'):
  201. config['cname'] = True
  202. elif key == '-n':
  203. config['threadnum'] = int(value)
  204. elif key in ('-h', '--help'):
  205. print_help()
  206. sys.exit(0)
  207. if len(args) != 1:
  208. print "You must specify the input hosts file (only one)."
  209. sys.exit(1)
  210. config['infile'] = args[0]
  211. if config['outfile'] == '':
  212. config['outfile'] = config['infile'] + '.out'
  213. def main():
  214. get_config()
  215. dig_path = '/usr/bin/dig'
  216. if not os.path.isfile(dig_path) or not os.access(dig_path, os.X_OK):
  217. print "It seems you don't have 'dig' command installed properly "\
  218. "on your system."
  219. sys.exit(2)
  220. global hosts
  221. try:
  222. with open(config['infile'], 'r') as infile:
  223. hosts = infile.readlines()
  224. except IOError as e:
  225. print e
  226. sys.exit(e.errno)
  227. if os.path.exists(config['outfile']):
  228. config['outfile'] += '.new'
  229. try:
  230. outfile = open(config['outfile'], 'w')
  231. except IOError as e:
  232. print e
  233. sys.exit(e.errno)
  234. print "Input: %s Output: %s\n" % (config['infile'], config['outfile'])
  235. threads = []
  236. t = watcher_thread()
  237. t.start()
  238. threads.append(t)
  239. worker_num = config['threadnum']
  240. lines_num = len(hosts)
  241. lines_per_thread = lines_num / worker_num
  242. lines_remain = lines_num % worker_num
  243. start_pt = 0
  244. for _ in range(worker_num):
  245. if not running: break
  246. lines_for_thread = lines_per_thread
  247. if lines_for_thread == 0 and lines_remain == 0:
  248. break
  249. if lines_remain > 0:
  250. lines_for_thread += 1
  251. lines_remain -= 1
  252. t = worker_thread(start_pt, start_pt + lines_for_thread)
  253. start_pt += lines_for_thread
  254. t.start()
  255. threads.append(t)
  256. for t in threads:
  257. t.join()
  258. try:
  259. outfile.writelines(hosts)
  260. except IOError as e:
  261. print e
  262. sys.exit(e.errno)
  263. sys.exit(0)
  264. if __name__ == '__main__':
  265. main()