Browse Source

use a watcher thread to monitor the progress

lennylxx 9 years ago
parent
commit
c4c006734e
2 changed files with 48 additions and 16 deletions
  1. 1 2
      conv.py
  2. 47 14
      update_hosts.py

+ 1 - 2
conv.py

@@ -42,8 +42,7 @@ def code2num(code):
 
 def main():
     if len(sys.argv) != 3:
-        print 'usage:\n\t./%s -i iata\n\t./%s -s sn'\
-            % (sys.argv[0], sys.argv[0])
+        print 'usage:\tconv -i iata\n\tconv -s sn'
         sys.exit(1)
 
     input = sys.argv[2]

+ 47 - 14
update_hosts.py

@@ -9,6 +9,7 @@ import getopt
 import threading
 import subprocess
 import shlex
+import time
 
 blackhole = (
 '10::2222',
@@ -41,6 +42,8 @@ config = {
 }
 
 hosts = []
+done_num = 0
+thread_lock = threading.Lock()
 
 class worker_thread(threading.Thread):
     def __init__(self, start_pt, end_pt):
@@ -49,10 +52,13 @@ class worker_thread(threading.Thread):
         self.end_pt = end_pt
     
     def run(self):
-        global hosts
+        global hosts, done_num
         for i in range(self.start_pt, self.end_pt):
             line = hosts[i].strip()
             
+            with thread_lock:
+                done_num += 1
+
             if line == "" or line[0:2] == '##':
                 hosts[i] = line + '\r\n'
                 continue
@@ -71,7 +77,7 @@ class worker_thread(threading.Thread):
                 if ret in blackhole or not ret:
                     ret = query_domain(domain, True)
 
-                if ret and ret[:2] != ';;':
+                if ret:
                     flag = True
                     arr[0] = ret
 
@@ -84,6 +90,31 @@ class worker_thread(threading.Thread):
             hosts[i] = ' '.join(arr)
             hosts[i] += '\r\n'
 
+class watcher_thread(threading.Thread):
+    def run(self):
+        global hosts, done_num
+        total_num = len(hosts)
+
+        wn = int(config['threadnum'])
+        if wn > total_num:
+            wn = total_num
+        print "There are %d threads working..." % wn
+
+        while True:
+            with thread_lock:
+                dn = done_num
+
+            outbuf = "Total: %d lines, Done: %d lines, Ratio: %d %%.\r"\
+                   % (total_num, dn, float(dn)/total_num*100)
+            print outbuf,
+            sys.stdout.flush()
+
+            if done_num == total_num:
+                print outbuf
+                break
+            time.sleep(1)
+
+
 def query_domain(domain, tcp):
     cmd = "dig +short +time=2 -6 %s @'%s' '%s'"\
         % (config['querytype'], config['dns'], domain)
@@ -144,18 +175,17 @@ Options:
 ''')
 
 def get_config():
-    shortopts = 'hs:i:o:t:n'
+    shortopts = 'hs:i:o:t:n:'
     longopts = ['help']
 
     try:
         optlist, args = getopt.getopt(sys.argv[1:], shortopts, longopts)   
     except getopt.GetoptError as e:
-        print e
+        print e, '\n'
         print_help()
         sys.exit(2)
     
     global config
-    
     for key, value in optlist:
         if key == '-i':
             config['infile'] = value
@@ -176,10 +206,11 @@ def get_config():
 def main():
     get_config()
 
-    global config
+    global config, hosts
 
     try:
-        infile = open(config['infile'], 'r')
+        with open(config['infile'], 'r') as infile:
+            hosts = infile.readlines()
     except IOError as e:
         print e
         sys.exit(e.errno)
@@ -192,20 +223,22 @@ def main():
     except IOError as e:
         print e
         sys.exit(e.errno)
-    
-    global hosts
-    hosts = infile.readlines()
 
     threads = []
-    thread_num = config['threadnum']
+
+    t = watcher_thread()
+    t.start()
+    threads.append(t)
+
+    worker_num = config['threadnum']
     lines_num = len(hosts)
 
-    lines_per_thread = lines_num / thread_num
-    lines_remain = lines_num % thread_num
+    lines_per_thread = lines_num / worker_num
+    lines_remain = lines_num % worker_num
 
     start_pt = 0
 
-    for i in range(thread_num):
+    for i in range(worker_num):
         lines_for_thread = lines_per_thread
 
         if lines_for_thread == 0 and lines_remain == 0: