mirror of
https://github.com/lemeow125/DocManagerBackend.git
synced 2025-01-18 17:13:00 +08:00
Add Ollama JSON schema for categorization
This commit is contained in:
parent
844113d44f
commit
b24646d42d
5 changed files with 88 additions and 40 deletions
|
@ -1,4 +1,3 @@
|
|||
from ollama import ChatResponse
|
||||
import base64
|
||||
import httpx
|
||||
from django.core.management.base import BaseCommand
|
||||
|
@ -18,6 +17,9 @@ from django.core.files import File
|
|||
import logging
|
||||
import time
|
||||
from ollama import Client
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
|
||||
class PDFHandler(FileSystemEventHandler):
|
||||
|
@ -83,55 +85,52 @@ class PDFHandler(FileSystemEventHandler):
|
|||
# Perform OCR
|
||||
text = pytesseract.image_to_string(img).strip()
|
||||
|
||||
# Get document category
|
||||
# Try to pass image to the Ollama image recognition API first
|
||||
try:
|
||||
class DocumentCategory(BaseModel):
|
||||
category: str = "other"
|
||||
explanation: Optional[str] = None
|
||||
|
||||
client = Client(
|
||||
host=get_secret("OLLAMA_URL"),
|
||||
auth=httpx.BasicAuth(
|
||||
username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None
|
||||
username=get_secret("OLLAMA_USERNAME"), password=get_secret("OLLAMA_PASSWORD")) if get_secret("OLLAMA_USE_AUTH") else None,
|
||||
)
|
||||
|
||||
encoded_image = base64.b64encode(
|
||||
img_buffer.getvalue()).decode()
|
||||
|
||||
attempts = 0
|
||||
while True:
|
||||
if attempts >= 3:
|
||||
raise Exception(
|
||||
"Unable to categorize using Ollama API")
|
||||
attempts += 1
|
||||
possible_categories = set((Document.objects.all().values_list(
|
||||
"document_type", flat=True), "Documented Procedures Manual", "Form", "Special Order"))
|
||||
prompt = f"""
|
||||
Read the text from the image and provide a category. Return as JSON.
|
||||
|
||||
content = f"""
|
||||
Read the text from the image and provide a category.
|
||||
|
||||
Possible categories are: Announcement, Manual, Form
|
||||
|
||||
Respond only with the category. No explanations are necessary.
|
||||
Possible categories are: {possible_categories}
|
||||
"""
|
||||
|
||||
response: ChatResponse = client.chat(
|
||||
response = client.chat(
|
||||
model=get_secret("OLLAMA_MODEL"),
|
||||
messages=[
|
||||
{"role": "user", "content": content,
|
||||
{"role": "user",
|
||||
"content": prompt,
|
||||
"images": [encoded_image]},
|
||||
],
|
||||
format=DocumentCategory.model_json_schema(),
|
||||
options={
|
||||
"temperature": 0
|
||||
},
|
||||
|
||||
)
|
||||
|
||||
document_type = response["message"]["content"].replace(
|
||||
"*", "").replace(".", "")
|
||||
|
||||
# A few safety checks if the model does not follow through with output instructions
|
||||
if len(document_type) > 16:
|
||||
self.logger.warning(
|
||||
f"Ollama API gave incorrect document category: {response['message']['content']}. Retrying...")
|
||||
break
|
||||
DocumentCategory.model_validate_json(
|
||||
response.message.content)
|
||||
result = json.loads(response.message.content)
|
||||
document_type = result.get("category")
|
||||
|
||||
# If that fails, just use regular OCR read the title as a dirty fix/fallback
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error! {e}")
|
||||
self.logger.warning(
|
||||
"Ollama OCR offloading failed. Falling back to default OCR")
|
||||
"Ollama OCR offload failed. Falling back to default OCR")
|
||||
lines = text.split("\n")
|
||||
|
||||
for line in lines:
|
||||
|
@ -158,7 +157,8 @@ class PDFHandler(FileSystemEventHandler):
|
|||
DOCUMENT.file.save(
|
||||
name=filename, content=File(open(file_path, "rb")))
|
||||
self.logger.info(
|
||||
f"Document '{filename}' created successfully with type '{document_type}'."
|
||||
f"Document '{filename}' created successfully with type '{
|
||||
document_type}'."
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
# Generated by Django 5.1.3 on 2024-12-18 07:58
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("document_requests", "0001_initial"),
|
||||
("questionnaires", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="documentrequest",
|
||||
name="questionnaire",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="questionnaires.questionnaire",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -6,12 +6,16 @@ class DocumentRequestUnit(models.Model):
|
|||
document_request = models.ForeignKey(
|
||||
"document_requests.DocumentRequest", on_delete=models.CASCADE
|
||||
)
|
||||
document = models.ForeignKey("documents.Document", on_delete=models.CASCADE)
|
||||
document = models.ForeignKey(
|
||||
"documents.Document", on_delete=models.CASCADE)
|
||||
copies = models.IntegerField(default=1, null=False, blank=False)
|
||||
|
||||
|
||||
class DocumentRequest(models.Model):
|
||||
requester = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE)
|
||||
requester = models.ForeignKey(
|
||||
"accounts.CustomUser", on_delete=models.CASCADE)
|
||||
questionnaire = models.ForeignKey(
|
||||
"questionnaires.Questionnaire", on_delete=models.SET_NULL, null=True, blank=True)
|
||||
documents = models.ManyToManyField("document_requests.DocumentRequestUnit")
|
||||
date_requested = models.DateTimeField(default=now, editable=False)
|
||||
college = models.CharField(max_length=64, blank=False, null=False)
|
||||
|
@ -23,11 +27,13 @@ class DocumentRequest(models.Model):
|
|||
("denied", "Denied"),
|
||||
)
|
||||
|
||||
status = models.CharField(max_length=32, choices=STATUS_CHOICES, default="pending")
|
||||
status = models.CharField(
|
||||
max_length=32, choices=STATUS_CHOICES, default="pending")
|
||||
|
||||
TYPE_CHOICES = (
|
||||
("softcopy", "Softcopy"),
|
||||
("hardcopy", "Hardcopy"),
|
||||
)
|
||||
|
||||
type = models.CharField(max_length=16, choices=TYPE_CHOICES, default="softcopy")
|
||||
type = models.CharField(
|
||||
max_length=16, choices=TYPE_CHOICES, default="softcopy")
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from rest_framework import serializers
|
||||
from documents.models import Document
|
||||
from documents.serializers import DocumentSerializer, DocumentFileSerializer
|
||||
from questionnaires.models import Questionnaire
|
||||
from accounts.models import CustomUser
|
||||
from emails.templates import RequestUpdateEmail
|
||||
from .models import DocumentRequest, DocumentRequestUnit
|
||||
|
@ -24,7 +25,8 @@ class DocumentRequestCreationSerializer(serializers.ModelSerializer):
|
|||
documents = DocumentRequestUnitCreationSerializer(many=True, required=True)
|
||||
college = serializers.CharField(max_length=64)
|
||||
purpose = serializers.CharField(max_length=512)
|
||||
type = serializers.ChoiceField(choices=DocumentRequest.TYPE_CHOICES, required=True)
|
||||
type = serializers.ChoiceField(
|
||||
choices=DocumentRequest.TYPE_CHOICES, required=True)
|
||||
|
||||
class Meta:
|
||||
model = DocumentRequest
|
||||
|
@ -79,6 +81,12 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
|
|||
queryset=CustomUser.objects.all(),
|
||||
required=False,
|
||||
)
|
||||
requester = serializers.SlugRelatedField(
|
||||
many=False,
|
||||
slug_field="id",
|
||||
queryset=CustomUser.objects.all(),
|
||||
required=False,
|
||||
)
|
||||
purpose = serializers.CharField(max_length=512)
|
||||
date_requested = serializers.DateTimeField(
|
||||
format="%m-%d-%Y %I:%M %p", read_only=True
|
||||
|
@ -108,10 +116,10 @@ class DocumentRequestSerializer(serializers.ModelSerializer):
|
|||
]
|
||||
|
||||
def get_documents(self, obj):
|
||||
if obj.status != "approved":
|
||||
serializer_class = DocumentRequestUnitSerializer
|
||||
else:
|
||||
if obj.questionnaire and obj.status == "approved":
|
||||
serializer_class = DocumentRequestUnitWithFileSerializer
|
||||
else:
|
||||
serializer_class = DocumentRequestUnitSerializer
|
||||
return serializer_class(obj.documents, many=True).data
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
annotated-types==0.7.0
|
||||
anyio==4.7.0
|
||||
asgiref==3.8.1
|
||||
attrs==24.2.0
|
||||
black==24.10.0
|
||||
|
@ -22,6 +24,9 @@ drf-spectacular-sidecar==2024.11.1
|
|||
filelock==3.16.1
|
||||
fsspec==2024.10.0
|
||||
gunicorn==23.0.0
|
||||
h11==0.14.0
|
||||
httpcore==1.0.7
|
||||
httpx==0.27.2
|
||||
idna==3.10
|
||||
inflection==0.5.1
|
||||
Jinja2==3.1.4
|
||||
|
@ -32,11 +37,14 @@ mpmath==1.3.0
|
|||
mypy-extensions==1.0.0
|
||||
networkx==3.4.2
|
||||
oauthlib==3.2.2
|
||||
ollama==0.4.4
|
||||
packaging==24.2
|
||||
pathspec==0.12.1
|
||||
pillow==11.0.0
|
||||
platformdirs==4.3.6
|
||||
pycparser==2.22
|
||||
pydantic==2.10.3
|
||||
pydantic_core==2.27.1
|
||||
pyflakes==3.2.0
|
||||
PyJWT==2.10.0
|
||||
PyMuPDF==1.24.14
|
||||
|
@ -49,6 +57,7 @@ requests==2.32.3
|
|||
requests-oauthlib==2.0.0
|
||||
rpds-py==0.21.0
|
||||
setuptools==70.2.0
|
||||
sniffio==1.3.1
|
||||
social-auth-app-django==5.4.2
|
||||
social-auth-core==4.5.4
|
||||
sqlparse==0.5.2
|
||||
|
|
Loading…
Reference in a new issue