| #include <u.h> |
| #include <libc.h> |
| #include <bio.h> |
| #include <regexp.h> |
| #include "hash.h" |
| |
| enum |
| { |
| MAXTAB = 256, |
| MAXBEST = 32 |
| }; |
| |
| typedef struct Table Table; |
| struct Table |
| { |
| char *file; |
| Hash *hash; |
| int nmsg; |
| }; |
| |
| typedef struct Word Word; |
| struct Word |
| { |
| Stringtab *s; /* from hmsg */ |
| int count[MAXTAB]; /* counts from each table */ |
| double p[MAXTAB]; /* probabilities from each table */ |
| double mp; /* max probability */ |
| int mi; /* w.p[w.mi] = w.mp */ |
| }; |
| |
| Table tab[MAXTAB]; |
| int ntab; |
| |
| Word best[MAXBEST]; |
| int mbest; |
| int nbest; |
| |
| int debug; |
| |
| void |
| usage(void) |
| { |
| fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n"); |
| exits("usage"); |
| } |
| |
| void* |
| emalloc(int n) |
| { |
| void *v; |
| |
| v = mallocz(n, 1); |
| if(v == nil) |
| sysfatal("out of memory"); |
| return v; |
| } |
| |
| void |
| noteword(Word *w) |
| { |
| int i; |
| |
| for(i=nbest-1; i>=0; i--) |
| if(w->mp < best[i].mp) |
| break; |
| i++; |
| |
| if(i >= mbest) |
| return; |
| if(nbest == mbest) |
| nbest--; |
| if(i < nbest) |
| memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0])); |
| best[i] = *w; |
| nbest++; |
| } |
| |
| Hash* |
| hread(char *s) |
| { |
| Hash *h; |
| Biobuf *b; |
| |
| if((b = Bopenlock(s, OREAD)) == nil) |
| sysfatal("open %s: %r", s); |
| |
| h = emalloc(sizeof(Hash)); |
| Breadhash(b, h, 1); |
| Bterm(b); |
| return h; |
| } |
| |
| void |
| main(int argc, char **argv) |
| { |
| int i, j, a, mi, oi, tot, keywords; |
| double totp, p, xp[MAXTAB]; |
| Hash *hmsg; |
| Word w; |
| Stringtab *s, *t; |
| Biobuf bout; |
| |
| mbest = 15; |
| keywords = 0; |
| ARGBEGIN{ |
| case 'D': |
| debug = 1; |
| break; |
| case 'k': |
| keywords = 1; |
| break; |
| case 'm': |
| mbest = atoi(EARGF(usage())); |
| if(mbest > MAXBEST) |
| sysfatal("cannot keep more than %d words", MAXBEST); |
| break; |
| default: |
| usage(); |
| }ARGEND |
| |
| for(i=0; i<argc; i++) |
| if(strcmp(argv[i], "~") == 0) |
| break; |
| |
| if(i > MAXTAB) |
| sysfatal("cannot handle more than %d tables", MAXTAB); |
| |
| if(i+1 >= argc) |
| usage(); |
| |
| for(i=0; i<argc; i++){ |
| if(strcmp(argv[i], "~") == 0) |
| break; |
| tab[ntab].file = argv[i]; |
| tab[ntab].hash = hread(argv[i]); |
| s = findstab(tab[ntab].hash, "*nmsg*", 6, 1); |
| if(s == nil || s->count == 0) |
| tab[ntab].nmsg = 1; |
| else |
| tab[ntab].nmsg = s->count; |
| ntab++; |
| } |
| |
| Binit(&bout, 1, OWRITE); |
| |
| oi = ++i; |
| for(a=i; a<argc; a++){ |
| hmsg = hread(argv[a]); |
| nbest = 0; |
| for(s=hmsg->all; s; s=s->link){ |
| w.s = s; |
| tot = 0; |
| totp = 0.0; |
| for(i=0; i<ntab; i++){ |
| t = findstab(tab[i].hash, s->str, s->n, 0); |
| if(t == nil) |
| w.count[i] = 0; |
| else |
| w.count[i] = t->count; |
| tot += w.count[i]; |
| p = w.count[i]/(double)tab[i].nmsg; |
| if(p >= 1.0) |
| p = 1.0; |
| w.p[i] = p; |
| totp += p; |
| } |
| |
| if(tot < 5){ /* word does not appear enough; give to box 0 */ |
| w.p[0] = 0.5; |
| for(i=1; i<ntab; i++) |
| w.p[i] = 0.1; |
| w.mp = 0.5; |
| w.mi = 0; |
| noteword(&w); |
| continue; |
| } |
| |
| w.mp = 0.0; |
| for(i=0; i<ntab; i++){ |
| p = w.p[i]; |
| p /= totp; |
| if(p < 0.01) |
| p = 0.01; |
| else if(p > 0.99) |
| p = 0.99; |
| if(p > w.mp){ |
| w.mp = p; |
| w.mi = i; |
| } |
| w.p[i] = p; |
| } |
| noteword(&w); |
| } |
| |
| totp = 0.0; |
| for(i=0; i<ntab; i++){ |
| p = 1.0; |
| for(j=0; j<nbest; j++) |
| p *= best[j].p[i]; |
| xp[i] = p; |
| totp += p; |
| } |
| for(i=0; i<ntab; i++) |
| xp[i] /= totp; |
| mi = 0; |
| for(i=1; i<ntab; i++) |
| if(xp[i] > xp[mi]) |
| mi = i; |
| if(oi != argc-1) |
| Bprint(&bout, "%s: ", argv[a]); |
| Bprint(&bout, "%s %f", tab[mi].file, xp[mi]); |
| if(keywords){ |
| for(i=0; i<nbest; i++){ |
| Bprint(&bout, " "); |
| Bwrite(&bout, best[i].s->str, best[i].s->n); |
| Bprint(&bout, " %f", best[i].p[mi]); |
| } |
| } |
| freehash(hmsg); |
| Bprint(&bout, "\n"); |
| if(debug){ |
| for(i=0; i<nbest; i++){ |
| Bwrite(&bout, best[i].s->str, best[i].s->n); |
| Bprint(&bout, " %f", best[i].p[mi]); |
| if(best[i].p[mi] < best[i].mp) |
| Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file); |
| Bprint(&bout, "\n"); |
| } |
| } |
| } |
| Bterm(&bout); |
| } |