| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346 | 
							
- import os
 
- import sys
 
- import re
 
- import socket
 
- import ipaddress
 
- import getopt
 
- import threading
 
- import subprocess
 
- import shlex
 
- import time
 
- import select
 
- blackhole = (
 
-     '10::2222',
 
-     '21:2::2',
 
-     '101::1234',
 
-     '200:2:807:c62d::',
 
-     '200:2:253d:369e::',
 
-     '200:2:2e52:ae44::',
 
-     '200:2:3b18:3ad::',
 
-     '200:2:4e10:310f::',
 
-     '200:2:5d2e:859::',
 
-     '200:2:9f6a:794b::',
 
-     '200:2:cb62:741::',
 
-     '200:2:cc9b:953e::',
 
-     '200:2:f3b9:bb27::',
 
-     '2001::212',
 
-     '2001:da8:112::21ae',
 
-     '2003:ff:1:2:3:4:5fff:6',
 
-     '2003:ff:1:2:3:4:5fff:7',
 
-     '2003:ff:1:2:3:4:5fff:8',
 
-     '2003:ff:1:2:3:4:5fff:9',
 
-     '2003:ff:1:2:3:4:5fff:10',
 
-     '2003:ff:1:2:3:4:5fff:11',
 
-     '2003:ff:1:2:3:4:5fff:12',
 
-     '2123::3e12',
 
-     '3059:83eb::e015:2bee:0:0',
 
-     '1.2.3.4',
 
-     '4.36.66.178',
 
-     '8.7.198.45',
 
-     '37.61.54.158',
 
-     '46.82.174.68',
 
-     '59.24.3.173',
 
-     '64.33.88.161',
 
-     '78.16.49.15',
 
-     '93.46.8.89',
 
-     '127.0.0.1',
 
-     '159.106.121.75',
 
-     '202.181.7.85',
 
-     '203.98.7.65',
 
-     '243.185.187.39'
 
- )
 
- invalid_network = [
 
-     '200:2::/32',
 
-     '2001::/32', 
 
-     'a000::/8',
 
- ]
 
- dns = {
 
-     'google_a': '2001:4860:4860::8888',
 
-     'google_b': '2001:4860:4860::8844',
 
-     'he_net': '2001:470:20::2',
 
-     'lax_he_net': '2001:470:0:9d::2'
 
- }
 
- config = {
 
-     'dns': dns['google_b'],
 
-     'infile': '',
 
-     'outfile': '',
 
-     'querytype': 'aaaa',
 
-     'cname': False,
 
-     'threadnum': 10
 
- }
 
- hosts = []
 
- done_num = 0
 
- thread_lock = threading.Lock()
 
- running = True
 
- class worker_thread(threading.Thread):
 
-     def __init__(self, start_pt, end_pt):
 
-         threading.Thread.__init__(self)
 
-         self.start_pt = start_pt
 
-         self.end_pt = end_pt
 
-     def run(self):
 
-         global hosts, done_num
 
-         for i in range(self.start_pt, self.end_pt):
 
-             if not running: break
 
-             line = hosts[i].strip()
 
-             if line == '' or line[0:2] == '##':
 
-                 hosts[i] = line + '\r\n'
 
-                 with thread_lock: done_num += 1
 
-                 continue
 
-             
 
-             line = line.lstrip('#')
 
-             
 
-             comment = ''
 
-             p = line.find('#')
 
-             if p > 0:
 
-                 comment = line[p:]
 
-                 line = line[:p]
 
-             arr = line.split()
 
-             if len(arr) == 1:
 
-                 domain = arr[0]
 
-             else:
 
-                 domain = arr[1]
 
-             flag = False
 
-             if validate_domain(domain):
 
-                 cname, ip = query_domain(domain, False)
 
-                 if ip == '' or ip in blackhole or invalid_address(ip):
 
-                     cname, ip = query_domain(domain, True)
 
-                 if ip:
 
-                     flag = True
 
-                     arr[0] = ip
 
-                     if len(arr) == 1:
 
-                         arr.append(domain)
 
-                     if config['cname'] and cname:
 
-                         arr.append('#' + cname)
 
-                     else:
 
-                         if comment:
 
-                             arr.append(comment)
 
-             if not flag:
 
-                 arr[0] = '#' + arr[0]
 
-                 if comment:
 
-                     arr.append(comment)
 
-             hosts[i] = ' '.join(arr)
 
-             hosts[i] += '\r\n'
 
-             with thread_lock: done_num += 1
 
- class watcher_thread(threading.Thread):
 
-     def run(self):
 
-         total_num = len(hosts)
 
-         wn = int(config['threadnum'])
 
-         if wn > total_num:
 
-             wn = total_num
 
-         print("There are %d threads working..." % wn)
 
-         print("Press 'Enter' to exit.\n")
 
-         while True:
 
-             if sys.stdin in select.select([sys.stdin], [], [], 0)[0]:
 
-                 input()
 
-                 print("Waiting threads to exit...")
 
-                 global running
 
-                 with thread_lock:
 
-                     running = False
 
-                 break
 
-             dn = done_num
 
-             outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
 
-                      % (total_num, dn, dn * 100 / total_num)
 
-             print(outbuf, end='', flush=True)
 
-             if dn == total_num:
 
-                 print(outbuf)
 
-                 break
 
-             time.sleep(1)
 
- def query_domain(domain, tcp):
 
-     cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
 
-         % (config['querytype'], config['dns'], domain)
 
-     if tcp:
 
-         cmd = cmd + ' +tcp'
 
-     proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE)
 
-     out, _ = proc.communicate()
 
-     outarr = out.decode('utf-8').splitlines()
 
-     cname = ip = ''
 
-     for v in outarr:
 
-         if cname == '' and validate_domain(v[:-1]):
 
-             cname = v[:-1]
 
-         if ip == '' and validate_ip_addr(v):
 
-             ip = v
 
-             break
 
-     return (cname, ip)
 
- def validate_domain(domain):
 
-     pattern = '^((?!-)[*A-Za-z0-9-]{1,63}(?<!-)\\.)+[A-Za-z]{2,6}$'
 
-     p = re.compile(pattern)
 
-     m = p.match(domain)
 
-     if m:
 
-         return True
 
-     else:
 
-         return False
 
- def validate_ip_addr(ip_addr):
 
-     if ':' in ip_addr:
 
-         try:
 
-             socket.inet_pton(socket.AF_INET6, ip_addr)
 
-             return True
 
-         except socket.error:
 
-             return False
 
-     else:
 
-         try:
 
-             socket.inet_pton(socket.AF_INET, ip_addr)
 
-             return True
 
-         except socket.error:
 
-             return False
 
- def invalid_address(ip_addr):
 
-     address = ipaddress.ip_address(ip_addr)
 
-     for cidr in invalid_network:
 
-         if address in ipaddress.ip_network(cidr):
 
-             return True
 
-     else:
 
-         return False
 
- def print_help():
 
-     print('''usage: update_hosts [OPTIONS] FILE
 
- A simple multi-threading tool used for updating hosts file.
 
- Options:
 
-   -h, --help             show this help message and exit
 
-   -s DNS                 set another dns server, default: 2001:4860:4860::8844
 
-   -o OUT_FILE            output file, default: inputfilename.out
 
-   -t QUERY_TYPE          dig command query type, default: aaaa
 
-   -c, --cname            write canonical name into hosts file
 
-   -n THREAD_NUM          set the number of worker threads, default: 10
 
- ''')
 
- def get_config():
 
-     shortopts = 'hs:o:t:n:c'
 
-     longopts = ['help', 'cname']
 
-     try:
 
-         optlist, args = getopt.gnu_getopt(sys.argv[1:], shortopts, longopts)
 
-     except getopt.GetoptError as e:
 
-         print(e)
 
-         print_help()
 
-         sys.exit(1)
 
-     global config
 
-     for key, value in optlist:
 
-         if key == '-s':
 
-             config['dns'] = value
 
-         elif key == '-o':
 
-             config['outfile'] = value
 
-         elif key == '-t':
 
-             config['querytype'] = value
 
-         elif key in ('-c', '--cname'):
 
-             config['cname'] = True
 
-         elif key == '-n':
 
-             config['threadnum'] = int(value)
 
-         elif key in ('-h', '--help'):
 
-             print_help()
 
-             sys.exit(0)
 
-     if len(args) != 1:
 
-         print("You must specify the input hosts file (only one).")
 
-         sys.exit(1)
 
-     config['infile'] = args[0]
 
-     if config['outfile'] == '':
 
-         config['outfile'] = config['infile'] + '.out'
 
- def main():
 
-     get_config()
 
-     dig_path = '/usr/bin/dig'
 
-     if not os.path.isfile(dig_path) or not os.access(dig_path, os.X_OK):
 
-         print("It seems you don't have 'dig' command installed properly "\
 
-               "on your system.")
 
-         sys.exit(2)
 
-     global hosts
 
-     try:
 
-         with open(config['infile'], 'r') as infile:
 
-             hosts = infile.readlines()
 
-     except IOError as e:
 
-         print(e)
 
-         sys.exit(e.errno)
 
-     if os.path.exists(config['outfile']):
 
-         config['outfile'] += '.new'
 
-     try:
 
-         outfile = open(config['outfile'], 'w')
 
-     except IOError as e:
 
-         print(e)
 
-         sys.exit(e.errno)
 
-     print("Input: %s    Output: %s\n" % (config['infile'], config['outfile']))
 
-     threads = []
 
-     t = watcher_thread()
 
-     t.start()
 
-     threads.append(t)
 
-     worker_num = config['threadnum']
 
-     lines_num = len(hosts)
 
-     lines_per_thread = lines_num // worker_num
 
-     lines_remain = lines_num % worker_num
 
-     start_pt = 0
 
-     for _ in range(worker_num):
 
-         if not running: break
 
-         lines_for_thread = lines_per_thread
 
-         if lines_for_thread == 0 and lines_remain == 0:
 
-             break
 
-         if lines_remain > 0:
 
-             lines_for_thread += 1
 
-             lines_remain -= 1
 
-         t = worker_thread(start_pt, start_pt + lines_for_thread)
 
-         start_pt += lines_for_thread
 
-         t.start()
 
-         threads.append(t)
 
-     for t in threads:
 
-         t.join()
 
-     try:
 
-         outfile.writelines(hosts)
 
-     except IOError as e:
 
-         print(e)
 
-         sys.exit(e.errno)
 
-     sys.exit(0)
 
- if __name__ == '__main__':
 
-     main()
 
 
  |