update_hosts.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import sys
  5. import re
  6. import socket
  7. import getopt
  8. import threading
  9. import subprocess
  10. import shlex
  11. import time
  12. import select
  13. blackhole = (
  14. '10::2222',
  15. '21:2::2',
  16. '101::1234',
  17. '200:2:807:c62d::',
  18. '200:2:253d:369e::',
  19. '200:2:2e52:ae44::',
  20. '200:2:3b18:3ad::',
  21. '200:2:4e10:310f::',
  22. '200:2:5d2e:859::',
  23. '200:2:9f6a:794b::',
  24. '200:2:cb62:741::',
  25. '200:2:f3b9:bb27::',
  26. '2001::212',
  27. '2001:da8:112::21ae',
  28. '2003:ff:1:2:3:4:5fff:6',
  29. '2003:ff:1:2:3:4:5fff:7',
  30. '2003:ff:1:2:3:4:5fff:8',
  31. '2003:ff:1:2:3:4:5fff:9',
  32. '2003:ff:1:2:3:4:5fff:10',
  33. '2003:ff:1:2:3:4:5fff:11',
  34. '2003:ff:1:2:3:4:5fff:12',
  35. '2123::3e12',
  36. '3059:83eb::e015:2bee:0:0',
  37. '1.2.3.4',
  38. '4.36.66.178',
  39. '8.7.198.45',
  40. '37.61.54.158',
  41. '46.82.174.68',
  42. '59.24.3.173',
  43. '64.33.88.161',
  44. '78.16.49.15',
  45. '93.46.8.89',
  46. '127.0.0.1',
  47. '159.106.121.75',
  48. '202.181.7.85',
  49. '203.98.7.65',
  50. '243.185.187.39'
  51. )
  52. dns = {
  53. 'google_a': '2001:4860:4860::8888',
  54. 'google_b': '2001:4860:4860::8844',
  55. 'he_net': '2001:470:20::2',
  56. 'lax_he_net': '2001:470:0:9d::2'
  57. }
  58. config = {
  59. 'dns': dns['google_b'],
  60. 'infile': '',
  61. 'outfile': '',
  62. 'querytype': 'aaaa',
  63. 'cname': False,
  64. 'threadnum': 10
  65. }
  66. hosts = []
  67. done_num = 0
  68. thread_lock = threading.Lock()
  69. running = True
  70. class worker_thread(threading.Thread):
  71. def __init__(self, start_pt, end_pt):
  72. threading.Thread.__init__(self)
  73. self.start_pt = start_pt
  74. self.end_pt = end_pt
  75. def run(self):
  76. global hosts, done_num
  77. for i in range(self.start_pt, self.end_pt):
  78. if not running: break
  79. line = hosts[i].strip()
  80. if line == '' or line[0:2] == '##':
  81. hosts[i] = line + '\r\n'
  82. with thread_lock: done_num += 1
  83. continue
  84. # uncomment line
  85. line = line.lstrip('#')
  86. # split comment that appended to line
  87. comment = ''
  88. p = line.find('#')
  89. if p > 0:
  90. comment = line[p:]
  91. line = line[:p]
  92. arr = line.split()
  93. if len(arr) == 1:
  94. domain = arr[0]
  95. else:
  96. domain = arr[1]
  97. flag = False
  98. if validate_domain(domain):
  99. cname, ip = query_domain(domain, False)
  100. if ip == '' or ip in blackhole:
  101. cname, ip = query_domain(domain, True)
  102. if ip:
  103. flag = True
  104. arr[0] = ip
  105. if len(arr) == 1:
  106. arr.append(domain)
  107. if config['cname'] and cname:
  108. arr.append('#' + cname)
  109. else:
  110. if comment:
  111. arr.append(comment)
  112. if not flag:
  113. arr[0] = '#' + arr[0]
  114. if comment:
  115. arr.append(comment)
  116. hosts[i] = ' '.join(arr)
  117. hosts[i] += '\r\n'
  118. with thread_lock: done_num += 1
  119. class watcher_thread(threading.Thread):
  120. def run(self):
  121. total_num = len(hosts)
  122. wn = int(config['threadnum'])
  123. if wn > total_num:
  124. wn = total_num
  125. print "There are %d threads working..." % wn
  126. print "Press 'Enter' to exit.\n"
  127. while True:
  128. if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
  129. raw_input()
  130. print 'Waiting threads to exit...'
  131. global running
  132. with thread_lock:
  133. running = False
  134. break
  135. dn = done_num
  136. outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
  137. % (total_num, dn, dn * 100 / total_num)
  138. print outbuf,
  139. sys.stdout.flush()
  140. if dn == total_num:
  141. print outbuf
  142. break
  143. time.sleep(1)
  144. def query_domain(domain, tcp):
  145. cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
  146. % (config['querytype'], config['dns'], domain)
  147. if tcp:
  148. cmd = cmd + ' +tcp'
  149. proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
  150. out, _ = proc.communicate()
  151. outarr = out.splitlines()
  152. cname = ip = ''
  153. for v in outarr:
  154. if cname == '' and validate_domain(v[:-1]):
  155. cname = v[:-1]
  156. if ip == '' and validate_ip_addr(v):
  157. ip = v
  158. break
  159. return (cname, ip)
  160. def validate_domain(domain):
  161. pattern = '^((?!-)[*A-Za-z0-9-]{1,63}(?<!-)\\.)+[A-Za-z]{2,6}$'
  162. p = re.compile(pattern)
  163. m = p.match(domain)
  164. if m:
  165. return True
  166. else:
  167. return False
  168. def validate_ip_addr(ip_addr):
  169. if ':' in ip_addr:
  170. try:
  171. socket.inet_pton(socket.AF_INET6, ip_addr)
  172. return True
  173. except socket.error:
  174. return False
  175. else:
  176. try:
  177. socket.inet_pton(socket.AF_INET, ip_addr)
  178. return True
  179. except socket.error:
  180. return False
  181. def print_help():
  182. print '''usage: update_hosts [OPTIONS] FILE
  183. A simple multi-threading tool used for updating hosts file.
  184. Options:
  185. -h, --help show this help message and exit
  186. -s DNS set another dns server, default: 2001:4860:4860::8844
  187. -o OUT_FILE output file, default: inputfilename.out
  188. -t QUERY_TYPE dig command query type, default: aaaa
  189. -c, --cname write canonical name into hosts file
  190. -n THREAD_NUM set the number of worker threads, default: 10
  191. '''
  192. def get_config():
  193. shortopts = 'hs:o:t:n:c'
  194. longopts = ['help', 'cname']
  195. try:
  196. optlist, args = getopt.gnu_getopt(sys.argv[1:], shortopts, longopts)
  197. except getopt.GetoptError as e:
  198. print e, '\n'
  199. print_help()
  200. sys.exit(1)
  201. global config
  202. for key, value in optlist:
  203. if key == '-s':
  204. config['dns'] = value
  205. elif key == '-o':
  206. config['outfile'] = value
  207. elif key == '-t':
  208. config['querytype'] = value
  209. elif key in ('-c', '--cname'):
  210. config['cname'] = True
  211. elif key == '-n':
  212. config['threadnum'] = int(value)
  213. elif key in ('-h', '--help'):
  214. print_help()
  215. sys.exit(0)
  216. if len(args) != 1:
  217. print "You must specify the input hosts file (only one)."
  218. sys.exit(1)
  219. config['infile'] = args[0]
  220. if config['outfile'] == '':
  221. config['outfile'] = config['infile'] + '.out'
  222. def main():
  223. get_config()
  224. dig_path = '/usr/bin/dig'
  225. if not os.path.isfile(dig_path) or not os.access(dig_path, os.X_OK):
  226. print "It seems you don't have 'dig' command installed properly "\
  227. "on your system."
  228. sys.exit(2)
  229. global hosts
  230. try:
  231. with open(config['infile'], 'r') as infile:
  232. hosts = infile.readlines()
  233. except IOError as e:
  234. print e
  235. sys.exit(e.errno)
  236. if os.path.exists(config['outfile']):
  237. config['outfile'] += '.new'
  238. try:
  239. outfile = open(config['outfile'], 'w')
  240. except IOError as e:
  241. print e
  242. sys.exit(e.errno)
  243. print "Input: %s Output: %s\n" % (config['infile'], config['outfile'])
  244. threads = []
  245. t = watcher_thread()
  246. t.start()
  247. threads.append(t)
  248. worker_num = config['threadnum']
  249. lines_num = len(hosts)
  250. lines_per_thread = lines_num / worker_num
  251. lines_remain = lines_num % worker_num
  252. start_pt = 0
  253. for _ in range(worker_num):
  254. if not running: break
  255. lines_for_thread = lines_per_thread
  256. if lines_for_thread == 0 and lines_remain == 0:
  257. break
  258. if lines_remain > 0:
  259. lines_for_thread += 1
  260. lines_remain -= 1
  261. t = worker_thread(start_pt, start_pt + lines_for_thread)
  262. start_pt += lines_for_thread
  263. t.start()
  264. threads.append(t)
  265. for t in threads:
  266. t.join()
  267. try:
  268. outfile.writelines(hosts)
  269. except IOError as e:
  270. print e
  271. sys.exit(e.errno)
  272. sys.exit(0)
  273. if __name__ == '__main__':
  274. main()