mirror of
https://github.com/lemeow125/DRF_Template.git
synced 2025-01-18 10:23:01 +08:00
Clean up docker-compose and run Black formatter over entire codebase
This commit is contained in:
parent
6c232b3e89
commit
069aba80b1
60 changed files with 1946 additions and 1485 deletions
1
Pipfile
1
Pipfile
|
@ -35,6 +35,7 @@ gunicorn = "*"
|
|||
django-silk = "*"
|
||||
django-redis = "*"
|
||||
granian = "*"
|
||||
black = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
||||
|
|
1492
Pipfile.lock
generated
1492
Pipfile.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -6,11 +6,13 @@ from .models import CustomUser
|
|||
|
||||
class CustomUserAdmin(UserAdmin):
|
||||
model = CustomUser
|
||||
list_display = ('id', 'is_active', 'user_group',) + UserAdmin.list_display
|
||||
list_display = (
|
||||
"id",
|
||||
"is_active",
|
||||
"user_group",
|
||||
) + UserAdmin.list_display
|
||||
# Editable fields per instance
|
||||
fieldsets = UserAdmin.fieldsets + (
|
||||
(None, {'fields': ('avatar',)}),
|
||||
)
|
||||
fieldsets = UserAdmin.fieldsets + ((None, {"fields": ("avatar",)}),)
|
||||
|
||||
|
||||
admin.site.register(CustomUser, CustomUserAdmin)
|
||||
|
|
|
@ -2,8 +2,8 @@ from django.apps import AppConfig
|
|||
|
||||
|
||||
class AccountsConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'accounts'
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "accounts"
|
||||
|
||||
def ready(self):
|
||||
import accounts.signals
|
||||
|
|
|
@ -13,38 +13,145 @@ class Migration(migrations.Migration):
|
|||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('auth', '0012_alter_user_first_name_max_length'),
|
||||
('user_groups', '0001_initial'),
|
||||
("auth", "0012_alter_user_first_name_max_length"),
|
||||
("user_groups", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='CustomUser',
|
||||
name="CustomUser",
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('password', models.CharField(max_length=128, verbose_name='password')),
|
||||
('last_login', models.DateTimeField(blank=True, null=True, verbose_name='last login')),
|
||||
('is_superuser', models.BooleanField(default=False, help_text='Designates that this user has all permissions without explicitly assigning them.', verbose_name='superuser status')),
|
||||
('username', models.CharField(error_messages={'unique': 'A user with that username already exists.'}, help_text='Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.', max_length=150, unique=True, validators=[django.contrib.auth.validators.UnicodeUsernameValidator()], verbose_name='username')),
|
||||
('first_name', models.CharField(blank=True, max_length=150, verbose_name='first name')),
|
||||
('last_name', models.CharField(blank=True, max_length=150, verbose_name='last name')),
|
||||
('email', models.EmailField(blank=True, max_length=254, verbose_name='email address')),
|
||||
('is_staff', models.BooleanField(default=False, help_text='Designates whether the user can log into this admin site.', verbose_name='staff status')),
|
||||
('is_active', models.BooleanField(default=True, help_text='Designates whether this user should be treated as active. Unselect this instead of deleting accounts.', verbose_name='active')),
|
||||
('date_joined', models.DateTimeField(default=django.utils.timezone.now, verbose_name='date joined')),
|
||||
('avatar', django_resized.forms.ResizedImageField(crop=None, force_format='WEBP', keep_meta=True, null=True, quality=100, scale=None, size=[1920, 1080], upload_to='avatars/')),
|
||||
('onboarding', models.BooleanField(default=True)),
|
||||
('groups', models.ManyToManyField(blank=True, help_text='The groups this user belongs to. A user will get all permissions granted to each of their groups.', related_name='user_set', related_query_name='user', to='auth.group', verbose_name='groups')),
|
||||
('user_group', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, to='user_groups.usergroup')),
|
||||
('user_permissions', models.ManyToManyField(blank=True, help_text='Specific permissions for this user.', related_name='user_set', related_query_name='user', to='auth.permission', verbose_name='user permissions')),
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("password", models.CharField(max_length=128, verbose_name="password")),
|
||||
(
|
||||
"last_login",
|
||||
models.DateTimeField(
|
||||
blank=True, null=True, verbose_name="last login"
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_superuser",
|
||||
models.BooleanField(
|
||||
default=False,
|
||||
help_text="Designates that this user has all permissions without explicitly assigning them.",
|
||||
verbose_name="superuser status",
|
||||
),
|
||||
),
|
||||
(
|
||||
"username",
|
||||
models.CharField(
|
||||
error_messages={
|
||||
"unique": "A user with that username already exists."
|
||||
},
|
||||
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
|
||||
max_length=150,
|
||||
unique=True,
|
||||
validators=[
|
||||
django.contrib.auth.validators.UnicodeUsernameValidator()
|
||||
],
|
||||
verbose_name="username",
|
||||
),
|
||||
),
|
||||
(
|
||||
"first_name",
|
||||
models.CharField(
|
||||
blank=True, max_length=150, verbose_name="first name"
|
||||
),
|
||||
),
|
||||
(
|
||||
"last_name",
|
||||
models.CharField(
|
||||
blank=True, max_length=150, verbose_name="last name"
|
||||
),
|
||||
),
|
||||
(
|
||||
"email",
|
||||
models.EmailField(
|
||||
blank=True, max_length=254, verbose_name="email address"
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_staff",
|
||||
models.BooleanField(
|
||||
default=False,
|
||||
help_text="Designates whether the user can log into this admin site.",
|
||||
verbose_name="staff status",
|
||||
),
|
||||
),
|
||||
(
|
||||
"is_active",
|
||||
models.BooleanField(
|
||||
default=True,
|
||||
help_text="Designates whether this user should be treated as active. Unselect this instead of deleting accounts.",
|
||||
verbose_name="active",
|
||||
),
|
||||
),
|
||||
(
|
||||
"date_joined",
|
||||
models.DateTimeField(
|
||||
default=django.utils.timezone.now, verbose_name="date joined"
|
||||
),
|
||||
),
|
||||
(
|
||||
"avatar",
|
||||
django_resized.forms.ResizedImageField(
|
||||
crop=None,
|
||||
force_format="WEBP",
|
||||
keep_meta=True,
|
||||
null=True,
|
||||
quality=100,
|
||||
scale=None,
|
||||
size=[1920, 1080],
|
||||
upload_to="avatars/",
|
||||
),
|
||||
),
|
||||
("onboarding", models.BooleanField(default=True)),
|
||||
(
|
||||
"groups",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
help_text="The groups this user belongs to. A user will get all permissions granted to each of their groups.",
|
||||
related_name="user_set",
|
||||
related_query_name="user",
|
||||
to="auth.group",
|
||||
verbose_name="groups",
|
||||
),
|
||||
),
|
||||
(
|
||||
"user_group",
|
||||
models.ForeignKey(
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="user_groups.usergroup",
|
||||
),
|
||||
),
|
||||
(
|
||||
"user_permissions",
|
||||
models.ManyToManyField(
|
||||
blank=True,
|
||||
help_text="Specific permissions for this user.",
|
||||
related_name="user_set",
|
||||
related_query_name="user",
|
||||
to="auth.permission",
|
||||
verbose_name="user permissions",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'user',
|
||||
'verbose_name_plural': 'users',
|
||||
'abstract': False,
|
||||
"verbose_name": "user",
|
||||
"verbose_name_plural": "users",
|
||||
"abstract": False,
|
||||
},
|
||||
managers=[
|
||||
('objects', django.contrib.auth.models.UserManager()),
|
||||
("objects", django.contrib.auth.models.UserManager()),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -15,14 +15,16 @@ class CustomUser(AbstractUser):
|
|||
# is_admin inherited from base user class
|
||||
|
||||
avatar = ResizedImageField(
|
||||
null=True, force_format="WEBP", quality=100, upload_to='avatars/')
|
||||
null=True, force_format="WEBP", quality=100, upload_to="avatars/"
|
||||
)
|
||||
|
||||
# Used for onboarding processes
|
||||
# Set this to False later on once the user makes actions
|
||||
onboarding = models.BooleanField(default=True)
|
||||
|
||||
user_group = models.ForeignKey(
|
||||
'user_groups.UserGroup', on_delete=models.SET_NULL, null=True)
|
||||
"user_groups.UserGroup", on_delete=models.SET_NULL, null=True
|
||||
)
|
||||
|
||||
@property
|
||||
def group_member(self):
|
||||
|
@ -57,4 +59,4 @@ class CustomUser(AbstractUser):
|
|||
|
||||
@property
|
||||
def admin_url(self):
|
||||
return reverse('admin:users_customuser_change', args=(self.pk,))
|
||||
return reverse("admin:users_customuser_change", args=(self.pk,))
|
||||
|
|
|
@ -8,13 +8,14 @@ from django.core.cache import cache
|
|||
from django.core import exceptions as django_exceptions
|
||||
from rest_framework.settings import api_settings
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
|
||||
# There can be multiple subject instances with the same name, only differing in course, year level, and semester. We filter them here
|
||||
|
||||
|
||||
class SimpleCustomUserSerializer(ModelSerializer):
|
||||
class Meta(BaseUserSerializer.Meta):
|
||||
model = CustomUser
|
||||
fields = ('id', 'username', 'email', 'full_name')
|
||||
fields = ("id", "username", "email", "full_name")
|
||||
|
||||
|
||||
class CustomUserSerializer(BaseUserSerializer):
|
||||
|
@ -22,19 +23,36 @@ class CustomUserSerializer(BaseUserSerializer):
|
|||
|
||||
class Meta(BaseUserSerializer.Meta):
|
||||
model = CustomUser
|
||||
fields = ('id', 'username', 'email', 'avatar', 'first_name',
|
||||
'last_name', 'user_group', 'group_member', 'group_owner')
|
||||
read_only_fields = ('id', 'username', 'email', 'user_group',
|
||||
'group_member', 'group_owner')
|
||||
fields = (
|
||||
"id",
|
||||
"username",
|
||||
"email",
|
||||
"avatar",
|
||||
"first_name",
|
||||
"is_new",
|
||||
"last_name",
|
||||
"user_group",
|
||||
"group_member",
|
||||
"group_owner",
|
||||
)
|
||||
read_only_fields = (
|
||||
"id",
|
||||
"username",
|
||||
"email",
|
||||
"user_group",
|
||||
"group_member",
|
||||
"group_owner",
|
||||
)
|
||||
|
||||
def to_representation(self, instance):
|
||||
representation = super().to_representation(instance)
|
||||
representation['user_group'] = SimpleUserGroupSerializer(
|
||||
instance.user_group, many=False).data
|
||||
representation["user_group"] = SimpleUserGroupSerializer(
|
||||
instance.user_group, many=False
|
||||
).data
|
||||
return representation
|
||||
|
||||
def update(self, instance, validated_data):
|
||||
cache.delete(f'user:{instance.id}')
|
||||
cache.delete(f"user:{instance.id}")
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
|
||||
|
@ -42,16 +60,18 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
|
|||
email = serializers.EmailField(required=True)
|
||||
username = serializers.CharField(required=True)
|
||||
password = serializers.CharField(
|
||||
write_only=True, style={'input_type': 'password', 'placeholder': 'Password'})
|
||||
write_only=True, style={"input_type": "password", "placeholder": "Password"}
|
||||
)
|
||||
first_name = serializers.CharField(
|
||||
required=True, allow_blank=False, allow_null=False)
|
||||
required=True, allow_blank=False, allow_null=False
|
||||
)
|
||||
last_name = serializers.CharField(
|
||||
required=True, allow_blank=False, allow_null=False)
|
||||
required=True, allow_blank=False, allow_null=False
|
||||
)
|
||||
|
||||
class Meta:
|
||||
model = CustomUser
|
||||
fields = ['email', 'username', 'password',
|
||||
'first_name', 'last_name']
|
||||
fields = ["email", "username", "password", "first_name", "last_name"]
|
||||
|
||||
def validate(self, attrs):
|
||||
user_attrs = attrs.copy()
|
||||
|
@ -69,14 +89,15 @@ class UserRegistrationSerializer(serializers.ModelSerializer):
|
|||
raise serializers.ValidationError({"password": errors})
|
||||
if self.Meta.model.objects.filter(username=attrs.get("username")).exists():
|
||||
raise serializers.ValidationError(
|
||||
"A user with that username already exists.")
|
||||
"A user with that username already exists."
|
||||
)
|
||||
return super().validate(attrs)
|
||||
|
||||
def create(self, validated_data):
|
||||
user = self.Meta.model(**validated_data)
|
||||
user.username = validated_data['username']
|
||||
user.username = validated_data["username"]
|
||||
user.is_active = False
|
||||
user.set_password(validated_data['password'])
|
||||
user.set_password(validated_data["password"])
|
||||
user.save()
|
||||
|
||||
return user
|
||||
|
|
|
@ -12,38 +12,37 @@ import json
|
|||
@receiver(post_migrate)
|
||||
def create_users(sender, **kwargs):
|
||||
if sender.name == "accounts":
|
||||
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
|
||||
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
|
||||
seed_data = json.loads(f.read())
|
||||
for user in seed_data['users']:
|
||||
USER = CustomUser.objects.filter(
|
||||
email=user['email']).first()
|
||||
for user in seed_data["users"]:
|
||||
USER = CustomUser.objects.filter(email=user["email"]).first()
|
||||
if not USER:
|
||||
if user['password'] == 'USE_REGULAR':
|
||||
password = get_secret('SEED_DATA_PASSWORD')
|
||||
elif user['password'] == 'USE_ADMIN':
|
||||
password = get_secret('SEED_DATA_ADMIN_PASSWORD')
|
||||
if user["password"] == "USE_REGULAR":
|
||||
password = get_secret("SEED_DATA_PASSWORD")
|
||||
elif user["password"] == "USE_ADMIN":
|
||||
password = get_secret("SEED_DATA_ADMIN_PASSWORD")
|
||||
else:
|
||||
password = user['password']
|
||||
if (user['is_superuser'] == True):
|
||||
password = user["password"]
|
||||
if user["is_superuser"] == True:
|
||||
# Admin users are created regardless of SEED_DATA value
|
||||
USER = CustomUser.objects.create_superuser(
|
||||
username=user['username'],
|
||||
email=user['email'],
|
||||
username=user["username"],
|
||||
email=user["email"],
|
||||
password=password,
|
||||
)
|
||||
print('Created Superuser:', user['email'])
|
||||
print("Created Superuser:", user["email"])
|
||||
else:
|
||||
# Only create non-admin users if SEED_DATA=True
|
||||
if SEED_DATA:
|
||||
USER = CustomUser.objects.create_user(
|
||||
username=user['email'],
|
||||
email=user['email'],
|
||||
username=user["email"],
|
||||
email=user["email"],
|
||||
password=password,
|
||||
)
|
||||
print('Created User:', user['email'])
|
||||
print("Created User:", user["email"])
|
||||
|
||||
USER.first_name = user['first_name']
|
||||
USER.last_name = user['last_name']
|
||||
USER.first_name = user["first_name"]
|
||||
USER.last_name = user["last_name"]
|
||||
USER.is_active = True
|
||||
USER.save()
|
||||
|
||||
|
@ -51,53 +50,57 @@ def create_users(sender, **kwargs):
|
|||
@receiver(post_migrate)
|
||||
def create_celery_beat_schedules(sender, **kwargs):
|
||||
if sender.name == "django_celery_beat":
|
||||
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
|
||||
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
|
||||
seed_data = json.loads(f.read())
|
||||
# Creating Schedules
|
||||
for schedule in seed_data['schedules']:
|
||||
if schedule['type'] == 'crontab':
|
||||
for schedule in seed_data["schedules"]:
|
||||
if schedule["type"] == "crontab":
|
||||
# Check if Schedule already exists
|
||||
SCHEDULE = CrontabSchedule.objects.filter(minute=schedule['minute'],
|
||||
hour=schedule['hour'],
|
||||
day_of_week=schedule['day_of_week'],
|
||||
day_of_month=schedule['day_of_month'],
|
||||
month_of_year=schedule['month_of_year'],
|
||||
timezone=schedule['timezone']
|
||||
).first()
|
||||
SCHEDULE = CrontabSchedule.objects.filter(
|
||||
minute=schedule["minute"],
|
||||
hour=schedule["hour"],
|
||||
day_of_week=schedule["day_of_week"],
|
||||
day_of_month=schedule["day_of_month"],
|
||||
month_of_year=schedule["month_of_year"],
|
||||
timezone=schedule["timezone"],
|
||||
).first()
|
||||
# If it does not exist, create a new Schedule
|
||||
if not SCHEDULE:
|
||||
SCHEDULE = CrontabSchedule.objects.create(
|
||||
minute=schedule['minute'],
|
||||
hour=schedule['hour'],
|
||||
day_of_week=schedule['day_of_week'],
|
||||
day_of_month=schedule['day_of_month'],
|
||||
month_of_year=schedule['month_of_year'],
|
||||
timezone=schedule['timezone']
|
||||
minute=schedule["minute"],
|
||||
hour=schedule["hour"],
|
||||
day_of_week=schedule["day_of_week"],
|
||||
day_of_month=schedule["day_of_month"],
|
||||
month_of_year=schedule["month_of_year"],
|
||||
timezone=schedule["timezone"],
|
||||
)
|
||||
print(
|
||||
f'Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}')
|
||||
f"Created Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f'Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists')
|
||||
for task in seed_data['scheduled_tasks']:
|
||||
TASK = PeriodicTask.objects.filter(name=task['name']).first()
|
||||
f"Crontab Schedule for Hour:{SCHEDULE.hour},Minute:{SCHEDULE.minute} already exists"
|
||||
)
|
||||
for task in seed_data["scheduled_tasks"]:
|
||||
TASK = PeriodicTask.objects.filter(name=task["name"]).first()
|
||||
if not TASK:
|
||||
if task['schedule']['type'] == 'crontab':
|
||||
SCHEDULE = CrontabSchedule.objects.filter(minute=task['schedule']['minute'],
|
||||
hour=task['schedule']['hour'],
|
||||
day_of_week=task['schedule']['day_of_week'],
|
||||
day_of_month=task['schedule']['day_of_month'],
|
||||
month_of_year=task['schedule']['month_of_year'],
|
||||
timezone=task['schedule']['timezone']
|
||||
).first()
|
||||
if task["schedule"]["type"] == "crontab":
|
||||
SCHEDULE = CrontabSchedule.objects.filter(
|
||||
minute=task["schedule"]["minute"],
|
||||
hour=task["schedule"]["hour"],
|
||||
day_of_week=task["schedule"]["day_of_week"],
|
||||
day_of_month=task["schedule"]["day_of_month"],
|
||||
month_of_year=task["schedule"]["month_of_year"],
|
||||
timezone=task["schedule"]["timezone"],
|
||||
).first()
|
||||
TASK = PeriodicTask.objects.create(
|
||||
crontab=SCHEDULE,
|
||||
name=task['name'],
|
||||
task=task['task'],
|
||||
enabled=task['enabled']
|
||||
name=task["name"],
|
||||
task=task["task"],
|
||||
enabled=task["enabled"],
|
||||
)
|
||||
print(f'Created Periodic Task: {TASK.name}')
|
||||
print(f"Created Periodic Task: {TASK.name}")
|
||||
else:
|
||||
raise Exception('Schedule for Periodic Task not found')
|
||||
raise Exception("Schedule for Periodic Task not found")
|
||||
else:
|
||||
print(f'Periodic Task: {TASK.name} already exists')
|
||||
print(f"Periodic Task: {TASK.name} already exists")
|
||||
|
|
|
@ -4,20 +4,26 @@ from celery import shared_task
|
|||
@shared_task
|
||||
def get_paying_users():
|
||||
from subscriptions.models import UserSubscription
|
||||
|
||||
# Get a list of user subscriptions
|
||||
active_subscriptions = UserSubscription.objects.filter(
|
||||
valid=True).distinct('user')
|
||||
active_subscriptions = UserSubscription.objects.filter(valid=True).distinct("user")
|
||||
|
||||
# Get paying users
|
||||
active_users = []
|
||||
|
||||
# Paying regular users
|
||||
active_users += [
|
||||
subscription.user.id for subscription in active_subscriptions if subscription.user is not None and subscription.user.user_group is None]
|
||||
subscription.user.id
|
||||
for subscription in active_subscriptions
|
||||
if subscription.user is not None and subscription.user.user_group is None
|
||||
]
|
||||
|
||||
# Paying users within groups
|
||||
active_users += [
|
||||
subscription.user_group.members for subscription in active_subscriptions if subscription.user_group is not None and subscription.user is None]
|
||||
subscription.user_group.members
|
||||
for subscription in active_subscriptions
|
||||
if subscription.user_group is not None and subscription.user is None
|
||||
]
|
||||
|
||||
# Return paying users
|
||||
return active_users
|
||||
|
|
|
@ -3,10 +3,10 @@ from rest_framework.routers import DefaultRouter
|
|||
from accounts import views
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r'users', views.CustomUserViewSet, basename='users')
|
||||
router.register(r"users", views.CustomUserViewSet, basename="users")
|
||||
|
||||
urlpatterns = [
|
||||
path('', include(router.urls)),
|
||||
path('', include('djoser.urls')),
|
||||
path('', include('djoser.urls.jwt')),
|
||||
path("", include(router.urls)),
|
||||
path("", include("djoser.urls")),
|
||||
path("", include("djoser.urls.jwt")),
|
||||
]
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.translation import gettext as _
|
||||
import re
|
||||
|
@ -6,9 +5,10 @@ import re
|
|||
|
||||
class UppercaseValidator(object):
|
||||
def validate(self, password, user=None):
|
||||
if not re.findall('[A-Z]', password):
|
||||
if not re.findall("[A-Z]", password):
|
||||
raise ValidationError(
|
||||
_("The password must contain at least 1 uppercase letter (A-Z)."))
|
||||
_("The password must contain at least 1 uppercase letter (A-Z).")
|
||||
)
|
||||
|
||||
def get_help_text(self):
|
||||
return _("Your password must contain at least 1 uppercase letter (A-Z).")
|
||||
|
@ -16,9 +16,10 @@ class UppercaseValidator(object):
|
|||
|
||||
class LowercaseValidator(object):
|
||||
def validate(self, password, user=None):
|
||||
if not re.findall('[a-z]', password):
|
||||
if not re.findall("[a-z]", password):
|
||||
raise ValidationError(
|
||||
_("The password must contain at least 1 lowercase letter (a-z)."))
|
||||
_("The password must contain at least 1 lowercase letter (a-z).")
|
||||
)
|
||||
|
||||
def get_help_text(self):
|
||||
return _("Your password must contain at least 1 lowercase letter (a-z).")
|
||||
|
@ -26,19 +27,25 @@ class LowercaseValidator(object):
|
|||
|
||||
class SpecialCharacterValidator(object):
|
||||
def validate(self, password, user=None):
|
||||
if not re.findall('[@#$%^&*()_+/\<>;:!?]', password):
|
||||
if not re.findall("[@#$%^&*()_+/\<>;:!?]", password):
|
||||
raise ValidationError(
|
||||
_("The password must contain at least 1 special character (@, #, $, etc.)."))
|
||||
_(
|
||||
"The password must contain at least 1 special character (@, #, $, etc.)."
|
||||
)
|
||||
)
|
||||
|
||||
def get_help_text(self):
|
||||
return _("Your password must contain at least 1 special character (@, #, $, etc.).")
|
||||
return _(
|
||||
"Your password must contain at least 1 special character (@, #, $, etc.)."
|
||||
)
|
||||
|
||||
|
||||
class NumberValidator(object):
|
||||
def validate(self, password, user=None):
|
||||
if not any(char.isdigit() for char in password):
|
||||
raise ValidationError(
|
||||
_("The password must contain at least one numerical digit (0-9)."))
|
||||
_("The password must contain at least one numerical digit (0-9).")
|
||||
)
|
||||
|
||||
def get_help_text(self):
|
||||
return _("Your password must contain at least numerical digit (0-9).")
|
||||
|
|
|
@ -22,28 +22,27 @@ class CustomUserViewSet(DjoserUserViewSet):
|
|||
user = self.request.user
|
||||
# If user is admin, show all active users
|
||||
if user.is_superuser:
|
||||
key = 'users'
|
||||
key = "users"
|
||||
# Get cache
|
||||
queryset = cache.get(key)
|
||||
# Set cache if stale or does not exist
|
||||
if not queryset:
|
||||
queryset = CustomUser.objects.filter(is_active=True)
|
||||
cache.set(key, queryset, 60*60)
|
||||
cache.set(key, queryset, 60 * 60)
|
||||
return queryset
|
||||
elif not user.user_group:
|
||||
key = f'user:{user.id}'
|
||||
key = f"user:{user.id}"
|
||||
queryset = cache.get(key)
|
||||
if not queryset:
|
||||
queryset = CustomUser.objects.filter(is_active=True)
|
||||
cache.set(key, queryset, 60*60)
|
||||
cache.set(key, queryset, 60 * 60)
|
||||
return queryset
|
||||
elif user.user_group:
|
||||
key = f'usergroup_users:{user.user_group.id}'
|
||||
key = f"usergroup_users:{user.user_group.id}"
|
||||
queryset = cache.get(key)
|
||||
if not queryset:
|
||||
queryset = CustomUser.objects.filter(
|
||||
user_group=user.user_group)
|
||||
cache.set(key, queryset, 60*60)
|
||||
queryset = CustomUser.objects.filter(user_group=user.user_group)
|
||||
cache.set(key, queryset, 60 * 60)
|
||||
return queryset
|
||||
else:
|
||||
return CustomUser.objects.none()
|
||||
|
@ -52,10 +51,10 @@ class CustomUserViewSet(DjoserUserViewSet):
|
|||
user = self.request.user
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'users')
|
||||
cache.delete(f'user:{user.id}')
|
||||
cache.delete(f"users")
|
||||
cache.delete(f"user:{user.id}")
|
||||
if user.user_group:
|
||||
cache.delete(f'usergroup_users:{user.user_group.id}')
|
||||
cache.delete(f"usergroup_users:{user.user_group.id}")
|
||||
|
||||
super().perform_update(serializer, *args, **kwargs)
|
||||
user = serializer.instance
|
||||
|
@ -84,16 +83,18 @@ class CustomUserViewSet(DjoserUserViewSet):
|
|||
settings.EMAIL.confirmation(self.request, context).send(to)
|
||||
|
||||
# Clear cache
|
||||
cache.delete('users')
|
||||
cache.delete(f'user:{user.id}')
|
||||
cache.delete("users")
|
||||
cache.delete(f"user:{user.id}")
|
||||
if user.user_group:
|
||||
cache.delete(f'usergroup_users:{user.user_group.id}')
|
||||
cache.delete(f"usergroup_users:{user.user_group.id}")
|
||||
|
||||
except Exception as e:
|
||||
print('Warning: Unable to send email')
|
||||
print("Warning: Unable to send email")
|
||||
print(e)
|
||||
|
||||
@action(methods=['post'], detail=False, url_path='activation', url_name='activation')
|
||||
@action(
|
||||
methods=["post"], detail=False, url_path="activation", url_name="activation"
|
||||
)
|
||||
def activation(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
@ -103,16 +104,16 @@ class CustomUserViewSet(DjoserUserViewSet):
|
|||
|
||||
# Construct a response with user's first name, last name, and email
|
||||
user_data = {
|
||||
'first_name': user.first_name,
|
||||
'last_name': user.last_name,
|
||||
'email': user.email,
|
||||
'username': user.username
|
||||
"first_name": user.first_name,
|
||||
"last_name": user.last_name,
|
||||
"email": user.email,
|
||||
"username": user.username,
|
||||
}
|
||||
|
||||
# Clear cache
|
||||
cache.delete('users')
|
||||
cache.delete(f'user:{user.id}')
|
||||
cache.delete("users")
|
||||
cache.delete(f"user:{user.id}")
|
||||
if user.user_group:
|
||||
cache.delete(f'usergroup_users:{user.user_group.id}')
|
||||
cache.delete(f"usergroup_users:{user.user_group.id}")
|
||||
|
||||
return Response(user_data, status=status.HTTP_200_OK)
|
||||
|
|
|
@ -1,27 +1,31 @@
|
|||
from django.conf.urls.static import static
|
||||
from django.contrib.staticfiles.urls import staticfiles_urlpatterns
|
||||
from django.urls import path, include
|
||||
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
|
||||
from drf_spectacular.views import (
|
||||
SpectacularAPIView,
|
||||
SpectacularRedocView,
|
||||
SpectacularSwaggerView,
|
||||
)
|
||||
from django.contrib import admin
|
||||
from config.settings import DEBUG, SERVE_MEDIA, MEDIA_ROOT
|
||||
|
||||
urlpatterns = [
|
||||
path('accounts/', include('accounts.urls')),
|
||||
path('subscriptions/', include('subscriptions.urls')),
|
||||
path('notifications/', include('notifications.urls')),
|
||||
path('billing/', include('billing.urls')),
|
||||
path('stripe/', include('payments.urls')),
|
||||
path('admin/', admin.site.urls),
|
||||
path('schema/', SpectacularAPIView.as_view(), name='schema'),
|
||||
path('swagger/',
|
||||
SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'),
|
||||
path('redoc/',
|
||||
SpectacularRedocView.as_view(url_name='schema'), name='redoc'),
|
||||
path("accounts/", include("accounts.urls")),
|
||||
path("subscriptions/", include("subscriptions.urls")),
|
||||
path("notifications/", include("notifications.urls")),
|
||||
path("billing/", include("billing.urls")),
|
||||
path("stripe/", include("payments.urls")),
|
||||
path("admin/", admin.site.urls),
|
||||
path("schema/", SpectacularAPIView.as_view(), name="schema"),
|
||||
path(
|
||||
"swagger/", SpectacularSwaggerView.as_view(url_name="schema"), name="swagger-ui"
|
||||
),
|
||||
path("redoc/", SpectacularRedocView.as_view(url_name="schema"), name="redoc"),
|
||||
]
|
||||
|
||||
# URLs for local development
|
||||
if DEBUG and SERVE_MEDIA:
|
||||
urlpatterns += staticfiles_urlpatterns()
|
||||
urlpatterns += static(
|
||||
'media/', document_root=MEDIA_ROOT)
|
||||
urlpatterns += static("media/", document_root=MEDIA_ROOT)
|
||||
if DEBUG:
|
||||
urlpatterns += [path('silk/', include('silk.urls', namespace='silk'))]
|
||||
urlpatterns += [path("silk/", include("silk.urls", namespace="silk"))]
|
||||
|
|
|
@ -2,6 +2,5 @@ from django.urls import path
|
|||
from billing import views
|
||||
|
||||
urlpatterns = [
|
||||
path('',
|
||||
views.BillingHistoryView.as_view()),
|
||||
path("", views.BillingHistoryView.as_view()),
|
||||
]
|
||||
|
|
|
@ -24,7 +24,7 @@ class BillingHistoryView(APIView):
|
|||
email = requesting_user.email
|
||||
|
||||
# Check cache
|
||||
key = f'billing_user:{requesting_user.id}'
|
||||
key = f"billing_user:{requesting_user.id}"
|
||||
billing_history = cache.get(key)
|
||||
|
||||
if not billing_history:
|
||||
|
@ -39,23 +39,25 @@ class BillingHistoryView(APIView):
|
|||
|
||||
if len(customers.data) > 0:
|
||||
# Retrieve the customer's charges (billing history)
|
||||
charges = stripe.Charge.list(
|
||||
limit=10, customer=customer.id)
|
||||
charges = stripe.Charge.list(limit=10, customer=customer.id)
|
||||
|
||||
# Prepare the response
|
||||
billing_history = [
|
||||
{
|
||||
'email': charge['billing_details']['email'],
|
||||
'amount_charged': int(charge['amount']/100),
|
||||
'paid': charge['paid'],
|
||||
'refunded': int(charge['amount_refunded']/100) > 0,
|
||||
'amount_refunded': int(charge['amount_refunded']/100),
|
||||
'last_4': charge['payment_method_details']['card']['last4'],
|
||||
'receipt_link': charge['receipt_url'],
|
||||
'timestamp': datetime.fromtimestamp(charge['created']).strftime("%m-%d-%Y %I:%M %p"),
|
||||
} for charge in charges.auto_paging_iter()
|
||||
"email": charge["billing_details"]["email"],
|
||||
"amount_charged": int(charge["amount"] / 100),
|
||||
"paid": charge["paid"],
|
||||
"refunded": int(charge["amount_refunded"] / 100) > 0,
|
||||
"amount_refunded": int(charge["amount_refunded"] / 100),
|
||||
"last_4": charge["payment_method_details"]["card"]["last4"],
|
||||
"receipt_link": charge["receipt_url"],
|
||||
"timestamp": datetime.fromtimestamp(
|
||||
charge["created"]
|
||||
).strftime("%m-%d-%Y %I:%M %p"),
|
||||
}
|
||||
for charge in charges.auto_paging_iter()
|
||||
]
|
||||
|
||||
cache.set(key, billing_history, 60*60)
|
||||
cache.set(key, billing_history, 60 * 60)
|
||||
|
||||
return Response(billing_history, status=status.HTTP_200_OK)
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .celery import app as celery_app
|
||||
|
||||
__all__ = ('celery_app',)
|
||||
__all__ = ("celery_app",)
|
||||
|
|
|
@ -11,6 +11,6 @@ import os
|
|||
|
||||
from django.core.asgi import get_asgi_application
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||
|
||||
application = get_asgi_application()
|
||||
|
|
|
@ -3,15 +3,15 @@ import os
|
|||
|
||||
|
||||
# Set the default Django settings module for the 'celery' program.
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||
|
||||
app = Celery('config')
|
||||
app = Celery("config")
|
||||
|
||||
# Using a string here means the worker doesn't have to serialize
|
||||
# the configuration object to child processes.
|
||||
# - namespace='CELERY' means all celery-related configuration keys
|
||||
# should have a `CELERY_` prefix.
|
||||
app.config_from_object('django.conf:settings', namespace='CELERY')
|
||||
app.config_from_object("django.conf:settings", namespace="CELERY")
|
||||
|
||||
# Load task modules from all registered Django apps.
|
||||
app.autodiscover_tasks()
|
||||
|
|
|
@ -10,9 +10,9 @@ BASE_DIR = Path(__file__).resolve().parent.parent
|
|||
ROOT_DIR = Path(__file__).resolve().parent.parent.parent
|
||||
|
||||
# If you're hosting this with a secret provider, have this set to True
|
||||
USE_VAULT = bool(os.getenv('USE_VAULT', False) == 'True')
|
||||
USE_VAULT = bool(os.getenv("USE_VAULT", False) == "True")
|
||||
# Have this set to True to serve media and static contents directly via Django
|
||||
SERVE_MEDIA = bool(os.getenv('SERVE_MEDIA', False) == 'True')
|
||||
SERVE_MEDIA = bool(os.getenv("SERVE_MEDIA", False) == "True")
|
||||
|
||||
load_dotenv(find_dotenv())
|
||||
|
||||
|
@ -35,98 +35,97 @@ def get_secret(secret_name):
|
|||
|
||||
|
||||
# URL Prefixes
|
||||
URL_SCHEME = 'https' if (get_secret('USE_HTTPS') == 'True') else 'http'
|
||||
URL_SCHEME = "https" if (get_secret("USE_HTTPS") == "True") else "http"
|
||||
# Backend
|
||||
BACKEND_ADDRESS = get_secret('BACKEND_ADDRESS')
|
||||
BACKEND_PORT = get_secret('BACKEND_PORT')
|
||||
BACKEND_ADDRESS = get_secret("BACKEND_ADDRESS")
|
||||
BACKEND_PORT = get_secret("BACKEND_PORT")
|
||||
# Frontend
|
||||
FRONTEND_ADDRESS = get_secret('FRONTEND_ADDRESS')
|
||||
FRONTEND_PORT = get_secret('FRONTEND_PORT')
|
||||
FRONTEND_ADDRESS = get_secret("FRONTEND_ADDRESS")
|
||||
FRONTEND_PORT = get_secret("FRONTEND_PORT")
|
||||
|
||||
ALLOWED_HOSTS = ['*']
|
||||
ALLOWED_HOSTS = ["*"]
|
||||
CSRF_TRUSTED_ORIGINS = [
|
||||
# Frontend
|
||||
f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}',
|
||||
f'{URL_SCHEME}://{FRONTEND_ADDRESS}', # For external domains
|
||||
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}",
|
||||
f"{URL_SCHEME}://{FRONTEND_ADDRESS}", # For external domains
|
||||
# Backend
|
||||
f'{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}',
|
||||
f'{URL_SCHEME}://{BACKEND_ADDRESS}' # For external domains
|
||||
f"{URL_SCHEME}://{BACKEND_ADDRESS}:{BACKEND_PORT}",
|
||||
f"{URL_SCHEME}://{BACKEND_ADDRESS}", # For external domains
|
||||
# You can also set up https://*.name.xyz for wildcards here
|
||||
]
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = (get_secret('BACKEND_DEBUG') == 'True')
|
||||
DEBUG = get_secret("BACKEND_DEBUG") == "True"
|
||||
# Determines whether or not to insert test data within tables
|
||||
SEED_DATA = (get_secret('SEED_DATA') == 'True')
|
||||
SEED_DATA = get_secret("SEED_DATA") == "True"
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = get_secret('SECRET_KEY')
|
||||
SECRET_KEY = get_secret("SECRET_KEY")
|
||||
# Selenium Config
|
||||
# Initiate CAPTCHA solver in test mode
|
||||
CAPTCHA_TESTING = (get_secret('CAPTCHA_TESTING') == 'True')
|
||||
CAPTCHA_TESTING = get_secret("CAPTCHA_TESTING") == "True"
|
||||
# If using Selenium and/or the provided CAPTCHA solver, determines whether or not to use proxies
|
||||
USE_PROXY = (get_secret('USE_PROXY') == 'True')
|
||||
USE_PROXY = get_secret("USE_PROXY") == "True"
|
||||
|
||||
# Stripe (For payments)
|
||||
STRIPE_SECRET_KEY = get_secret(
|
||||
"STRIPE_SECRET_KEY")
|
||||
STRIPE_SECRET_WEBHOOK = get_secret('STRIPE_SECRET_WEBHOOK')
|
||||
STRIPE_CHECKOUT_URL = f''
|
||||
STRIPE_SECRET_KEY = get_secret("STRIPE_SECRET_KEY")
|
||||
STRIPE_SECRET_WEBHOOK = get_secret("STRIPE_SECRET_WEBHOOK")
|
||||
STRIPE_CHECKOUT_URL = f""
|
||||
|
||||
# Email credentials
|
||||
EMAIL_HOST = get_secret('EMAIL_HOST')
|
||||
EMAIL_HOST_USER = get_secret('EMAIL_HOST_USER')
|
||||
EMAIL_HOST_PASSWORD = get_secret('EMAIL_HOST_PASSWORD')
|
||||
EMAIL_PORT = get_secret('EMAIL_PORT')
|
||||
EMAIL_USE_TLS = (get_secret('EMAIL_USE_TLS') == 'True')
|
||||
EMAIL_ADDRESS = (get_secret('EMAIL_ADDRESS') == 'True')
|
||||
EMAIL_HOST = get_secret("EMAIL_HOST")
|
||||
EMAIL_HOST_USER = get_secret("EMAIL_HOST_USER")
|
||||
EMAIL_HOST_PASSWORD = get_secret("EMAIL_HOST_PASSWORD")
|
||||
EMAIL_PORT = get_secret("EMAIL_PORT")
|
||||
EMAIL_USE_TLS = get_secret("EMAIL_USE_TLS") == "True"
|
||||
EMAIL_ADDRESS = get_secret("EMAIL_ADDRESS") == "True"
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'config',
|
||||
'unfold',
|
||||
'unfold.contrib.filters',
|
||||
'unfold.contrib.simple_history',
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'storages',
|
||||
'django_extensions',
|
||||
'rest_framework',
|
||||
'rest_framework_simplejwt',
|
||||
'django_celery_results',
|
||||
'django_celery_beat',
|
||||
'simple_history',
|
||||
'djoser',
|
||||
'corsheaders',
|
||||
'drf_spectacular',
|
||||
'drf_spectacular_sidecar',
|
||||
'webdriver',
|
||||
'accounts',
|
||||
'user_groups',
|
||||
'subscriptions',
|
||||
'payments',
|
||||
'billing',
|
||||
'emails',
|
||||
'notifications',
|
||||
'search_results'
|
||||
"config",
|
||||
"unfold",
|
||||
"unfold.contrib.filters",
|
||||
"unfold.contrib.simple_history",
|
||||
"django.contrib.admin",
|
||||
"django.contrib.auth",
|
||||
"django.contrib.contenttypes",
|
||||
"django.contrib.sessions",
|
||||
"django.contrib.messages",
|
||||
"django.contrib.staticfiles",
|
||||
"storages",
|
||||
"django_extensions",
|
||||
"rest_framework",
|
||||
"rest_framework_simplejwt",
|
||||
"django_celery_results",
|
||||
"django_celery_beat",
|
||||
"simple_history",
|
||||
"djoser",
|
||||
"corsheaders",
|
||||
"drf_spectacular",
|
||||
"drf_spectacular_sidecar",
|
||||
"webdriver",
|
||||
"accounts",
|
||||
"user_groups",
|
||||
"subscriptions",
|
||||
"payments",
|
||||
"billing",
|
||||
"emails",
|
||||
"notifications",
|
||||
"search_results",
|
||||
]
|
||||
|
||||
if DEBUG:
|
||||
INSTALLED_APPS += ['silk']
|
||||
INSTALLED_APPS += ["silk"]
|
||||
MIDDLEWARE = [
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
"django.middleware.security.SecurityMiddleware",
|
||||
"silk.middleware.SilkyMiddleware",
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"corsheaders.middleware.CorsMiddleware",
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
]
|
||||
DJANGO_LOG_LEVEL = "DEBUG"
|
||||
# Enables VS Code debugger to break on raised exceptions
|
||||
|
@ -146,111 +145,101 @@ if DEBUG:
|
|||
}
|
||||
else:
|
||||
MIDDLEWARE = [
|
||||
'django.middleware.security.SecurityMiddleware',
|
||||
"django.middleware.security.SecurityMiddleware",
|
||||
"whitenoise.middleware.WhiteNoiseMiddleware",
|
||||
"django.contrib.sessions.middleware.SessionMiddleware",
|
||||
"corsheaders.middleware.CorsMiddleware",
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
"django.middleware.common.CommonMiddleware",
|
||||
"django.middleware.csrf.CsrfViewMiddleware",
|
||||
"django.contrib.auth.middleware.AuthenticationMiddleware",
|
||||
"django.contrib.messages.middleware.MessageMiddleware",
|
||||
"django.middleware.clickjacking.XFrameOptionsMiddleware",
|
||||
]
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/4.2/howto/static-files/
|
||||
|
||||
ROOT_URLCONF = 'config.urls'
|
||||
ROOT_URLCONF = "config.urls"
|
||||
if SERVE_MEDIA:
|
||||
# Cloud Storage Settings
|
||||
# This is assuming you use the same bucket for media and static containers
|
||||
CLOUD_BUCKET = get_secret('CLOUD_BUCKET')
|
||||
MEDIA_CONTAINER = get_secret('MEDIA_CONTAINER')
|
||||
STATIC_CONTAINER = get_secret('STATIC_CONTAINER')
|
||||
CLOUD_BUCKET = get_secret("CLOUD_BUCKET")
|
||||
MEDIA_CONTAINER = get_secret("MEDIA_CONTAINER")
|
||||
STATIC_CONTAINER = get_secret("STATIC_CONTAINER")
|
||||
|
||||
MEDIA_URL = f'https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/'
|
||||
MEDIA_ROOT = f'https://{CLOUD_BUCKET}/'
|
||||
MEDIA_URL = f"https://{CLOUD_BUCKET}/{MEDIA_CONTAINER}/"
|
||||
MEDIA_ROOT = f"https://{CLOUD_BUCKET}/"
|
||||
|
||||
STATIC_URL = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/'
|
||||
STATIC_ROOT = f'https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/'
|
||||
STATIC_URL = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
|
||||
STATIC_ROOT = f"https://{CLOUD_BUCKET}/{STATIC_CONTAINER}/"
|
||||
|
||||
# Consult django-storages documentation when filling in these values. This will vary depending on your cloud service provider
|
||||
STORAGES = {
|
||||
'default': {
|
||||
"default": {
|
||||
# TODO: Set this up here if you're using cloud storage
|
||||
'BACKEND': None,
|
||||
'OPTIONS': {
|
||||
"BACKEND": None,
|
||||
"OPTIONS": {
|
||||
# Optional parameters
|
||||
},
|
||||
},
|
||||
'staticfiles': {
|
||||
"staticfiles": {
|
||||
# TODO: Set this up here if you're using cloud storage
|
||||
'BACKEND': None,
|
||||
'OPTIONS': {
|
||||
"BACKEND": None,
|
||||
"OPTIONS": {
|
||||
# Optional parameters
|
||||
},
|
||||
},
|
||||
}
|
||||
else:
|
||||
STATIC_URL = 'static/'
|
||||
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
|
||||
STATIC_URL = "static/"
|
||||
STATIC_ROOT = os.path.join(BASE_DIR, "static")
|
||||
STATICFILES_STORAGE = "whitenoise.storage.CompressedManifestStaticFilesStorage"
|
||||
MEDIA_URL = 'api/v1/media/'
|
||||
MEDIA_ROOT = os.path.join(BASE_DIR, 'media')
|
||||
ROOT_URLCONF = 'config.urls'
|
||||
MEDIA_URL = "api/v1/media/"
|
||||
MEDIA_ROOT = os.path.join(BASE_DIR, "media")
|
||||
ROOT_URLCONF = "config.urls"
|
||||
|
||||
TEMPLATES = [
|
||||
{
|
||||
'BACKEND': 'django.template.backends.django.DjangoTemplates',
|
||||
'DIRS': [
|
||||
BASE_DIR / 'emails/templates/',
|
||||
"BACKEND": "django.template.backends.django.DjangoTemplates",
|
||||
"DIRS": [
|
||||
BASE_DIR / "emails/templates/",
|
||||
],
|
||||
'APP_DIRS': True,
|
||||
'OPTIONS': {
|
||||
'context_processors': [
|
||||
'django.template.context_processors.debug',
|
||||
'django.template.context_processors.request',
|
||||
'django.contrib.auth.context_processors.auth',
|
||||
'django.contrib.messages.context_processors.messages',
|
||||
"APP_DIRS": True,
|
||||
"OPTIONS": {
|
||||
"context_processors": [
|
||||
"django.template.context_processors.debug",
|
||||
"django.template.context_processors.request",
|
||||
"django.contrib.auth.context_processors.auth",
|
||||
"django.contrib.messages.context_processors.messages",
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
REST_FRAMEWORK = {
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': (
|
||||
'rest_framework_simplejwt.authentication.JWTAuthentication',
|
||||
"DEFAULT_AUTHENTICATION_CLASSES": (
|
||||
"rest_framework_simplejwt.authentication.JWTAuthentication",
|
||||
),
|
||||
'DEFAULT_THROTTLE_CLASSES': [
|
||||
|
||||
'rest_framework.throttling.AnonRateThrottle',
|
||||
|
||||
'rest_framework.throttling.UserRateThrottle'
|
||||
|
||||
"DEFAULT_THROTTLE_CLASSES": [
|
||||
"rest_framework.throttling.AnonRateThrottle",
|
||||
"rest_framework.throttling.UserRateThrottle",
|
||||
],
|
||||
|
||||
'DEFAULT_THROTTLE_RATES': {
|
||||
|
||||
'anon': '360/min',
|
||||
|
||||
'user': '1440/min'
|
||||
|
||||
},
|
||||
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
|
||||
"DEFAULT_THROTTLE_RATES": {"anon": "360/min", "user": "1440/min"},
|
||||
"DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema",
|
||||
}
|
||||
|
||||
# DRF-Spectacular
|
||||
SPECTACULAR_SETTINGS = {
|
||||
'TITLE': 'DRF-Template',
|
||||
'DESCRIPTION': 'A Template Project by Keannu Bernasol',
|
||||
'VERSION': '1.0.0',
|
||||
'SERVE_INCLUDE_SCHEMA': False,
|
||||
'SWAGGER_UI_DIST': 'SIDECAR',
|
||||
'SWAGGER_UI_FAVICON_HREF': 'SIDECAR',
|
||||
'REDOC_DIST': 'SIDECAR',
|
||||
"TITLE": "DRF-Template",
|
||||
"DESCRIPTION": "A Template Project by Keannu Bernasol",
|
||||
"VERSION": "1.0.0",
|
||||
"SERVE_INCLUDE_SCHEMA": False,
|
||||
"SWAGGER_UI_DIST": "SIDECAR",
|
||||
"SWAGGER_UI_FAVICON_HREF": "SIDECAR",
|
||||
"REDOC_DIST": "SIDECAR",
|
||||
}
|
||||
|
||||
WSGI_APPLICATION = 'config.wsgi.application'
|
||||
WSGI_APPLICATION = "config.wsgi.application"
|
||||
|
||||
# If you're using an external connection bouncer (eg. PgBouncer), server side cursors must be disabled to avoid any issues
|
||||
USE_BOUNCER = get_secret("USE_BOUNCER")
|
||||
|
@ -266,15 +255,13 @@ else:
|
|||
DATABASES = {
|
||||
"default": {
|
||||
"ENGINE": "django.db.backends.postgresql",
|
||||
'DISABLE_SERVER_SIDE_CURSORS': DISABLE_SERVER_SIDE_CURSORS,
|
||||
"DISABLE_SERVER_SIDE_CURSORS": DISABLE_SERVER_SIDE_CURSORS,
|
||||
"NAME": get_secret("DB_DATABASE"),
|
||||
"USER": get_secret("DB_USERNAME"),
|
||||
"PASSWORD": get_secret("DB_PASSWORD"),
|
||||
"HOST": DB_HOST,
|
||||
"PORT": DB_PORT,
|
||||
"OPTIONS": {
|
||||
"sslmode": get_secret("DB_SSL_MODE")
|
||||
},
|
||||
"OPTIONS": {"sslmode": get_secret("DB_SSL_MODE")},
|
||||
}
|
||||
}
|
||||
# Django Cache
|
||||
|
@ -284,34 +271,34 @@ CACHES = {
|
|||
"LOCATION": f"redis://{get_secret('REDIS_HOST')}:{get_secret('REDIS_PORT')}/2",
|
||||
"OPTIONS": {
|
||||
"CLIENT_CLASS": "django_redis.client.DefaultClient",
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
AUTH_USER_MODEL = 'accounts.CustomUser'
|
||||
AUTH_USER_MODEL = "accounts.CustomUser"
|
||||
|
||||
DJOSER = {
|
||||
'SEND_ACTIVATION_EMAIL': True,
|
||||
'SEND_CONFIRMATION_EMAIL': True,
|
||||
'PASSWORD_RESET_CONFIRM_URL': 'reset_password_confirm/{uid}/{token}',
|
||||
'ACTIVATION_URL': 'activation/{uid}/{token}',
|
||||
'USER_AUTHENTICATION_RULES': ['djoser.authentication.TokenAuthenticationRule'],
|
||||
'EMAIL': {
|
||||
'activation': 'emails.templates.ActivationEmail',
|
||||
'password_reset': 'emails.templates.PasswordResetEmail'
|
||||
"SEND_ACTIVATION_EMAIL": True,
|
||||
"SEND_CONFIRMATION_EMAIL": True,
|
||||
"PASSWORD_RESET_CONFIRM_URL": "reset_password_confirm/{uid}/{token}",
|
||||
"ACTIVATION_URL": "activation/{uid}/{token}",
|
||||
"USER_AUTHENTICATION_RULES": ["djoser.authentication.TokenAuthenticationRule"],
|
||||
"EMAIL": {
|
||||
"activation": "emails.templates.ActivationEmail",
|
||||
"password_reset": "emails.templates.PasswordResetEmail",
|
||||
},
|
||||
'SERIALIZERS': {
|
||||
'user': 'accounts.serializers.CustomUserSerializer',
|
||||
'current_user': 'accounts.serializers.CustomUserSerializer',
|
||||
'user_create': 'accounts.serializers.UserRegistrationSerializer',
|
||||
"SERIALIZERS": {
|
||||
"user": "accounts.serializers.CustomUserSerializer",
|
||||
"current_user": "accounts.serializers.CustomUserSerializer",
|
||||
"user_create": "accounts.serializers.UserRegistrationSerializer",
|
||||
},
|
||||
'PERMISSIONS': {
|
||||
"PERMISSIONS": {
|
||||
# Disable some unneeded endpoints by setting them to admin only
|
||||
'username_reset': ['rest_framework.permissions.IsAdminUser'],
|
||||
'username_reset_confirm': ['rest_framework.permissions.IsAdminUser'],
|
||||
'set_username': ['rest_framework.permissions.IsAdminUser'],
|
||||
'set_password': ['rest_framework.permissions.IsAdminUser'],
|
||||
}
|
||||
"username_reset": ["rest_framework.permissions.IsAdminUser"],
|
||||
"username_reset_confirm": ["rest_framework.permissions.IsAdminUser"],
|
||||
"set_username": ["rest_framework.permissions.IsAdminUser"],
|
||||
"set_password": ["rest_framework.permissions.IsAdminUser"],
|
||||
},
|
||||
}
|
||||
|
||||
# Password validation
|
||||
|
@ -319,32 +306,32 @@ DJOSER = {
|
|||
|
||||
AUTH_PASSWORD_VALIDATORS = [
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.MinimumLengthValidator",
|
||||
"OPTIONS": {
|
||||
"min_length": 8,
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.CommonPasswordValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator',
|
||||
"NAME": "django.contrib.auth.password_validation.NumericPasswordValidator",
|
||||
},
|
||||
# Additional password validators
|
||||
{
|
||||
'NAME': 'accounts.validators.SpecialCharacterValidator',
|
||||
"NAME": "accounts.validators.SpecialCharacterValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'accounts.validators.LowercaseValidator',
|
||||
"NAME": "accounts.validators.LowercaseValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'accounts.validators.UppercaseValidator',
|
||||
"NAME": "accounts.validators.UppercaseValidator",
|
||||
},
|
||||
{
|
||||
'NAME': 'accounts.validators.NumberValidator',
|
||||
"NAME": "accounts.validators.NumberValidator",
|
||||
},
|
||||
]
|
||||
|
||||
|
@ -352,9 +339,9 @@ AUTH_PASSWORD_VALIDATORS = [
|
|||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/4.2/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = 'en-us'
|
||||
LANGUAGE_CODE = "en-us"
|
||||
|
||||
TIME_ZONE = get_secret('TIMEZONE')
|
||||
TIME_ZONE = get_secret("TIMEZONE")
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
|
@ -364,14 +351,14 @@ USE_TZ = True
|
|||
# Default primary key field type
|
||||
# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field
|
||||
|
||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||
DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
|
||||
|
||||
SITE_NAME = 'DRF-Template'
|
||||
SITE_NAME = "DRF-Template"
|
||||
|
||||
# JWT Token Lifetimes
|
||||
SIMPLE_JWT = {
|
||||
"ACCESS_TOKEN_LIFETIME": timedelta(hours=1),
|
||||
"REFRESH_TOKEN_LIFETIME": timedelta(days=3)
|
||||
"REFRESH_TOKEN_LIFETIME": timedelta(days=3),
|
||||
}
|
||||
|
||||
CORS_ALLOW_ALL_ORIGINS = True
|
||||
|
@ -388,11 +375,19 @@ CELERY_RESULT_BACKEND = get_secret("CELERY_RESULT_BACKEND")
|
|||
CELERY_RESULT_EXTENDED = True
|
||||
|
||||
# Celery Beat Options
|
||||
CELERY_BEAT_SCHEDULER = 'django_celery_beat.schedulers:DatabaseScheduler'
|
||||
CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler"
|
||||
|
||||
# Maximum number of rows that can be updated within the Django admin panel
|
||||
DATA_UPLOAD_MAX_NUMBER_FIELDS = 20480
|
||||
|
||||
GRAPH_MODELS = {
|
||||
'app_labels': ['accounts', 'user_groups', 'billing', 'emails', 'payments', 'subscriptions', 'search_results']
|
||||
"app_labels": [
|
||||
"accounts",
|
||||
"user_groups",
|
||||
"billing",
|
||||
"emails",
|
||||
"payments",
|
||||
"subscriptions",
|
||||
"search_results",
|
||||
]
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from django.urls import path, include
|
||||
|
||||
urlpatterns = [
|
||||
path('api/v1/', include('api.urls')),
|
||||
path("api/v1/", include("api.urls")),
|
||||
]
|
||||
|
|
|
@ -11,6 +11,6 @@ import os
|
|||
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||
|
||||
application = get_wsgi_application()
|
||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
|||
|
||||
|
||||
class EmailsConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'emails'
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "emails"
|
||||
|
|
|
@ -3,11 +3,11 @@ from django.utils import timezone
|
|||
|
||||
|
||||
class ActivationEmail(email.ActivationEmail):
|
||||
template_name = 'email_activation.html'
|
||||
template_name = "email_activation.html"
|
||||
|
||||
|
||||
class PasswordResetEmail(email.PasswordResetEmail):
|
||||
template_name = 'password_change.html'
|
||||
template_name = "password_change.html"
|
||||
|
||||
|
||||
class SubscriptionAvailedEmail(email.BaseEmailMessage):
|
||||
|
@ -19,7 +19,7 @@ class SubscriptionAvailedEmail(email.BaseEmailMessage):
|
|||
context["subscription_plan"] = context.get("subscription_plan")
|
||||
context["subscription"] = context.get("subscription")
|
||||
context["price_paid"] = context.get("price_paid")
|
||||
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||
context.update(self.context)
|
||||
return context
|
||||
|
||||
|
@ -32,7 +32,7 @@ class SubscriptionRefundedEmail(email.BaseEmailMessage):
|
|||
context["user"] = context.get("user")
|
||||
context["subscription_plan"] = context.get("subscription_plan")
|
||||
context["refund"] = context.get("refund")
|
||||
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||
context.update(self.context)
|
||||
return context
|
||||
|
||||
|
@ -44,6 +44,6 @@ class SubscriptionCancelledEmail(email.BaseEmailMessage):
|
|||
context = super().get_context_data()
|
||||
context["user"] = context.get("user")
|
||||
context["subscription_plan"] = context.get("subscription_plan")
|
||||
context['date'] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||
context["date"] = timezone.now().strftime("%B %d, %I:%M %p")
|
||||
context.update(self.context)
|
||||
return context
|
||||
|
|
|
@ -6,7 +6,7 @@ import sys
|
|||
|
||||
def main():
|
||||
"""Run administrative tasks."""
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'config.settings')
|
||||
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "config.settings")
|
||||
try:
|
||||
from django.core.management import execute_from_command_line
|
||||
except ImportError as exc:
|
||||
|
@ -18,5 +18,5 @@ def main():
|
|||
execute_from_command_line(sys.argv)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -6,5 +6,5 @@ from .models import Notification
|
|||
@admin.register(Notification)
|
||||
class NotificationAdmin(ModelAdmin):
|
||||
model = Notification
|
||||
search_fields = ('id', 'content')
|
||||
list_display = ['id', 'dismissed']
|
||||
search_fields = ("id", "content")
|
||||
list_display = ["id", "dismissed"]
|
||||
|
|
|
@ -2,8 +2,8 @@ from django.apps import AppConfig
|
|||
|
||||
|
||||
class NotificationsConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'notifications'
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "notifications"
|
||||
|
||||
def ready(self):
|
||||
import notifications.signals
|
||||
|
|
|
@ -15,13 +15,27 @@ class Migration(migrations.Migration):
|
|||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Notification',
|
||||
name="Notification",
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('content', models.CharField(max_length=1000, null=True)),
|
||||
('timestamp', models.DateTimeField(auto_now_add=True)),
|
||||
('dismissed', models.BooleanField(default=False)),
|
||||
('recipient', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("content", models.CharField(max_length=1000, null=True)),
|
||||
("timestamp", models.DateTimeField(auto_now_add=True)),
|
||||
("dismissed", models.BooleanField(default=False)),
|
||||
(
|
||||
"recipient",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -2,8 +2,7 @@ from django.db import models
|
|||
|
||||
|
||||
class Notification(models.Model):
|
||||
recipient = models.ForeignKey(
|
||||
'accounts.CustomUser', on_delete=models.CASCADE)
|
||||
recipient = models.ForeignKey("accounts.CustomUser", on_delete=models.CASCADE)
|
||||
content = models.CharField(max_length=1000, null=True)
|
||||
timestamp = models.DateTimeField(auto_now_add=True, editable=False)
|
||||
dismissed = models.BooleanField(default=False)
|
||||
|
|
|
@ -3,10 +3,9 @@ from notifications.models import Notification
|
|||
|
||||
|
||||
class NotificationSerializer(serializers.ModelSerializer):
|
||||
timestamp = serializers.DateTimeField(
|
||||
format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||
timestamp = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = Notification
|
||||
fields = '__all__'
|
||||
read_only_fields = ('id', 'recipient', 'content', 'timestamp')
|
||||
fields = "__all__"
|
||||
read_only_fields = ("id", "recipient", "content", "timestamp")
|
||||
|
|
|
@ -9,5 +9,5 @@ from django.core.cache import cache
|
|||
@receiver(post_save, sender=Notification)
|
||||
def clear_cache_after_notification_update(sender, instance, **kwargs):
|
||||
# Clear cache
|
||||
cache.delete('notifications')
|
||||
cache.delete(f'notifications_user:{instance.recipient.id}')
|
||||
cache.delete("notifications")
|
||||
cache.delete(f"notifications_user:{instance.recipient.id}")
|
||||
|
|
|
@ -9,5 +9,4 @@ def cleanup_notifications():
|
|||
three_days_ago = timezone.now() - timezone.timedelta(days=3)
|
||||
|
||||
# Delete notifications that are older than 3 days and dismissed
|
||||
Notification.objects.filter(
|
||||
dismissed=True, timestamp__lte=three_days_ago).delete()
|
||||
Notification.objects.filter(dismissed=True, timestamp__lte=three_days_ago).delete()
|
||||
|
|
|
@ -3,8 +3,7 @@ from notifications.views import NotificationViewSet
|
|||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r'', NotificationViewSet,
|
||||
basename="Notifications")
|
||||
router.register(r"", NotificationViewSet, basename="Notifications")
|
||||
urlpatterns = [
|
||||
path('', include(router.urls)),
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
|
|
|
@ -6,30 +6,33 @@ from django.core.cache import cache
|
|||
|
||||
|
||||
class NotificationViewSet(viewsets.ModelViewSet):
|
||||
http_method_names = ['get', 'patch', 'delete']
|
||||
http_method_names = ["get", "patch", "delete"]
|
||||
serializer_class = NotificationSerializer
|
||||
queryset = Notification.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
user = self.request.user
|
||||
key = f'notifications_user:{user.id}'
|
||||
key = f"notifications_user:{user.id}"
|
||||
queryset = cache.get(key)
|
||||
if not queryset:
|
||||
queryset = Notification.objects.filter(
|
||||
recipient=user).order_by('-timestamp')
|
||||
cache.set(key, queryset, 60*60)
|
||||
queryset = Notification.objects.filter(recipient=user).order_by(
|
||||
"-timestamp"
|
||||
)
|
||||
cache.set(key, queryset, 60 * 60)
|
||||
return queryset
|
||||
|
||||
def update(self, request, *args, **kwargs):
|
||||
instance = self.get_object()
|
||||
if instance.recipient != request.user:
|
||||
raise PermissionDenied(
|
||||
"You do not have permission to update this notification.")
|
||||
"You do not have permission to update this notification."
|
||||
)
|
||||
return super().update(request, *args, **kwargs)
|
||||
|
||||
def destroy(self, request, *args, **kwargs):
|
||||
instance = self.get_object()
|
||||
if instance.recipient != request.user:
|
||||
raise PermissionDenied(
|
||||
"You do not have permission to delete this notification.")
|
||||
"You do not have permission to delete this notification."
|
||||
)
|
||||
return super().destroy(request, *args, **kwargs)
|
||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
|||
|
||||
|
||||
class PaymentsConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'payments'
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "payments"
|
||||
|
|
|
@ -3,6 +3,6 @@ from payments import views
|
|||
|
||||
|
||||
urlpatterns = [
|
||||
path('checkout_session/', views.StripeCheckoutView.as_view()),
|
||||
path('webhook/', views.stripe_webhook_view, name='Stripe Webhook'),
|
||||
path("checkout_session/", views.StripeCheckoutView.as_view()),
|
||||
path("webhook/", views.stripe_webhook_view, name="Stripe Webhook"),
|
||||
]
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
from config.settings import STRIPE_SECRET_KEY, STRIPE_SECRET_WEBHOOK, URL_SCHEME, FRONTEND_ADDRESS, FRONTEND_PORT
|
||||
from config.settings import (
|
||||
STRIPE_SECRET_KEY,
|
||||
STRIPE_SECRET_WEBHOOK,
|
||||
URL_SCHEME,
|
||||
FRONTEND_ADDRESS,
|
||||
FRONTEND_PORT,
|
||||
)
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
|
@ -12,16 +18,19 @@ from accounts.models import CustomUser
|
|||
from rest_framework.decorators import api_view
|
||||
from subscriptions.tasks import get_user_subscription
|
||||
import json
|
||||
from emails.templates import SubscriptionAvailedEmail, SubscriptionRefundedEmail, SubscriptionCancelledEmail
|
||||
from emails.templates import (
|
||||
SubscriptionAvailedEmail,
|
||||
SubscriptionRefundedEmail,
|
||||
SubscriptionCancelledEmail,
|
||||
)
|
||||
from django.core.cache import cache
|
||||
from payments.serializers import CheckoutSerializer
|
||||
from drf_spectacular.utils import extend_schema
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
|
||||
@extend_schema(
|
||||
request=CheckoutSerializer
|
||||
)
|
||||
@extend_schema(request=CheckoutSerializer)
|
||||
class StripeCheckoutView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
|
@ -30,41 +39,46 @@ class StripeCheckoutView(APIView):
|
|||
# Get subscription ID from POST
|
||||
USER = CustomUser.objects.get(id=self.request.user.id)
|
||||
data = json.loads(request.body)
|
||||
subscription_id = data.get('subscription_id')
|
||||
annual = data.get('annual')
|
||||
subscription_id = data.get("subscription_id")
|
||||
annual = data.get("annual")
|
||||
|
||||
# Validation for subscription_id field
|
||||
try:
|
||||
subscription_id = int(subscription_id)
|
||||
except:
|
||||
return Response({
|
||||
'error': 'Invalid value specified in subscription_id field'
|
||||
}, status=status.HTTP_403_FORBIDDEN)
|
||||
return Response(
|
||||
{"error": "Invalid value specified in subscription_id field"},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# Validation for annual field
|
||||
try:
|
||||
annual = bool(annual)
|
||||
except:
|
||||
return Response({
|
||||
'error': 'Invalid value specified in annual field'
|
||||
}, status=status.HTTP_403_FORBIDDEN)
|
||||
return Response(
|
||||
{"error": "Invalid value specified in annual field"},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# Return an error if the user already has an active subscription
|
||||
EXISTING_SUBSCRIPTION = get_user_subscription(USER.id)
|
||||
if EXISTING_SUBSCRIPTION:
|
||||
return Response({
|
||||
'error': f'User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}'
|
||||
}, status=status.HTTP_403_FORBIDDEN)
|
||||
return Response(
|
||||
{
|
||||
"error": f"User is already subscribed to: {EXISTING_SUBSCRIPTION.subscription.name}"
|
||||
},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
# Attempt to query the subscription
|
||||
SUBSCRIPTION = SubscriptionPlan.objects.filter(
|
||||
id=subscription_id).first()
|
||||
SUBSCRIPTION = SubscriptionPlan.objects.filter(id=subscription_id).first()
|
||||
|
||||
# Return an error if the plan does not exist
|
||||
if not SUBSCRIPTION:
|
||||
return Response({
|
||||
'error': 'Subscription plan not found'
|
||||
}, status=status.HTTP_404_NOT_FOUND)
|
||||
return Response(
|
||||
{"error": "Subscription plan not found"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
# Get the stripe_price_id from the related StripePrice instances
|
||||
if annual:
|
||||
|
@ -74,52 +88,58 @@ class StripeCheckoutView(APIView):
|
|||
|
||||
# Return 404 if no price is set
|
||||
if not PRICE:
|
||||
return Response({
|
||||
'error': 'Specified price does not exist for plan'
|
||||
}, status=status.HTTP_404_NOT_FOUND)
|
||||
return Response(
|
||||
{"error": "Specified price does not exist for plan"},
|
||||
status=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
PRICE_ID = PRICE.stripe_price_id
|
||||
prorated = PRICE.prorated
|
||||
|
||||
# Return an error if a user is in a user_group and is availing pro-rated plans
|
||||
if not USER.user_group and SUBSCRIPTION.group_exclusive:
|
||||
return Response({
|
||||
'error': 'Regular users cannot avail prorated plans'
|
||||
}, status=status.HTTP_403_FORBIDDEN)
|
||||
return Response(
|
||||
{"error": "Regular users cannot avail prorated plans"},
|
||||
status=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
success_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \
|
||||
'/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}'
|
||||
cancel_url = f'{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}' + \
|
||||
'/user/subscription/payment?success=false&user_group=False'
|
||||
success_url = (
|
||||
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
|
||||
+ "/user/subscription/payment?success=true&agency=False&session_id={CHECKOUT_SESSION_ID}"
|
||||
)
|
||||
cancel_url = (
|
||||
f"{URL_SCHEME}://{FRONTEND_ADDRESS}:{FRONTEND_PORT}"
|
||||
+ "/user/subscription/payment?success=false&user_group=False"
|
||||
)
|
||||
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
line_items=[
|
||||
{
|
||||
'price': PRICE_ID,
|
||||
'quantity': 1
|
||||
} if not prorated else
|
||||
{
|
||||
'price': PRICE_ID,
|
||||
}
|
||||
(
|
||||
{"price": PRICE_ID, "quantity": 1}
|
||||
if not prorated
|
||||
else {
|
||||
"price": PRICE_ID,
|
||||
}
|
||||
)
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
mode="subscription",
|
||||
payment_method_types=["card"],
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
)
|
||||
return Response({"url": checkout_session.url})
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return Response({
|
||||
'error': str(e)
|
||||
}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return Response(
|
||||
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
@ api_view(['POST'])
|
||||
@ csrf_exempt
|
||||
@api_view(["POST"])
|
||||
@csrf_exempt
|
||||
def stripe_webhook_view(request):
|
||||
payload = request.body
|
||||
sig_header = request.META['HTTP_STRIPE_SIGNATURE']
|
||||
sig_header = request.META["HTTP_STRIPE_SIGNATURE"]
|
||||
event = None
|
||||
|
||||
try:
|
||||
|
@ -133,12 +153,12 @@ def stripe_webhook_view(request):
|
|||
# Invalid signature
|
||||
return Response(status=401)
|
||||
|
||||
if event['type'] == 'customer.subscription.created':
|
||||
subscription = event['data']['object']
|
||||
if event["type"] == "customer.subscription.created":
|
||||
subscription = event["data"]["object"]
|
||||
# Get the Invoice object from the Subscription object
|
||||
invoice = stripe.Invoice.retrieve(subscription['latest_invoice'])
|
||||
invoice = stripe.Invoice.retrieve(subscription["latest_invoice"])
|
||||
# Get the Charge object from the Invoice object
|
||||
charge = stripe.Charge.retrieve(invoice['charge'])
|
||||
charge = stripe.Charge.retrieve(invoice["charge"])
|
||||
|
||||
# Get paying user
|
||||
customer = stripe.Customer.retrieve(subscription["customer"])
|
||||
|
@ -146,18 +166,20 @@ def stripe_webhook_view(request):
|
|||
|
||||
product = subscription["items"]["data"][0]
|
||||
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
|
||||
stripe_product_id=product["plan"]["product"])
|
||||
stripe_product_id=product["plan"]["product"]
|
||||
)
|
||||
SUBSCRIPTION = UserSubscription.objects.create(
|
||||
subscription=SUBSCRIPTION_PLAN,
|
||||
annual=product["plan"]["interval"] == "year",
|
||||
valid=True,
|
||||
user=USER,
|
||||
stripe_id=subscription['id'])
|
||||
stripe_id=subscription["id"],
|
||||
)
|
||||
email = SubscriptionAvailedEmail()
|
||||
|
||||
paid = {
|
||||
"amount": charge['amount']/100,
|
||||
"currency": str(charge['currency']).upper()
|
||||
"amount": charge["amount"] / 100,
|
||||
"currency": str(charge["currency"]).upper(),
|
||||
}
|
||||
|
||||
email.context = {
|
||||
|
@ -169,19 +191,20 @@ def stripe_webhook_view(request):
|
|||
email.send(to=[customer.email])
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{USER.id}')
|
||||
cache.delete(f'subscriptions_user:{USER.id}')
|
||||
cache.delete(f"billing_user:{USER.id}")
|
||||
cache.delete(f"subscriptions_user:{USER.id}")
|
||||
|
||||
# On chargebacks/refunds, invalidate the subscription
|
||||
elif event['type'] == 'charge.refunded':
|
||||
charge = event['data']['object']
|
||||
elif event["type"] == "charge.refunded":
|
||||
charge = event["data"]["object"]
|
||||
|
||||
# Get the Invoice object from the Charge object
|
||||
invoice = stripe.Invoice.retrieve(charge['invoice'])
|
||||
invoice = stripe.Invoice.retrieve(charge["invoice"])
|
||||
|
||||
# Check if the subscription exists
|
||||
SUBSCRIPTION = UserSubscription.objects.filter(
|
||||
stripe_id=invoice['subscription']).first()
|
||||
stripe_id=invoice["subscription"]
|
||||
).first()
|
||||
|
||||
if not (SUBSCRIPTION):
|
||||
return HttpResponse(status=404)
|
||||
|
@ -196,8 +219,8 @@ def stripe_webhook_view(request):
|
|||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||
|
||||
refund = {
|
||||
"amount": charge['amount_refunded']/100,
|
||||
"currency": str(charge['currency']).upper()
|
||||
"amount": charge["amount_refunded"] / 100,
|
||||
"currency": str(charge["currency"]).upper(),
|
||||
}
|
||||
|
||||
# Send an email
|
||||
|
@ -206,13 +229,13 @@ def stripe_webhook_view(request):
|
|||
email.context = {
|
||||
"user": USER,
|
||||
"subscription_plan": SUBSCRIPTION_PLAN,
|
||||
"refund": refund
|
||||
"refund": refund,
|
||||
}
|
||||
|
||||
email.send(to=[USER.email])
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{USER.id}')
|
||||
cache.delete(f"billing_user:{USER.id}")
|
||||
|
||||
elif SUBSCRIPTION.user_group:
|
||||
OWNER = SUBSCRIPTION.user_group.owner
|
||||
|
@ -223,8 +246,8 @@ def stripe_webhook_view(request):
|
|||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||
|
||||
refund = {
|
||||
"amount": charge['amount_refunded']/100,
|
||||
"currency": str(charge['currency']).upper()
|
||||
"amount": charge["amount_refunded"] / 100,
|
||||
"currency": str(charge["currency"]).upper(),
|
||||
}
|
||||
|
||||
# Send en email
|
||||
|
@ -233,36 +256,38 @@ def stripe_webhook_view(request):
|
|||
email.context = {
|
||||
"user": OWNER,
|
||||
"subscription_plan": SUBSCRIPTION_PLAN,
|
||||
"refund": refund
|
||||
"refund": refund,
|
||||
}
|
||||
email.send(to=[OWNER.email])
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{USER.id}')
|
||||
cache.delete(f'subscriptions_user:{USER.id}')
|
||||
cache.delete(f"billing_user:{USER.id}")
|
||||
cache.delete(f"subscriptions_user:{USER.id}")
|
||||
|
||||
elif event['type'] == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
elif event["type"] == "customer.subscription.updated":
|
||||
subscription = event["data"]["object"]
|
||||
|
||||
# Check if the subscription exists
|
||||
SUBSCRIPTION = UserSubscription.objects.filter(
|
||||
stripe_id=subscription['id']).first()
|
||||
stripe_id=subscription["id"]
|
||||
).first()
|
||||
|
||||
if not (SUBSCRIPTION):
|
||||
return HttpResponse(status=404)
|
||||
|
||||
# Check if a subscription has been upgraded/downgraded
|
||||
new_stripe_product_id = subscription['items']['data'][0]['plan']['product']
|
||||
new_stripe_product_id = subscription["items"]["data"][0]["plan"]["product"]
|
||||
current_stripe_product_id = SUBSCRIPTION.subscription.stripe_product_id
|
||||
if new_stripe_product_id != current_stripe_product_id:
|
||||
SUBSCRIPTION_PLAN = SubscriptionPlan.objects.get(
|
||||
stripe_product_id=new_stripe_product_id)
|
||||
stripe_product_id=new_stripe_product_id
|
||||
)
|
||||
SUBSCRIPTION.subscription = SUBSCRIPTION_PLAN
|
||||
SUBSCRIPTION.save()
|
||||
# TODO: Add a plan upgraded email message here
|
||||
|
||||
# Subscription activation/reactivation
|
||||
if subscription['status'] == 'active':
|
||||
if subscription["status"] == "active":
|
||||
SUBSCRIPTION.valid = True
|
||||
SUBSCRIPTION.save()
|
||||
|
||||
|
@ -270,26 +295,24 @@ def stripe_webhook_view(request):
|
|||
USER = SUBSCRIPTION.user
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{USER.id}')
|
||||
cache.delete(
|
||||
f'subscriptions_user:{USER.id}')
|
||||
cache.delete(f"billing_user:{USER.id}")
|
||||
cache.delete(f"subscriptions_user:{USER.id}")
|
||||
|
||||
elif SUBSCRIPTION.user_group:
|
||||
OWNER = SUBSCRIPTION.user_group.owner
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{OWNER.id}')
|
||||
cache.delete(
|
||||
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
|
||||
cache.delete(f"billing_user:{OWNER.id}")
|
||||
cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
|
||||
|
||||
# TODO: Add notification here to inform users if their plan has been reactivated
|
||||
|
||||
elif subscription['status'] == 'past_due':
|
||||
elif subscription["status"] == "past_due":
|
||||
# TODO: Add notification here to inform users if their payment method for an existing subscription payment is failing
|
||||
pass
|
||||
|
||||
# If subscriptions get cancelled due to non-payment, invalidate the UserSubscription
|
||||
elif subscription['status'] == 'cancelled':
|
||||
elif subscription["status"] == "cancelled":
|
||||
if SUBSCRIPTION.user:
|
||||
USER = SUBSCRIPTION.user
|
||||
|
||||
|
@ -310,8 +333,8 @@ def stripe_webhook_view(request):
|
|||
email.send(to=[USER.email])
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{USER.id}')
|
||||
cache.delete(f'subscriptions_user:{USER.id}')
|
||||
cache.delete(f"billing_user:{USER.id}")
|
||||
cache.delete(f"subscriptions_user:{USER.id}")
|
||||
|
||||
elif SUBSCRIPTION.user_group:
|
||||
OWNER = SUBSCRIPTION.user_group.owner
|
||||
|
@ -325,24 +348,21 @@ def stripe_webhook_view(request):
|
|||
|
||||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||
|
||||
email.context = {
|
||||
"user": OWNER,
|
||||
"subscription_plan": SUBSCRIPTION_PLAN
|
||||
}
|
||||
email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
|
||||
email.send(to=[OWNER.email])
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{OWNER.id}')
|
||||
cache.delete(
|
||||
f'subscriptions_usergroup:{SUBSCRIPTION.user_group.id}')
|
||||
cache.delete(f"billing_user:{OWNER.id}")
|
||||
cache.delete(f"subscriptions_usergroup:{SUBSCRIPTION.user_group.id}")
|
||||
|
||||
# If a subscription gets cancelled, invalidate it
|
||||
elif event['type'] == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
elif event["type"] == "customer.subscription.deleted":
|
||||
subscription = event["data"]["object"]
|
||||
|
||||
# Check if the subscription exists
|
||||
SUBSCRIPTION = UserSubscription.objects.filter(
|
||||
stripe_id=subscription['id']).first()
|
||||
stripe_id=subscription["id"]
|
||||
).first()
|
||||
|
||||
if not (SUBSCRIPTION):
|
||||
return HttpResponse(status=404)
|
||||
|
@ -367,7 +387,7 @@ def stripe_webhook_view(request):
|
|||
email.send(to=[USER.email])
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{USER.id}')
|
||||
cache.delete(f"billing_user:{USER.id}")
|
||||
|
||||
elif SUBSCRIPTION.user_group:
|
||||
OWNER = SUBSCRIPTION.user_group.owner
|
||||
|
@ -381,14 +401,11 @@ def stripe_webhook_view(request):
|
|||
|
||||
SUBSCRIPTION_PLAN = SUBSCRIPTION.subscription
|
||||
|
||||
email.context = {
|
||||
"user": OWNER,
|
||||
"subscription_plan": SUBSCRIPTION_PLAN
|
||||
}
|
||||
email.context = {"user": OWNER, "subscription_plan": SUBSCRIPTION_PLAN}
|
||||
email.send(to=[OWNER.email])
|
||||
|
||||
# Clear cache
|
||||
cache.delete(f'billing_user:{OWNER.id}')
|
||||
cache.delete(f"billing_user:{OWNER.id}")
|
||||
|
||||
# Passed signature verification
|
||||
return HttpResponse(status=200)
|
||||
|
|
|
@ -805,6 +805,9 @@ components:
|
|||
first_name:
|
||||
type: string
|
||||
maxLength: 150
|
||||
is_new:
|
||||
type: string
|
||||
readOnly: true
|
||||
last_name:
|
||||
type: string
|
||||
maxLength: 150
|
||||
|
@ -824,6 +827,7 @@ components:
|
|||
- group_member
|
||||
- group_owner
|
||||
- id
|
||||
- is_new
|
||||
- user_group
|
||||
- username
|
||||
Notification:
|
||||
|
@ -885,6 +889,9 @@ components:
|
|||
first_name:
|
||||
type: string
|
||||
maxLength: 150
|
||||
is_new:
|
||||
type: string
|
||||
readOnly: true
|
||||
last_name:
|
||||
type: string
|
||||
maxLength: 150
|
||||
|
|
|
@ -7,12 +7,11 @@ from unfold.contrib.filters.admin import RangeDateFilter
|
|||
@admin.register(SearchResult)
|
||||
class SearchResultAdmin(ModelAdmin):
|
||||
model = SearchResult
|
||||
search_fields = ('id', 'title', 'link')
|
||||
list_display = ['id', 'title', 'timestamp']
|
||||
search_fields = ("id", "title", "link")
|
||||
list_display = ["id", "title", "timestamp"]
|
||||
|
||||
list_filter_submit = True
|
||||
list_filter = ((
|
||||
"timestamp", RangeDateFilter
|
||||
), (
|
||||
"timestamp", RangeDateFilter
|
||||
),)
|
||||
list_filter = (
|
||||
("timestamp", RangeDateFilter),
|
||||
("timestamp", RangeDateFilter),
|
||||
)
|
||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
|||
|
||||
|
||||
class SearchResultsConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'search_results'
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "search_results"
|
||||
|
|
|
@ -7,17 +7,24 @@ class Migration(migrations.Migration):
|
|||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
]
|
||||
dependencies = []
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='SearchResult',
|
||||
name="SearchResult",
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('title', models.CharField(max_length=1000)),
|
||||
('link', models.CharField(max_length=1000)),
|
||||
('timestamp', models.DateTimeField(auto_now_add=True)),
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("title", models.CharField(max_length=1000)),
|
||||
("link", models.CharField(max_length=1000)),
|
||||
("timestamp", models.DateTimeField(auto_now_add=True)),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
|
||||
|
||||
from celery import shared_task
|
||||
from .models import SearchResult
|
||||
|
||||
|
||||
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 0, 'countdown': 5})
|
||||
@shared_task(
|
||||
autoretry_for=(Exception,), retry_kwargs={"max_retries": 0, "countdown": 5}
|
||||
)
|
||||
def create_search_result(title, link):
|
||||
if SearchResult.objects.filter(title=title, link=link).exists():
|
||||
return ("SearchResult entry already exists")
|
||||
return "SearchResult entry already exists"
|
||||
else:
|
||||
SearchResult.objects.create(
|
||||
title=title,
|
||||
link=link
|
||||
)
|
||||
SearchResult.objects.create(title=title, link=link)
|
||||
return f"Created new SearchResult entry titled: {title}"
|
||||
|
|
|
@ -6,10 +6,24 @@ from unfold.contrib.filters.admin import RangeDateFilter
|
|||
|
||||
@admin.register(StripePrice)
|
||||
class StripePriceAdmin(ModelAdmin):
|
||||
search_fields = ["id", "lookup_key",
|
||||
"stripe_price_id","price","currency", "prorated", "annual"]
|
||||
list_display = ["id", "lookup_key",
|
||||
"stripe_price_id", "price", "currency", "prorated", "annual"]
|
||||
search_fields = [
|
||||
"id",
|
||||
"lookup_key",
|
||||
"stripe_price_id",
|
||||
"price",
|
||||
"currency",
|
||||
"prorated",
|
||||
"annual",
|
||||
]
|
||||
list_display = [
|
||||
"id",
|
||||
"lookup_key",
|
||||
"stripe_price_id",
|
||||
"price",
|
||||
"currency",
|
||||
"prorated",
|
||||
"annual",
|
||||
]
|
||||
|
||||
|
||||
@admin.register(SubscriptionPlan)
|
||||
|
@ -21,9 +35,6 @@ class SubscriptionPlanAdmin(ModelAdmin):
|
|||
@admin.register(UserSubscription)
|
||||
class UserSubscriptionAdmin(ModelAdmin):
|
||||
list_filter_submit = True
|
||||
list_filter = ((
|
||||
"date", RangeDateFilter
|
||||
),)
|
||||
list_display = ["id", "__str__", "valid", "annual",
|
||||
"date"]
|
||||
list_filter = (("date", RangeDateFilter),)
|
||||
list_display = ["id", "__str__", "valid", "annual", "date"]
|
||||
search_fields = ["id", "date"]
|
||||
|
|
|
@ -2,8 +2,8 @@ from django.apps import AppConfig
|
|||
|
||||
|
||||
class SubscriptionConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'subscriptions'
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "subscriptions"
|
||||
|
||||
def ready(self):
|
||||
import subscriptions.signals
|
||||
|
|
|
@ -11,46 +11,118 @@ class Migration(migrations.Migration):
|
|||
initial = True
|
||||
|
||||
dependencies = [
|
||||
('user_groups', '0001_initial'),
|
||||
("user_groups", "0001_initial"),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='StripePrice',
|
||||
name="StripePrice",
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('annual', models.BooleanField(default=False)),
|
||||
('stripe_price_id', models.CharField(max_length=100)),
|
||||
('price', models.DecimalField(decimal_places=2, default=0.0, max_digits=10)),
|
||||
('currency', models.CharField(max_length=20)),
|
||||
('lookup_key', models.CharField(blank=True, max_length=100, null=True)),
|
||||
('prorated', models.BooleanField(default=False)),
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("annual", models.BooleanField(default=False)),
|
||||
("stripe_price_id", models.CharField(max_length=100)),
|
||||
(
|
||||
"price",
|
||||
models.DecimalField(decimal_places=2, default=0.0, max_digits=10),
|
||||
),
|
||||
("currency", models.CharField(max_length=20)),
|
||||
("lookup_key", models.CharField(blank=True, max_length=100, null=True)),
|
||||
("prorated", models.BooleanField(default=False)),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='SubscriptionPlan',
|
||||
name="SubscriptionPlan",
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(max_length=100)),
|
||||
('description', models.TextField(max_length=1024, null=True)),
|
||||
('stripe_product_id', models.CharField(max_length=100)),
|
||||
('group_exclusive', models.BooleanField(default=False)),
|
||||
('annual_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='annual_plan', to='subscriptions.stripeprice')),
|
||||
('monthly_price', models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='monthly_plan', to='subscriptions.stripeprice')),
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=100)),
|
||||
("description", models.TextField(max_length=1024, null=True)),
|
||||
("stripe_product_id", models.CharField(max_length=100)),
|
||||
("group_exclusive", models.BooleanField(default=False)),
|
||||
(
|
||||
"annual_price",
|
||||
models.ForeignKey(
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="annual_plan",
|
||||
to="subscriptions.stripeprice",
|
||||
),
|
||||
),
|
||||
(
|
||||
"monthly_price",
|
||||
models.ForeignKey(
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="monthly_plan",
|
||||
to="subscriptions.stripeprice",
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name='UserSubscription',
|
||||
name="UserSubscription",
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('stripe_id', models.CharField(max_length=100)),
|
||||
('date', models.DateTimeField(default=django.utils.timezone.now, editable=False)),
|
||||
('valid', models.BooleanField()),
|
||||
('annual', models.BooleanField()),
|
||||
('subscription', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to='subscriptions.subscriptionplan')),
|
||||
('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
('user_group', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to='user_groups.usergroup')),
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("stripe_id", models.CharField(max_length=100)),
|
||||
(
|
||||
"date",
|
||||
models.DateTimeField(
|
||||
default=django.utils.timezone.now, editable=False
|
||||
),
|
||||
),
|
||||
("valid", models.BooleanField()),
|
||||
("annual", models.BooleanField()),
|
||||
(
|
||||
"subscription",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="subscriptions.subscriptionplan",
|
||||
),
|
||||
),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
(
|
||||
"user_group",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="user_groups.usergroup",
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
from django.db import models
|
||||
from accounts.models import CustomUser
|
||||
from user_groups.models import UserGroup
|
||||
|
@ -25,9 +24,11 @@ class SubscriptionPlan(models.Model):
|
|||
description = models.TextField(max_length=1024, null=True)
|
||||
stripe_product_id = models.CharField(max_length=100)
|
||||
annual_price = models.ForeignKey(
|
||||
StripePrice, on_delete=models.SET_NULL, related_name='annual_plan', null=True)
|
||||
StripePrice, on_delete=models.SET_NULL, related_name="annual_plan", null=True
|
||||
)
|
||||
monthly_price = models.ForeignKey(
|
||||
StripePrice, on_delete=models.SET_NULL, related_name='monthly_plan', null=True)
|
||||
StripePrice, on_delete=models.SET_NULL, related_name="monthly_plan", null=True
|
||||
)
|
||||
group_exclusive = models.BooleanField(default=False)
|
||||
|
||||
def __str__(self):
|
||||
|
@ -39,11 +40,14 @@ class SubscriptionPlan(models.Model):
|
|||
|
||||
class UserSubscription(models.Model):
|
||||
user = models.ForeignKey(
|
||||
CustomUser, on_delete=models.CASCADE, blank=True, null=True)
|
||||
CustomUser, on_delete=models.CASCADE, blank=True, null=True
|
||||
)
|
||||
user_group = models.ForeignKey(
|
||||
UserGroup, on_delete=models.CASCADE, blank=True, null=True)
|
||||
UserGroup, on_delete=models.CASCADE, blank=True, null=True
|
||||
)
|
||||
subscription = models.ForeignKey(
|
||||
SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True)
|
||||
SubscriptionPlan, on_delete=models.SET_NULL, blank=True, null=True
|
||||
)
|
||||
stripe_id = models.CharField(max_length=100)
|
||||
date = models.DateTimeField(default=now, editable=False)
|
||||
valid = models.BooleanField()
|
||||
|
@ -51,6 +55,6 @@ class UserSubscription(models.Model):
|
|||
|
||||
def __str__(self):
|
||||
if self.user:
|
||||
return f'Subscription {self.subscription.name} for {self.user}'
|
||||
return f"Subscription {self.subscription.name} for {self.user}"
|
||||
else:
|
||||
return f'Subscription {self.subscription.name} for {self.user_group}'
|
||||
return f"Subscription {self.subscription.name} for {self.user_group}"
|
||||
|
|
|
@ -7,38 +7,46 @@ class SimpleStripePriceSerializer(serializers.ModelSerializer):
|
|||
|
||||
class Meta:
|
||||
model = StripePrice
|
||||
fields = ['price', 'currency', 'prorated']
|
||||
fields = ["price", "currency", "prorated"]
|
||||
|
||||
|
||||
class SubscriptionPlanSerializer(serializers.ModelSerializer):
|
||||
|
||||
class Meta:
|
||||
model = SubscriptionPlan
|
||||
fields = ['id', 'name', 'description',
|
||||
'annual_price', 'monthly_price', 'group_exclusive']
|
||||
fields = [
|
||||
"id",
|
||||
"name",
|
||||
"description",
|
||||
"annual_price",
|
||||
"monthly_price",
|
||||
"group_exclusive",
|
||||
]
|
||||
|
||||
def to_representation(self, instance):
|
||||
representation = super().to_representation(instance)
|
||||
representation['annual_price'] = SimpleStripePriceSerializer(
|
||||
instance.annual_price, many=False).data
|
||||
representation['monthly_price'] = SimpleStripePriceSerializer(
|
||||
instance.monthly_price, many=False).data
|
||||
representation["annual_price"] = SimpleStripePriceSerializer(
|
||||
instance.annual_price, many=False
|
||||
).data
|
||||
representation["monthly_price"] = SimpleStripePriceSerializer(
|
||||
instance.monthly_price, many=False
|
||||
).data
|
||||
return representation
|
||||
|
||||
|
||||
class UserSubscriptionSerializer(serializers.ModelSerializer):
|
||||
date = serializers.DateTimeField(
|
||||
format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||
date = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = UserSubscription
|
||||
fields = ['id', 'user', 'user_group', 'subscription',
|
||||
'date', 'valid', 'annual']
|
||||
fields = ["id", "user", "user_group", "subscription", "date", "valid", "annual"]
|
||||
|
||||
def to_representation(self, instance):
|
||||
representation = super().to_representation(instance)
|
||||
representation['user'] = SimpleCustomUserSerializer(
|
||||
instance.user, many=False).data
|
||||
representation['subscription'] = SubscriptionPlanSerializer(
|
||||
instance.subscription, many=False).data
|
||||
representation["user"] = SimpleCustomUserSerializer(
|
||||
instance.user, many=False
|
||||
).data
|
||||
representation["subscription"] = SubscriptionPlanSerializer(
|
||||
instance.subscription, many=False
|
||||
).data
|
||||
return representation
|
||||
|
|
|
@ -4,6 +4,7 @@ from .models import UserSubscription, StripePrice, SubscriptionPlan
|
|||
from django.core.cache import cache
|
||||
from config.settings import STRIPE_SECRET_KEY
|
||||
import stripe
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
# Template for running actions after user have paid for a subscription
|
||||
|
@ -12,7 +13,7 @@ stripe.api_key = STRIPE_SECRET_KEY
|
|||
@receiver(post_save, sender=SubscriptionPlan)
|
||||
def clear_cache_after_plan_updates(sender, instance, **kwargs):
|
||||
# Clear cache
|
||||
cache.delete('subscriptionplans')
|
||||
cache.delete("subscriptionplans")
|
||||
|
||||
|
||||
@receiver(post_save, sender=UserSubscription)
|
||||
|
@ -25,8 +26,8 @@ def scan_after_payment(sender, instance, **kwargs):
|
|||
|
||||
@receiver(post_migrate)
|
||||
def create_subscriptions(sender, **kwargs):
|
||||
if sender.name == 'subscriptions':
|
||||
print('Importing data from Stripe')
|
||||
if sender.name == "subscriptions":
|
||||
print("Importing data from Stripe")
|
||||
created_prices = 0
|
||||
created_plans = 0
|
||||
skipped_prices = 0
|
||||
|
@ -35,16 +36,19 @@ def create_subscriptions(sender, **kwargs):
|
|||
prices = stripe.Price.list(expand=["data.tiers"], active=True)
|
||||
|
||||
# Create the StripePrice
|
||||
for price in prices['data']:
|
||||
annual = (price['recurring']['interval'] ==
|
||||
'year') if price['recurring'] else False
|
||||
for price in prices["data"]:
|
||||
annual = (
|
||||
(price["recurring"]["interval"] == "year")
|
||||
if price["recurring"]
|
||||
else False
|
||||
)
|
||||
STRIPE_PRICE, CREATED = StripePrice.objects.get_or_create(
|
||||
stripe_price_id=price['id'],
|
||||
price=price['unit_amount'] / 100,
|
||||
stripe_price_id=price["id"],
|
||||
price=price["unit_amount"] / 100,
|
||||
annual=annual,
|
||||
lookup_key=price['lookup_key'],
|
||||
prorated=price['recurring']['usage_type'] == 'metered',
|
||||
currency=price['currency']
|
||||
lookup_key=price["lookup_key"],
|
||||
prorated=price["recurring"]["usage_type"] == "metered",
|
||||
currency=price["currency"],
|
||||
)
|
||||
if CREATED:
|
||||
created_prices += 1
|
||||
|
@ -52,13 +56,13 @@ def create_subscriptions(sender, **kwargs):
|
|||
skipped_prices += 1
|
||||
|
||||
# Create the SubscriptionPlan
|
||||
for product in products['data']:
|
||||
for product in products["data"]:
|
||||
ANNUAL_PRICE = None
|
||||
MONTHLY_PRICE = None
|
||||
for price in prices['data']:
|
||||
if price['product'] == product['id']:
|
||||
for price in prices["data"]:
|
||||
if price["product"] == product["id"]:
|
||||
STRIPE_PRICE = StripePrice.objects.get(
|
||||
stripe_price_id=price['id'],
|
||||
stripe_price_id=price["id"],
|
||||
)
|
||||
if STRIPE_PRICE.annual:
|
||||
ANNUAL_PRICE = STRIPE_PRICE
|
||||
|
@ -66,12 +70,12 @@ def create_subscriptions(sender, **kwargs):
|
|||
MONTHLY_PRICE = STRIPE_PRICE
|
||||
if ANNUAL_PRICE or MONTHLY_PRICE:
|
||||
SUBSCRIPTION_PLAN, CREATED = SubscriptionPlan.objects.get_or_create(
|
||||
name=product['name'],
|
||||
description=product['description'],
|
||||
stripe_product_id=product['id'],
|
||||
name=product["name"],
|
||||
description=product["description"],
|
||||
stripe_product_id=product["id"],
|
||||
annual_price=ANNUAL_PRICE,
|
||||
monthly_price=MONTHLY_PRICE,
|
||||
group_exclusive=product['metadata']['group_exclusive'] == 'True'
|
||||
group_exclusive=product["metadata"]["group_exclusive"] == "True",
|
||||
)
|
||||
if CREATED:
|
||||
created_plans += 1
|
||||
|
@ -79,13 +83,12 @@ def create_subscriptions(sender, **kwargs):
|
|||
skipped_plans += 1
|
||||
# Skip over plans with missing pricing rates
|
||||
else:
|
||||
print('Skipping plan' +
|
||||
product['name'] + 'with missing pricing data')
|
||||
print("Skipping plan" + product["name"] + "with missing pricing data")
|
||||
|
||||
# Assign the StripePrice to the SubscriptionPlan
|
||||
SUBSCRIPTION_PLAN.save()
|
||||
|
||||
print('Created', created_plans, 'new plans')
|
||||
print('Skipped', skipped_plans, 'existing plans')
|
||||
print('Created', created_prices, 'new prices')
|
||||
print('Skipped', skipped_prices, 'existing prices')
|
||||
print("Created", created_plans, "new plans")
|
||||
print("Skipped", skipped_plans, "existing plans")
|
||||
print("Created", created_prices, "new prices")
|
||||
print("Skipped", skipped_prices, "existing prices")
|
||||
|
|
|
@ -12,10 +12,10 @@ def get_user_subscription(user_id):
|
|||
active_subscriptions = None
|
||||
if USER.user_group:
|
||||
active_subscriptions = UserSubscription.objects.filter(
|
||||
user_group=USER.user_group, valid=True)
|
||||
user_group=USER.user_group, valid=True
|
||||
)
|
||||
else:
|
||||
active_subscriptions = UserSubscription.objects.filter(
|
||||
user=USER, valid=True)
|
||||
active_subscriptions = UserSubscription.objects.filter(user=USER, valid=True)
|
||||
|
||||
# Return first valid subscription if there is one
|
||||
if len(active_subscriptions) > 0:
|
||||
|
@ -33,7 +33,8 @@ def get_user_group_subscription(user_group):
|
|||
# Get a list of subscriptions for the specified user
|
||||
active_subscriptions = None
|
||||
active_subscriptions = UserSubscription.objects.filter(
|
||||
user_group=USER_GROUP, valid=True)
|
||||
user_group=USER_GROUP, valid=True
|
||||
)
|
||||
|
||||
# Return first valid subscription if there is one
|
||||
if len(active_subscriptions) > 0:
|
||||
|
|
|
@ -3,12 +3,11 @@ from subscriptions import views
|
|||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r'plans', views.SubscriptionPlanViewset,
|
||||
basename="Subscription Plans")
|
||||
router.register(r'self', views.UserSubscriptionViewset,
|
||||
basename="Self Subscriptions")
|
||||
router.register(r'user_group', views.UserGroupSubscriptionViewet,
|
||||
basename="Group Subscriptions")
|
||||
router.register(r"plans", views.SubscriptionPlanViewset, basename="Subscription Plans")
|
||||
router.register(r"self", views.UserSubscriptionViewset, basename="Self Subscriptions")
|
||||
router.register(
|
||||
r"user_group", views.UserGroupSubscriptionViewet, basename="Group Subscriptions"
|
||||
)
|
||||
urlpatterns = [
|
||||
path('', include(router.urls)),
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from subscriptions.serializers import SubscriptionPlanSerializer, UserSubscriptionSerializer
|
||||
from subscriptions.serializers import (
|
||||
SubscriptionPlanSerializer,
|
||||
UserSubscriptionSerializer,
|
||||
)
|
||||
from subscriptions.models import SubscriptionPlan, UserSubscription
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework import viewsets
|
||||
|
@ -6,38 +9,38 @@ from django.core.cache import cache
|
|||
|
||||
|
||||
class SubscriptionPlanViewset(viewsets.ModelViewSet):
|
||||
http_method_names = ['get']
|
||||
http_method_names = ["get"]
|
||||
serializer_class = SubscriptionPlanSerializer
|
||||
permission_classes = [AllowAny]
|
||||
queryset = SubscriptionPlan.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
key = 'subscriptionplans'
|
||||
key = "subscriptionplans"
|
||||
queryset = cache.get(key)
|
||||
if not queryset:
|
||||
queryset = super().get_queryset()
|
||||
cache.set(key, queryset, 60*60)
|
||||
cache.set(key, queryset, 60 * 60)
|
||||
return queryset
|
||||
|
||||
|
||||
class UserSubscriptionViewset(viewsets.ModelViewSet):
|
||||
http_method_names = ['get']
|
||||
http_method_names = ["get"]
|
||||
serializer_class = UserSubscriptionSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
queryset = UserSubscription.objects.all()
|
||||
|
||||
def get_queryset(self):
|
||||
user = self.request.user
|
||||
key = f'subscriptions_user:{user.id}'
|
||||
key = f"subscriptions_user:{user.id}"
|
||||
queryset = cache.get(key)
|
||||
if not queryset:
|
||||
queryset = UserSubscription.objects.filter(user=user)
|
||||
cache.set(key, queryset, 60*60)
|
||||
cache.set(key, queryset, 60 * 60)
|
||||
return queryset
|
||||
|
||||
|
||||
class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
|
||||
http_method_names = ['get']
|
||||
http_method_names = ["get"]
|
||||
serializer_class = UserSubscriptionSerializer
|
||||
permission_classes = [IsAuthenticated]
|
||||
queryset = UserSubscription.objects.all()
|
||||
|
@ -47,10 +50,9 @@ class UserGroupSubscriptionViewet(viewsets.ModelViewSet):
|
|||
if not user.user_group:
|
||||
return UserSubscription.objects.none()
|
||||
else:
|
||||
key = f'subscriptions_usergroup:{user.user_group.id}'
|
||||
key = f"subscriptions_usergroup:{user.user_group.id}"
|
||||
queryset = cache.get(key)
|
||||
if not cache:
|
||||
queryset = UserSubscription.objects.filter(
|
||||
user_group=user.user_group)
|
||||
cache.set(key, queryset, 60*60)
|
||||
queryset = UserSubscription.objects.filter(user_group=user.user_group)
|
||||
cache.set(key, queryset, 60 * 60)
|
||||
return queryset
|
||||
|
|
|
@ -7,9 +7,7 @@ from unfold.contrib.filters.admin import RangeDateFilter
|
|||
@admin.register(UserGroup)
|
||||
class UserGroupAdmin(ModelAdmin):
|
||||
list_filter_submit = True
|
||||
list_filter = ((
|
||||
"date_created", RangeDateFilter
|
||||
),)
|
||||
list_filter = (("date_created", RangeDateFilter),)
|
||||
|
||||
list_display = ['id', 'name']
|
||||
search_fields = ['id', 'name']
|
||||
list_display = ["id", "name"]
|
||||
search_fields = ["id", "name"]
|
||||
|
|
|
@ -8,16 +8,28 @@ class Migration(migrations.Migration):
|
|||
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
]
|
||||
dependencies = []
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='UserGroup',
|
||||
name="UserGroup",
|
||||
fields=[
|
||||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
|
||||
('name', models.CharField(max_length=128)),
|
||||
('date_created', models.DateTimeField(default=django.utils.timezone.now, editable=False)),
|
||||
(
|
||||
"id",
|
||||
models.BigAutoField(
|
||||
auto_created=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
verbose_name="ID",
|
||||
),
|
||||
),
|
||||
("name", models.CharField(max_length=128)),
|
||||
(
|
||||
"date_created",
|
||||
models.DateTimeField(
|
||||
default=django.utils.timezone.now, editable=False
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -8,24 +8,33 @@ from django.db import migrations, models
|
|||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('user_groups', '0001_initial'),
|
||||
("user_groups", "0001_initial"),
|
||||
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='usergroup',
|
||||
name='managers',
|
||||
field=models.ManyToManyField(related_name='usergroup_managers', to=settings.AUTH_USER_MODEL),
|
||||
model_name="usergroup",
|
||||
name="managers",
|
||||
field=models.ManyToManyField(
|
||||
related_name="usergroup_managers", to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='usergroup',
|
||||
name='members',
|
||||
field=models.ManyToManyField(related_name='usergroup_members', to=settings.AUTH_USER_MODEL),
|
||||
model_name="usergroup",
|
||||
name="members",
|
||||
field=models.ManyToManyField(
|
||||
related_name="usergroup_members", to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name='usergroup',
|
||||
name='owner',
|
||||
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='usergroup_owner', to=settings.AUTH_USER_MODEL),
|
||||
model_name="usergroup",
|
||||
name="owner",
|
||||
field=models.ForeignKey(
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
related_name="usergroup_owner",
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -2,17 +2,24 @@ from django.db import models
|
|||
from django.utils.timezone import now
|
||||
from config.settings import STRIPE_SECRET_KEY
|
||||
import stripe
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
|
||||
class UserGroup(models.Model):
|
||||
name = models.CharField(max_length=128, null=False)
|
||||
owner = models.ForeignKey(
|
||||
'accounts.CustomUser', on_delete=models.SET_NULL, null=True, related_name='usergroup_owner')
|
||||
"accounts.CustomUser",
|
||||
on_delete=models.SET_NULL,
|
||||
null=True,
|
||||
related_name="usergroup_owner",
|
||||
)
|
||||
managers = models.ManyToManyField(
|
||||
'accounts.CustomUser', related_name='usergroup_managers')
|
||||
"accounts.CustomUser", related_name="usergroup_managers"
|
||||
)
|
||||
members = models.ManyToManyField(
|
||||
'accounts.CustomUser', related_name='usergroup_members')
|
||||
"accounts.CustomUser", related_name="usergroup_members"
|
||||
)
|
||||
date_created = models.DateTimeField(default=now, editable=False)
|
||||
|
||||
# Derived from email of owner, may be used for billing
|
||||
|
|
|
@ -3,10 +3,9 @@ from .models import UserGroup
|
|||
|
||||
|
||||
class SimpleUserGroupSerializer(serializers.ModelSerializer):
|
||||
date_created = serializers.DateTimeField(
|
||||
format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||
date_created = serializers.DateTimeField(format="%m-%d-%Y %I:%M %p", read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = UserGroup
|
||||
fields = ['id', 'name', 'date_created']
|
||||
read_only_fields = ['id', 'name', 'date_created']
|
||||
fields = ["id", "name", "date_created"]
|
||||
read_only_fields = ["id", "name", "date_created"]
|
||||
|
|
|
@ -8,15 +8,16 @@ from config.settings import STRIPE_SECRET_KEY, ROOT_DIR
|
|||
import os
|
||||
import json
|
||||
import stripe
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
|
||||
|
||||
@receiver(m2m_changed, sender=UserGroup.managers.through)
|
||||
def update_group_managers(sender, instance, action, **kwargs):
|
||||
# When adding new managers to a UserGroup, associate them with it
|
||||
if action == 'post_add':
|
||||
if action == "post_add":
|
||||
# Get the newly added managers
|
||||
new_managers = kwargs.get('pk_set', set())
|
||||
new_managers = kwargs.get("pk_set", set())
|
||||
for manager in new_managers:
|
||||
# Retrieve the member
|
||||
USER = CustomUser.objects.get(pk=manager)
|
||||
|
@ -27,8 +28,8 @@ def update_group_managers(sender, instance, action, **kwargs):
|
|||
if USER not in instance.members.all():
|
||||
instance.members.add(USER)
|
||||
# When removing managers from a UserGroup, remove their association with it
|
||||
elif action == 'post_remove':
|
||||
for manager in kwargs['pk_set']:
|
||||
elif action == "post_remove":
|
||||
for manager in kwargs["pk_set"]:
|
||||
# Retrieve the manager
|
||||
USER = CustomUser.objects.get(pk=manager)
|
||||
if USER not in instance.members.all():
|
||||
|
@ -39,9 +40,9 @@ def update_group_managers(sender, instance, action, **kwargs):
|
|||
@receiver(m2m_changed, sender=UserGroup.members.through)
|
||||
def update_group_members(sender, instance, action, **kwargs):
|
||||
# When adding new members to a UserGroup, associate them with it
|
||||
if action == 'post_add':
|
||||
if action == "post_add":
|
||||
# Get the newly added members
|
||||
new_members = kwargs.get('pk_set', set())
|
||||
new_members = kwargs.get("pk_set", set())
|
||||
for member in new_members:
|
||||
# Retrieve the member
|
||||
USER = CustomUser.objects.get(pk=member)
|
||||
|
@ -50,10 +51,13 @@ def update_group_members(sender, instance, action, **kwargs):
|
|||
USER.user_group = instance
|
||||
USER.save()
|
||||
# When removing members from a UserGroup, remove their association with it
|
||||
elif action == 'post_remove':
|
||||
for client in kwargs['pk_set']:
|
||||
elif action == "post_remove":
|
||||
for client in kwargs["pk_set"]:
|
||||
USER = CustomUser.objects.get(pk=client)
|
||||
if USER not in instance.members.all() and USER not in instance.managers.all():
|
||||
if (
|
||||
USER not in instance.members.all()
|
||||
and USER not in instance.managers.all()
|
||||
):
|
||||
USER.user_group = None
|
||||
USER.save()
|
||||
# Update usage records
|
||||
|
@ -66,42 +70,42 @@ def update_group_members(sender, instance, action, **kwargs):
|
|||
stripe.SubscriptionItem.create_usage_record(
|
||||
SUBSCRIPTION_ITEM.stripe_id,
|
||||
quantity=len(instance.members.all()),
|
||||
action="set"
|
||||
action="set",
|
||||
)
|
||||
except:
|
||||
print(
|
||||
f'Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}')
|
||||
f"Warning: Unable to update usage record for SubscriptionGroup ID:{instance.id}"
|
||||
)
|
||||
|
||||
|
||||
@receiver(post_migrate)
|
||||
def create_groups(sender, **kwargs):
|
||||
if sender.name == "agencies":
|
||||
with open(os.path.join(ROOT_DIR, 'seed_data.json'), "r") as f:
|
||||
with open(os.path.join(ROOT_DIR, "seed_data.json"), "r") as f:
|
||||
seed_data = json.loads(f.read())
|
||||
for user_group in seed_data['user_groups']:
|
||||
OWNER = CustomUser.objects.filter(
|
||||
email=user_group['owner']).first()
|
||||
for user_group in seed_data["user_groups"]:
|
||||
OWNER = CustomUser.objects.filter(email=user_group["owner"]).first()
|
||||
USER_GROUP, CREATED = UserGroup.objects.get_or_create(
|
||||
owner=OWNER,
|
||||
agency_name=user_group['name'],
|
||||
agency_name=user_group["name"],
|
||||
)
|
||||
if CREATED:
|
||||
print(f"Created UserGroup {USER_GROUP.agency_name}")
|
||||
|
||||
# Add managers
|
||||
USERS = CustomUser.objects.filter(
|
||||
email__in=user_group['managers'])
|
||||
USERS = CustomUser.objects.filter(email__in=user_group["managers"])
|
||||
for USER in USERS:
|
||||
if USER not in USER_GROUP.managers.all():
|
||||
print(
|
||||
f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}")
|
||||
f"Adding User {USER.full_name} as manager to UserGroup {USER_GROUP.agency_name}"
|
||||
)
|
||||
USER_GROUP.managers.add(USER)
|
||||
# Add members
|
||||
USERS = CustomUser.objects.filter(
|
||||
email__in=user_group['members'])
|
||||
USERS = CustomUser.objects.filter(email__in=user_group["members"])
|
||||
for USER in USERS:
|
||||
if USER not in USER_GROUP.members.all():
|
||||
print(
|
||||
f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}")
|
||||
f"Adding User {USER.full_name} as member to UserGroup {USER_GROUP.agency_name}"
|
||||
)
|
||||
USER_GROUP.clients.add(USER)
|
||||
USER_GROUP.save()
|
||||
|
|
|
@ -2,5 +2,5 @@ from django.apps import AppConfig
|
|||
|
||||
|
||||
class EmailsConfig(AppConfig):
|
||||
default_auto_field = 'django.db.models.BigAutoField'
|
||||
name = 'webdriver'
|
||||
default_auto_field = "django.db.models.BigAutoField"
|
||||
name = "webdriver"
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
from celery import shared_task
|
||||
from webdriver.utils import setup_webdriver, selenium_action_template, google_search, get_element, get_elements
|
||||
from webdriver.utils import (
|
||||
setup_webdriver,
|
||||
selenium_action_template,
|
||||
google_search,
|
||||
get_element,
|
||||
get_elements,
|
||||
)
|
||||
from selenium.webdriver.common.by import By
|
||||
from search_results.tasks import create_search_result
|
||||
|
||||
|
||||
# Task template
|
||||
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5})
|
||||
@shared_task(
|
||||
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
|
||||
)
|
||||
def sample_selenium_task():
|
||||
|
||||
driver = setup_webdriver(use_proxy=False, use_saved_session=False)
|
||||
|
@ -18,27 +26,29 @@ def sample_selenium_task():
|
|||
driver.close()
|
||||
driver.quit()
|
||||
|
||||
|
||||
# Sample task to scrape Google for search results based on a keyword
|
||||
|
||||
|
||||
@shared_task(autoretry_for=(Exception,), retry_kwargs={'max_retries': 3, 'countdown': 5})
|
||||
@shared_task(
|
||||
autoretry_for=(Exception,), retry_kwargs={"max_retries": 3, "countdown": 5}
|
||||
)
|
||||
def simple_google_search():
|
||||
driver = setup_webdriver(driver_type="firefox",
|
||||
use_proxy=False, use_saved_session=False)
|
||||
driver = setup_webdriver(
|
||||
driver_type="firefox", use_proxy=False, use_saved_session=False
|
||||
)
|
||||
driver.get(f"https://google.com/")
|
||||
|
||||
google_search(driver, search_term="cat blog posts")
|
||||
|
||||
# Count number of Google search results
|
||||
search_items = get_elements(
|
||||
driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
|
||||
search_items = get_elements(driver, "xpath", '//*[@id="search"]/div[1]/div[1]/*')
|
||||
|
||||
for item in search_items:
|
||||
title = item.find_element(By.TAG_NAME, 'h3').text
|
||||
link = item.find_element(By.TAG_NAME, 'a').get_attribute('href')
|
||||
title = item.find_element(By.TAG_NAME, "h3").text
|
||||
link = item.find_element(By.TAG_NAME, "a").get_attribute("href")
|
||||
|
||||
create_search_result.apply_async(
|
||||
kwargs={"title": title, "link": link})
|
||||
create_search_result.apply_async(kwargs={"title": title, "link": link})
|
||||
|
||||
driver.close()
|
||||
driver.quit()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Settings file to hold constants and functions
|
||||
"""
|
||||
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.keys import Keys
|
||||
from config.settings import get_secret
|
||||
|
@ -18,24 +19,26 @@ import os
|
|||
import random
|
||||
|
||||
|
||||
def take_snapshot(driver, filename='dump.png'):
|
||||
def take_snapshot(driver, filename="dump.png"):
|
||||
# Set window size
|
||||
required_width = driver.execute_script(
|
||||
'return document.body.parentNode.scrollWidth')
|
||||
"return document.body.parentNode.scrollWidth"
|
||||
)
|
||||
required_height = driver.execute_script(
|
||||
'return document.body.parentNode.scrollHeight')
|
||||
driver.set_window_size(
|
||||
required_width, required_height+(required_height*0.05))
|
||||
"return document.body.parentNode.scrollHeight"
|
||||
)
|
||||
driver.set_window_size(required_width, required_height + (required_height * 0.05))
|
||||
|
||||
# Take the snapshot
|
||||
driver.find_element(By.TAG_NAME,
|
||||
'body').screenshot('/dumps/'+filename) # avoids any scrollbars
|
||||
print('Snapshot saved')
|
||||
driver.find_element(By.TAG_NAME, "body").screenshot(
|
||||
"/dumps/" + filename
|
||||
) # avoids any scrollbars
|
||||
print("Snapshot saved")
|
||||
|
||||
|
||||
def dump_html(driver, filename='dump.html'):
|
||||
def dump_html(driver, filename="dump.html"):
|
||||
# Save the page source to error.html
|
||||
with open(('/dumps/'+filename), 'w', encoding='utf-8') as file:
|
||||
with open(("/dumps/" + filename), "w", encoding="utf-8") as file:
|
||||
file.write(driver.page_source)
|
||||
|
||||
|
||||
|
@ -44,83 +47,83 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
|
|||
if not USE_PROXY:
|
||||
use_proxy = False
|
||||
if use_proxy:
|
||||
print('Running driver with proxy enabled')
|
||||
print("Running driver with proxy enabled")
|
||||
else:
|
||||
print('Running driver with proxy disabled')
|
||||
print("Running driver with proxy disabled")
|
||||
|
||||
if use_saved_session:
|
||||
print('Running with saved session')
|
||||
print("Running with saved session")
|
||||
else:
|
||||
print('Running without using saved session')
|
||||
print("Running without using saved session")
|
||||
|
||||
if driver_type == "chrome":
|
||||
print('Using Chrome driver')
|
||||
print("Using Chrome driver")
|
||||
opts = uc.ChromeOptions()
|
||||
|
||||
if use_saved_session:
|
||||
if os.path.exists("/tmp_chrome_profile"):
|
||||
print('Existing Chrome ephemeral profile found')
|
||||
print("Existing Chrome ephemeral profile found")
|
||||
else:
|
||||
print('No existing Chrome ephemeral profile found')
|
||||
print("No existing Chrome ephemeral profile found")
|
||||
os.system("mkdir /tmp_chrome_profile")
|
||||
if os.path.exists('/chrome'):
|
||||
print('Copying Chrome Profile to ephemeral directory')
|
||||
if os.path.exists("/chrome"):
|
||||
print("Copying Chrome Profile to ephemeral directory")
|
||||
# Flush any non-essential cache directories from the existing profile as they may balloon in size overtime
|
||||
os.system(
|
||||
'rm -rf "/chrome/Selenium Profile/Code Cache/*"')
|
||||
os.system('rm -rf "/chrome/Selenium Profile/Code Cache/*"')
|
||||
# Create a copy of the Chrome Profile
|
||||
os.system("cp -r /chrome/* /tmp_chrome_profile")
|
||||
try:
|
||||
# Remove some items related to file locks
|
||||
os.remove('/tmp_chrome_profile/SingletonLock')
|
||||
os.remove('/tmp_chrome_profile/SingletonSocket')
|
||||
os.remove('/tmp_chrome_profile/SingletonLock')
|
||||
os.remove("/tmp_chrome_profile/SingletonLock")
|
||||
os.remove("/tmp_chrome_profile/SingletonSocket")
|
||||
os.remove("/tmp_chrome_profile/SingletonLock")
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
print('No existing Chrome Profile found. Creating one from scratch')
|
||||
print("No existing Chrome Profile found. Creating one from scratch")
|
||||
|
||||
if use_saved_session:
|
||||
# Specify the user data directory
|
||||
opts.add_argument(f'--user-data-dir=/tmp_chrome_profile')
|
||||
opts.add_argument('--profile-directory=Selenium Profile')
|
||||
opts.add_argument(f"--user-data-dir=/tmp_chrome_profile")
|
||||
opts.add_argument("--profile-directory=Selenium Profile")
|
||||
|
||||
# Set proxy
|
||||
if use_proxy:
|
||||
opts.add_argument(
|
||||
f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}')
|
||||
f'--proxy-server=socks5://{get_secret("PROXY_IP")}:{get_secret("PROXY_PORT_IP_AUTH")}'
|
||||
)
|
||||
|
||||
opts.add_argument("--disable-extensions")
|
||||
opts.add_argument('--disable-application-cache')
|
||||
opts.add_argument("--disable-application-cache")
|
||||
opts.add_argument("--disable-setuid-sandbox")
|
||||
opts.add_argument('--disable-dev-shm-usage')
|
||||
opts.add_argument("--disable-dev-shm-usage")
|
||||
opts.add_argument("--disable-gpu")
|
||||
opts.add_argument("--no-sandbox")
|
||||
opts.add_argument("--headless=new")
|
||||
driver = uc.Chrome(options=opts)
|
||||
|
||||
elif driver_type == "firefox":
|
||||
print('Using firefox driver')
|
||||
print("Using firefox driver")
|
||||
opts = FirefoxOptions()
|
||||
if use_saved_session:
|
||||
if not os.path.exists("/firefox"):
|
||||
print('No profile found')
|
||||
print("No profile found")
|
||||
os.makedirs("/firefox")
|
||||
else:
|
||||
print('Existing profile found')
|
||||
print("Existing profile found")
|
||||
# Specify a profile if it exists
|
||||
opts.profile = "/firefox"
|
||||
|
||||
# Set proxy
|
||||
if use_proxy:
|
||||
opts.set_preference('network.proxy.type', 1)
|
||||
opts.set_preference('network.proxy.socks',
|
||||
get_secret('PROXY_IP'))
|
||||
opts.set_preference('network.proxy.socks_port',
|
||||
int(get_secret('PROXY_PORT_IP_AUTH')))
|
||||
opts.set_preference('network.proxy.socks_remote_dns', False)
|
||||
opts.set_preference("network.proxy.type", 1)
|
||||
opts.set_preference("network.proxy.socks", get_secret("PROXY_IP"))
|
||||
opts.set_preference(
|
||||
"network.proxy.socks_port", int(get_secret("PROXY_PORT_IP_AUTH"))
|
||||
)
|
||||
opts.set_preference("network.proxy.socks_remote_dns", False)
|
||||
|
||||
opts.add_argument('--disable-dev-shm-usage')
|
||||
opts.add_argument("--disable-dev-shm-usage")
|
||||
opts.add_argument("--headless")
|
||||
opts.add_argument("--disable-gpu")
|
||||
driver = webdriver.Firefox(options=opts)
|
||||
|
@ -128,13 +131,15 @@ def setup_webdriver(driver_type="chrome", use_proxy=True, use_saved_session=Fals
|
|||
driver.maximize_window()
|
||||
|
||||
# Check if proxy is working
|
||||
driver.get('https://api.ipify.org/')
|
||||
driver.get("https://api.ipify.org/")
|
||||
body = WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body")))
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
ip_address = body.text
|
||||
print(f'External IP: {ip_address}')
|
||||
print(f"External IP: {ip_address}")
|
||||
return driver
|
||||
|
||||
|
||||
# These are wrapper function for quickly automating multiple steps in webscraping (logins, button presses, text inputs, etc.)
|
||||
# Depending on your use case, you may have to opt out of using this
|
||||
|
||||
|
@ -151,10 +156,11 @@ def get_element(driver, by, key, hidden_element=False, timeout=8):
|
|||
wait = WebDriverWait(driver, timeout=timeout)
|
||||
if not hidden_element:
|
||||
element = wait.until(
|
||||
EC.element_to_be_clickable((by, key)) and EC.visibility_of_element_located((by, key)))
|
||||
EC.element_to_be_clickable((by, key))
|
||||
and EC.visibility_of_element_located((by, key))
|
||||
)
|
||||
else:
|
||||
element = wait.until(EC.presence_of_element_located(
|
||||
(by, key)))
|
||||
element = wait.until(EC.presence_of_element_located((by, key)))
|
||||
return element
|
||||
except Exception:
|
||||
dump_html(driver)
|
||||
|
@ -173,13 +179,12 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
|
|||
wait = WebDriverWait(driver, timeout=timeout)
|
||||
|
||||
if hidden_element:
|
||||
elements = wait.until(
|
||||
EC.presence_of_all_elements_located((by, key)))
|
||||
elements = wait.until(EC.presence_of_all_elements_located((by, key)))
|
||||
else:
|
||||
visible_elements = wait.until(
|
||||
EC.visibility_of_any_elements_located((by, key)))
|
||||
elements = [
|
||||
element for element in visible_elements if element.is_enabled()]
|
||||
EC.visibility_of_any_elements_located((by, key))
|
||||
)
|
||||
elements = [element for element in visible_elements if element.is_enabled()]
|
||||
|
||||
return elements
|
||||
except Exception:
|
||||
|
@ -193,17 +198,22 @@ def get_elements(driver, by, key, hidden_element=False, timeout=8):
|
|||
def execute_selenium_elements(driver, timeout, elements):
|
||||
try:
|
||||
for index, element in enumerate(elements):
|
||||
print('Waiting...')
|
||||
print("Waiting...")
|
||||
# Element may have a keyword specified, check if that exists before running any actions
|
||||
if "keyword" in element:
|
||||
# Skip a step if the keyword does not exist
|
||||
if element['keyword'] not in driver.page_source:
|
||||
if element["keyword"] not in driver.page_source:
|
||||
print(
|
||||
f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}')
|
||||
f'Keyword {element["keyword"]} does not exist. Skipping step: {index+1} - {element["name"]}'
|
||||
)
|
||||
continue
|
||||
elif element['keyword'] in driver.page_source and element['type'] == 'skip':
|
||||
elif (
|
||||
element["keyword"] in driver.page_source
|
||||
and element["type"] == "skip"
|
||||
):
|
||||
print(
|
||||
f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}')
|
||||
f'Keyword {element["keyword"]} does exists. Stopping at step: {index+1} - {element["name"]}'
|
||||
)
|
||||
break
|
||||
print(f'Step: {index+1} - {element["name"]}')
|
||||
# Revert to default iframe action
|
||||
|
@ -217,31 +227,47 @@ def execute_selenium_elements(driver, timeout, elements):
|
|||
else:
|
||||
values = element["input"]
|
||||
if type(values) is list:
|
||||
raise Exception(
|
||||
'Invalid input value specified for "callback" type')
|
||||
raise Exception('Invalid input value specified for "callback" type')
|
||||
else:
|
||||
# For single input values
|
||||
driver.execute_script(
|
||||
f'onRecaptcha("{values}");')
|
||||
driver.execute_script(f'onRecaptcha("{values}");')
|
||||
continue
|
||||
try:
|
||||
# Try to get default element
|
||||
if "hidden" in element:
|
||||
site_element = get_element(
|
||||
driver, element["default"]["type"], element["default"]["key"], hidden_element=True, timeout=timeout)
|
||||
driver,
|
||||
element["default"]["type"],
|
||||
element["default"]["key"],
|
||||
hidden_element=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
site_element = get_element(
|
||||
driver, element["default"]["type"], element["default"]["key"], timeout=timeout)
|
||||
driver,
|
||||
element["default"]["type"],
|
||||
element["default"]["key"],
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f'Failed to find primary element')
|
||||
print(f"Failed to find primary element")
|
||||
# If that fails, try to get the failover one
|
||||
print('Trying to find legacy element')
|
||||
print("Trying to find legacy element")
|
||||
if "hidden" in element:
|
||||
site_element = get_element(
|
||||
driver, element["failover"]["type"], element["failover"]["key"], hidden_element=True, timeout=timeout)
|
||||
driver,
|
||||
element["failover"]["type"],
|
||||
element["failover"]["key"],
|
||||
hidden_element=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
site_element = get_element(
|
||||
driver, element["failover"]["type"], element["failover"]["key"], timeout=timeout)
|
||||
driver,
|
||||
element["failover"]["type"],
|
||||
element["failover"]["key"],
|
||||
timeout=timeout,
|
||||
)
|
||||
# Clicking an element
|
||||
if element["type"] == "click":
|
||||
site_element.click()
|
||||
|
@ -272,11 +298,13 @@ def execute_selenium_elements(driver, timeout, elements):
|
|||
values = element["input"]
|
||||
if type(values) is list:
|
||||
raise Exception(
|
||||
'Invalid input value specified for "input_replace" type')
|
||||
'Invalid input value specified for "input_replace" type'
|
||||
)
|
||||
else:
|
||||
# For single input values
|
||||
driver.execute_script(
|
||||
f'arguments[0].value = "{values}";', site_element)
|
||||
f'arguments[0].value = "{values}";', site_element
|
||||
)
|
||||
except Exception as e:
|
||||
take_snapshot(driver)
|
||||
dump_html(driver)
|
||||
|
@ -285,30 +313,33 @@ def execute_selenium_elements(driver, timeout, elements):
|
|||
raise Exception(e)
|
||||
|
||||
|
||||
def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=False, use_proxy=True):
|
||||
def solve_captcha(
|
||||
site_key, url, retry_attempts=3, version="v2", enterprise=False, use_proxy=True
|
||||
):
|
||||
# Manual proxy override set via $ENV
|
||||
if not USE_PROXY:
|
||||
use_proxy = False
|
||||
if CAPTCHA_TESTING:
|
||||
print('Initializing CAPTCHA solver in dummy mode')
|
||||
print("Initializing CAPTCHA solver in dummy mode")
|
||||
code = random.randint()
|
||||
print("CAPTCHA Successful")
|
||||
return code
|
||||
|
||||
elif use_proxy:
|
||||
print('Using CAPTCHA solver with proxy')
|
||||
print("Using CAPTCHA solver with proxy")
|
||||
else:
|
||||
print('Using CAPTCHA solver without proxy')
|
||||
print("Using CAPTCHA solver without proxy")
|
||||
|
||||
captcha_params = {
|
||||
"url": url,
|
||||
"sitekey": site_key,
|
||||
"version": version,
|
||||
"enterprise": 1 if enterprise else 0,
|
||||
"proxy": {
|
||||
'type': 'socks5',
|
||||
'uri': get_secret('PROXY_USER_AUTH')
|
||||
} if use_proxy else None
|
||||
"proxy": (
|
||||
{"type": "socks5", "uri": get_secret("PROXY_USER_AUTH")}
|
||||
if use_proxy
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
# Keep retrying until max attempts is reached
|
||||
|
@ -316,12 +347,12 @@ def solve_captcha(site_key, url, retry_attempts=3, version='v2', enterprise=Fals
|
|||
# Solver uses 2CAPTCHA by default
|
||||
solver = TwoCaptcha(get_secret("CAPTCHA_API_KEY"))
|
||||
try:
|
||||
print('Waiting for CAPTCHA code...')
|
||||
print("Waiting for CAPTCHA code...")
|
||||
code = solver.recaptcha(**captcha_params)["code"]
|
||||
print("CAPTCHA Successful")
|
||||
return code
|
||||
except Exception as e:
|
||||
print(f'CAPTCHA Failed! {e}')
|
||||
print(f"CAPTCHA Failed! {e}")
|
||||
|
||||
raise Exception(f"CAPTCHA API Failed!")
|
||||
|
||||
|
@ -339,13 +370,12 @@ def save_browser_session(driver):
|
|||
# Copy over the profile once we finish logging in
|
||||
if isinstance(driver, webdriver.Firefox):
|
||||
# Copy process for Firefox
|
||||
print('Updating saved Firefox profile')
|
||||
print("Updating saved Firefox profile")
|
||||
# Get the current profile directory from about:support page
|
||||
driver.get("about:support")
|
||||
box = get_element(
|
||||
driver, "id", "profile-dir-box", timeout=4)
|
||||
box = get_element(driver, "id", "profile-dir-box", timeout=4)
|
||||
temp_profile_path = os.path.join(os.getcwd(), box.text)
|
||||
profile_path = '/firefox'
|
||||
profile_path = "/firefox"
|
||||
# Create the command
|
||||
copy_command = "cp -r " + temp_profile_path + "/* " + profile_path
|
||||
# Copy over the Firefox profile
|
||||
|
@ -353,13 +383,13 @@ def save_browser_session(driver):
|
|||
print("Firefox profile saved")
|
||||
elif isinstance(driver, uc.Chrome):
|
||||
# Copy the Chrome profile
|
||||
print('Updating non-ephemeral Chrome profile')
|
||||
print("Updating non-ephemeral Chrome profile")
|
||||
# Flush Code Cache again to speed up copy
|
||||
os.system(
|
||||
'rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
|
||||
os.system('rm -rf "/tmp_chrome_profile/SimpleDMCA Profile/Code Cache/*"')
|
||||
if os.system("cp -r /tmp_chrome_profile/* /chrome"):
|
||||
print("Chrome profile saved")
|
||||
|
||||
|
||||
# Sample function
|
||||
# Call this within a Celery task
|
||||
# TODO: Modify as needed to your needs
|
||||
|
@ -370,7 +400,7 @@ def selenium_action_template(driver):
|
|||
info = {
|
||||
"sample_field1": "sample_data",
|
||||
"sample_field2": "sample_data",
|
||||
"captcha_code": lambda: solve_captcha('SITE_KEY', 'SITE_URL')
|
||||
"captcha_code": lambda: solve_captcha("SITE_KEY", "SITE_URL"),
|
||||
}
|
||||
|
||||
elements = [
|
||||
|
@ -382,13 +412,10 @@ def selenium_action_template(driver):
|
|||
"default": {
|
||||
# See get_element() for possible selector types
|
||||
"type": "xpath",
|
||||
"key": ''
|
||||
"key": "",
|
||||
},
|
||||
# If a site implements canary design releases, you can place the ID for the element in the old design here
|
||||
"failover": {
|
||||
"type": "xpath",
|
||||
"key": ''
|
||||
}
|
||||
"failover": {"type": "xpath", "key": ""},
|
||||
},
|
||||
]
|
||||
|
||||
|
@ -398,8 +425,8 @@ def selenium_action_template(driver):
|
|||
|
||||
# Fill in final fstring values in elements
|
||||
for element in elements:
|
||||
if 'input' in element and '{' in element['input']:
|
||||
a = element['input'].strip('{}')
|
||||
if "input" in element and "{" in element["input"]:
|
||||
a = element["input"].strip("{}")
|
||||
if a in info:
|
||||
value = info[a]
|
||||
# Check if the value is a callable (a lambda function) and call it if so
|
||||
|
@ -411,11 +438,12 @@ def selenium_action_template(driver):
|
|||
# Use the stored value
|
||||
value = site_form_values[a]
|
||||
# Replace the placeholder with the actual value
|
||||
element['input'] = str(value)
|
||||
element["input"] = str(value)
|
||||
|
||||
# Execute the selenium actions
|
||||
execute_selenium_elements(driver, 8, elements)
|
||||
|
||||
|
||||
# Sample task for Google search
|
||||
|
||||
|
||||
|
@ -429,40 +457,28 @@ def google_search(driver, search_term):
|
|||
"name": "Type in search term",
|
||||
"type": "input",
|
||||
"input": "{search_term}",
|
||||
"default": {
|
||||
"type": "xpath",
|
||||
"key": '//*[@id="APjFqb"]'
|
||||
},
|
||||
"failover": {
|
||||
"type": "xpath",
|
||||
"key": '//*[@id="APjFqb"]'
|
||||
}
|
||||
"default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||
"failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||
},
|
||||
{
|
||||
"name": "Press enter",
|
||||
"type": "input_enter",
|
||||
"default": {
|
||||
"type": "xpath",
|
||||
"key": '//*[@id="APjFqb"]'
|
||||
},
|
||||
"failover": {
|
||||
"type": "xpath",
|
||||
"key": '//*[@id="APjFqb"]'
|
||||
}
|
||||
"default": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||
"failover": {"type": "xpath", "key": '//*[@id="APjFqb"]'},
|
||||
},
|
||||
]
|
||||
|
||||
site_form_values = {}
|
||||
|
||||
for element in elements:
|
||||
if 'input' in element and '{' in element['input']:
|
||||
a = element['input'].strip('{}')
|
||||
if "input" in element and "{" in element["input"]:
|
||||
a = element["input"].strip("{}")
|
||||
if a in info:
|
||||
value = info[a]
|
||||
if callable(value):
|
||||
if a not in site_form_values:
|
||||
site_form_values[a] = value()
|
||||
value = site_form_values[a]
|
||||
element['input'] = str(value)
|
||||
element["input"] = str(value)
|
||||
|
||||
execute_selenium_elements(driver, 8, elements)
|
||||
|
|
|
@ -5,7 +5,7 @@ services:
|
|||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: drf_template:latest
|
||||
image: drf_template
|
||||
ports:
|
||||
- "${BACKEND_PORT}:${BACKEND_PORT}"
|
||||
environment:
|
||||
|
@ -23,7 +23,7 @@ services:
|
|||
env_file: .env
|
||||
environment:
|
||||
- RUN_TYPE=worker
|
||||
image: drf_template:latest
|
||||
image: drf_template
|
||||
volumes:
|
||||
- .:/code
|
||||
- ./chrome:/chrome
|
||||
|
@ -42,7 +42,7 @@ services:
|
|||
env_file: .env
|
||||
environment:
|
||||
- RUN_TYPE=beat
|
||||
image: drf_template:latest
|
||||
image: drf_template
|
||||
volumes:
|
||||
- .:/code
|
||||
depends_on:
|
||||
|
@ -58,7 +58,7 @@ services:
|
|||
env_file: .env
|
||||
environment:
|
||||
- RUN_TYPE=monitor
|
||||
image: drf_template:latest
|
||||
image: drf_template
|
||||
ports:
|
||||
- "${CELERY_FLOWER_PORT}:5555"
|
||||
volumes:
|
||||
|
|
Loading…
Reference in a new issue