mirror of
https://github.com/lemeow125/DocManagerBackend.git
synced 2025-01-18 17:13:00 +08:00
Add Ollama JSON schema for categorization
This commit is contained in:
parent
844113d44f
commit
b24646d42d
5 changed files with 88 additions and 40 deletions
|
@ -1,4 +1,3 @@
|
||||||
from ollama import ChatResponse
|
|
||||||
import base64
|
import base64
|
||||||
import httpx
|
import httpx
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
|
@ -18,6 +17,9 @@ from django.core.files import File
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from ollama import Client
|
from ollama import Client
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class PDFHandler(FileSystemEventHandler):
|
class PDFHandler(FileSystemEventHandler):
|
||||||
|
@ -83,55 +85,52 @@ class PDFHandler(FileSystemEventHandler):
|
||||||
# Perform OCR
|
# Perform OCR
|
||||||
text = pytesseract.image_to_string(img).strip()
|
text = pytesseract.image_to_string(img).strip()
|
||||||
|
|
||||||
# Get document category
|
|
||||||
# Try to pass image to the Ollama image recognition API first
|
# Try to pass image to the Ollama image recognition API first
|
||||||
try:
|
try:
|
||||||
|
class DocumentCategory(BaseModel):
|
||||||
|
category: str = "other"
|
||||||
|
explanation: Optional[str] = None
|
||||||
|
|
||||||
client = Client(
|
client = Client(
|
||||||
host=get_secret("OLLAMA_URL"),
|
host=get_secret("OLLAMA_URL"),
|
||||||
auth=httpx.BasicAuth(
|
auth=httpx.BasicAuth(
|
||||||
username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None
|
username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
encoded_image = base64.b64encode(
|
encoded_image = base64.b64encode(
|
||||||
img_buffer.getvalue()).decode()
|
img_buffer.getvalue()).decode()
|
||||||
|
|
||||||
attempts = 0
|
possible_categories = set((Document.objects.all().values_list(
|
||||||
while True:
|
"document_type", flat=True), "Documented Procedures Manual", "Form", "Special Order"))
|
||||||
if attempts >= 3:
|
prompt = f"""
|
||||||
raise Exception(
|
Read the text from the image and provide a category. Return as JSON.
|
||||||
"Unable to categorize using Ollama API")
|
|
||||||
attempts += 1
|
|
||||||
|
|
||||||
content = f"""
|
Possible categories are: {possible_categories}
|
||||||
Read the text from the image and provide a category.
|
|
||||||
|
|
||||||
Possible categories are: Announcement, Manual, Form
|
|
||||||
|
|
||||||
Respond only with the category. No explanations are necessary.
|
|
||||||
"""
|
"""
|
||||||
|
response = client.chat(
|
||||||
response: ChatResponse = client.chat(
|
|
||||||
model=get_secret("OLLAMA_MODEL"),
|
model=get_secret("OLLAMA_MODEL"),
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "user", "content": content,
|
{"role": "user",
|
||||||
|
"content": prompt,
|
||||||
"images": [encoded_image]},
|
"images": [encoded_image]},
|
||||||
],
|
],
|
||||||
|
format=DocumentCategory.model_json_schema(),
|
||||||
|
options={
|
||||||
|
"temperature": 0
|
||||||
|
},
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
document_type = response["message"]["content"].replace(
|
DocumentCategory.model_validate_json(
|
||||||
"*", "").replace(".", "")
|
response.message.content)
|
||||||
|
result = json.loads(response.message.content)
|
||||||
# A few safety checks if the model does not follow through with output instructions
|
document_type = result.get("category")
|
||||||
if len(document_type) > 16:
|
|
||||||
self.logger.warning(
|
|
||||||
f"Ollama API gave incorrect document category: {response['message']['content']}. Retrying...")
|
|
||||||
break
|
|
||||||
|
|
||||||
# If that fails, just use regular OCR read the title as a dirty fix/fallback
|
# If that fails, just use regular OCR read the title as a dirty fix/fallback
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.warning(f"Error! {e}")
|
self.logger.warning(f"Error! {e}")
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"Ollama OCR offloading failed. Falling back to default OCR")
|
"Ollama OCR offload failed. Falling back to default OCR")
|
||||||
lines = text.split("\n")
|
lines = text.split("\n")
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
@ -158,7 +157,8 @@ class PDFHandler(FileSystemEventHandler):
|
||||||
DOCUMENT.file.save(
|
DOCUMENT.file.save(
|
||||||
name=filename, content=File(open(file_path, "rb")))
|
name=filename, content=File(open(file_path, "rb")))
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Document '{filename}' created successfully with type '{document_type}'."
|
f"Document '{filename}' created successfully with type '{
|
||||||
|
document_type}'."
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Generated by Django 5.1.3 on 2024-12-18 07:58
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
("document_requests", "0001_initial"),
|
||||||
|
("questionnaires", "0001_initial"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="documentrequest",
|
||||||
|
name="questionnaire",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
to="questionnaires.questionnaire",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -6,12 +6,16 @@ class DocumentRequestUnit(models.Model):
|
||||||
document_request = models.ForeignKey(
|
document_request = models.ForeignKey(
|
||||||
"document_requests.DocumentRequest", on_delete=models.CASCADE
|
"document_requests.DocumentRequest", on_delete=models.CASCADE
|
||||||
)
|
)
|
||||||
document = models.ForeignKey("documents.Document", on_delete=models.CASCADE)
|
document = models.ForeignKey(
|
||||||
|
"documents.Document", on_delete=models.CASCADE)
|
||||||
copies = models.IntegerField(default=1, null=False, blank=False)
|
copies = models.IntegerField(default=1, null=False, blank=False)
|
||||||
|
|
||||||
|
|
||||||
class DocumentRequest(models.Model):
|
class DocumentRequest(models.Model):
|
||||||
requester = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE)
|
requester = models.ForeignKey(
|
||||||
|
"accounts.CustomUser", on_delete=models.CASCADE)
|
||||||
|
questionnaire = models.ForeignKey(
|
||||||
|
"questionnaires.Questionnaire", on_delete=models.SET_NULL, null=True, blank=True)
|
||||||
documents = models.ManyToManyField("document_requests.DocumentRequestUnit")
|
documents = models.ManyToManyField("document_requests.DocumentRequestUnit")
|
||||||
date_requested = models.DateTimeField(default=now, editable=False)
|
date_requested = models.DateTimeField(default=now, editable=False)
|
||||||
college = models.CharField(max_length=64, blank=False, null=False)
|
college = models.CharField(max_length=64, blank=False, null=False)
|
||||||
|
@ -23,11 +27,13 @@ class DocumentRequest(models.Model):
|
||||||
("denied", "Denied"),
|
("denied", "Denied"),
|
||||||
)
|
)
|
||||||
|
|
||||||
status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="pending")
|
status = models.CharField(
|
||||||
|
max_length=32, choices=STATUS_CHOICES, default="pending")
|
||||||
|
|
||||||
TYPE_CHOICES = (
|
TYPE_CHOICES = (
|
||||||
("softcopy", "Softcopy"),
|
("softcopy", "Softcopy"),
|
||||||
("hardcopy", "Hardcopy"),
|
("hardcopy", "Hardcopy"),
|
||||||
)
|
)
|
||||||
|
|
||||||
type = models.CharField(max_length=16, choices=TYPE_CHOICES, default="softcopy")
|
type = models.CharField(
|
||||||
|
max_length=16, choices=TYPE_CHOICES, default="softcopy")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from documents.serializers import DocumentSerializer, DocumentFileSerializer
|
from documents.serializers import DocumentSerializer, DocumentFileSerializer
|
||||||
|
from questionnaires.models import Questionnaire
|
||||||
from accounts.models import CustomUser
|
from accounts.models import CustomUser
|
||||||
from emails.templates import RequestUpdateEmail
|
from emails.templates import RequestUpdateEmail
|
||||||
from .models import DocumentRequest, DocumentRequestUnit
|
from .models import DocumentRequest, DocumentRequestUnit
|
||||||
|
@ -24,7 +25,8 @@ class DocumentRequestCreationSerializer(serializers.ModelSerializer):
|
||||||
documents = DocumentRequestUnitCreationSerializer(many=True, required=True)
|
documents = DocumentRequestUnitCreationSerializer(many=True, required=True)
|
||||||
college = serializers.CharField(max_length=64)
|
college = serializers.CharField(max_length=64)
|
||||||
purpose = serializers.CharField(max_length=512)
|
purpose = serializers.CharField(max_length=512)
|
||||||
type = serializers.ChoiceField(choices=DocumentRequest.TYPE_CHOICES, required=True)
|
type = serializers.ChoiceField(
|
||||||
|
choices=DocumentRequest.TYPE_CHOICES, required=True)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
model = DocumentRequest
|
model = DocumentRequest
|
||||||
|
@ -79,6 +81,12 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
|
||||||
queryset=CustomUser.objects.all(),
|
queryset=CustomUser.objects.all(),
|
||||||
required=False,
|
required=False,
|
||||||
)
|
)
|
||||||
|
requester = serializers.SlugRelatedField(
|
||||||
|
many=False,
|
||||||
|
slug_field="id",
|
||||||
|
queryset=CustomUser.objects.all(),
|
||||||
|
required=False,
|
||||||
|
)
|
||||||
purpose = serializers.CharField(max_length=512)
|
purpose = serializers.CharField(max_length=512)
|
||||||
date_requested = serializers.DateTimeField(
|
date_requested = serializers.DateTimeField(
|
||||||
format="%m-%d-%Y %I:%M %p", read_only=True
|
format="%m-%d-%Y %I:%M %p", read_only=True
|
||||||
|
@ -108,10 +116,10 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_documents(self, obj):
|
def get_documents(self, obj):
|
||||||
if obj.status != "approved":
|
if obj.questionnaire and obj.status == "approved":
|
||||||
serializer_class = DocumentRequestUnitSerializer
|
|
||||||
else:
|
|
||||||
serializer_class = DocumentRequestUnitWithFileSerializer
|
serializer_class = DocumentRequestUnitWithFileSerializer
|
||||||
|
else:
|
||||||
|
serializer_class = DocumentRequestUnitSerializer
|
||||||
return serializer_class(obj.documents, many=True).data
|
return serializer_class(obj.documents, many=True).data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
annotated-types==0.7.0
|
||||||
|
anyio==4.7.0
|
||||||
asgiref==3.8.1
|
asgiref==3.8.1
|
||||||
attrs==24.2.0
|
attrs==24.2.0
|
||||||
black==24.10.0
|
black==24.10.0
|
||||||
|
@ -22,6 +24,9 @@ drf-spectacular-sidecar==2024.11.1
|
||||||
filelock==3.16.1
|
filelock==3.16.1
|
||||||
fsspec==2024.10.0
|
fsspec==2024.10.0
|
||||||
gunicorn==23.0.0
|
gunicorn==23.0.0
|
||||||
|
h11==0.14.0
|
||||||
|
httpcore==1.0.7
|
||||||
|
httpx==0.27.2
|
||||||
idna==3.10
|
idna==3.10
|
||||||
inflection==0.5.1
|
inflection==0.5.1
|
||||||
Jinja2==3.1.4
|
Jinja2==3.1.4
|
||||||
|
@ -32,11 +37,14 @@ mpmath==1.3.0
|
||||||
mypy-extensions==1.0.0
|
mypy-extensions==1.0.0
|
||||||
networkx==3.4.2
|
networkx==3.4.2
|
||||||
oauthlib==3.2.2
|
oauthlib==3.2.2
|
||||||
|
ollama==0.4.4
|
||||||
packaging==24.2
|
packaging==24.2
|
||||||
pathspec==0.12.1
|
pathspec==0.12.1
|
||||||
pillow==11.0.0
|
pillow==11.0.0
|
||||||
platformdirs==4.3.6
|
platformdirs==4.3.6
|
||||||
pycparser==2.22
|
pycparser==2.22
|
||||||
|
pydantic==2.10.3
|
||||||
|
pydantic_core==2.27.1
|
||||||
pyflakes==3.2.0
|
pyflakes==3.2.0
|
||||||
PyJWT==2.10.0
|
PyJWT==2.10.0
|
||||||
PyMuPDF==1.24.14
|
PyMuPDF==1.24.14
|
||||||
|
@ -49,6 +57,7 @@ requests==2.32.3
|
||||||
requests-oauthlib==2.0.0
|
requests-oauthlib==2.0.0
|
||||||
rpds-py==0.21.0
|
rpds-py==0.21.0
|
||||||
setuptools==70.2.0
|
setuptools==70.2.0
|
||||||
|
sniffio==1.3.1
|
||||||
social-auth-app-django==5.4.2
|
social-auth-app-django==5.4.2
|
||||||
social-auth-core==4.5.4
|
social-auth-core==4.5.4
|
||||||
sqlparse==0.5.2
|
sqlparse==0.5.2
|
||||||
|
|
Loading…
Reference in a new issue